"""
Visualization functions for creating interactive plots and displays.
"""
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from dimensionality_reduction import calculate_local_density_scaling
from config import MESSAGE_CONTENT_PREVIEW_LENGTH, DEFAULT_POINT_SIZE, DEFAULT_POINT_OPACITY
def create_hover_text(df):
"""Create hover text for plotly"""
hover_text = []
for _, row in df.iterrows():
text = f"Author: {row['author_name']}
"
text += f"Timestamp: {row['timestamp_utc']}
"
text += f"Source: {row['source_file']}
"
# Handle potential NaN or non-string content
content = row['content']
if pd.isna(content) or content is None:
content_text = "[No content]"
else:
content_str = str(content)
content_text = content_str[:MESSAGE_CONTENT_PREVIEW_LENGTH] + ('...' if len(content_str) > MESSAGE_CONTENT_PREVIEW_LENGTH else '')
text += f"Content: {content_text}"
hover_text.append(text)
return hover_text
def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
point_size=DEFAULT_POINT_SIZE, size_variation=2.0):
"""Calculate point sizes based on density if enabled"""
if not density_based_sizing:
return [point_size] * len(reduced_embeddings)
local_densities = calculate_local_density_scaling(reduced_embeddings)
# Invert densities so sparse areas get larger points
inverted_densities = 1.0 - local_densities
# Scale point sizes
point_sizes = point_size * (1.0 + inverted_densities * (size_variation - 1.0))
return point_sizes
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
"""Create a plot colored by clusters"""
fig = go.Figure()
unique_clusters = np.unique(cluster_labels)
colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
for i, cluster_id in enumerate(unique_clusters):
cluster_mask = cluster_labels == cluster_id
if cluster_mask.any():
cluster_embeddings = reduced_embeddings[cluster_mask]
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
if enable_3d:
fig.add_trace(go.Scatter3d(
x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1],
z=cluster_embeddings[:, 2],
mode='markers',
name=cluster_name,
marker=dict(
size=cluster_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}',
hovertext=cluster_hover
))
else:
fig.add_trace(go.Scatter(
x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1],
mode='markers',
name=cluster_name,
marker=dict(
size=cluster_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}',
hovertext=cluster_hover
))
return fig
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
"""Create a plot colored by source files"""
fig = go.Figure()
colors = px.colors.qualitative.Set1
for i, source in enumerate(selected_sources):
source_mask = filtered_df['source_file'] == source
if source_mask.any():
source_embeddings = reduced_embeddings[source_mask]
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
if enable_3d:
fig.add_trace(go.Scatter3d(
x=source_embeddings[:, 0],
y=source_embeddings[:, 1],
z=source_embeddings[:, 2],
mode='markers',
name=source,
marker=dict(
size=source_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}',
hovertext=source_hover
))
else:
fig.add_trace(go.Scatter(
x=source_embeddings[:, 0],
y=source_embeddings[:, 1],
mode='markers',
name=source,
marker=dict(
size=source_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}',
hovertext=source_hover
))
return fig
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
selected_sources=None, method="PCA", clustering_method="None",
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
density_based_sizing=False, size_variation=2.0, enable_3d=False):
"""Create the main visualization plot"""
# Create hover text
hover_text = create_hover_text(filtered_df)
# Calculate point sizes
point_sizes = calculate_point_sizes(reduced_embeddings, density_based_sizing,
point_size, size_variation)
# Create plot based on coloring strategy
if cluster_labels is not None:
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
hover_text, point_sizes, point_opacity, method, enable_3d)
else:
if selected_sources is None:
selected_sources = filtered_df['source_file'].unique()
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
hover_text, point_sizes, point_opacity, enable_3d)
# Update layout
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
dimension_text = "3D" if enable_3d else "2D"
if enable_3d:
fig.update_layout(
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
scene=dict(
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
zaxis_title=f"{method} Component 3"
),
width=1000,
height=700
)
else:
fig.update_layout(
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
hovermode='closest',
width=1000,
height=700
)
return fig
def display_clustering_metrics(cluster_labels, silhouette_avg, calinski_harabasz, show_metrics=True):
"""Display clustering quality metrics"""
if cluster_labels is not None and show_metrics:
col1, col2, col3 = st.columns(3)
with col1:
n_clusters_found = len(np.unique(cluster_labels[cluster_labels != -1]))
st.metric("Clusters Found", n_clusters_found)
with col2:
if silhouette_avg is not None:
st.metric("Silhouette Score", f"{silhouette_avg:.3f}")
else:
st.metric("Silhouette Score", "N/A")
with col3:
if calinski_harabasz is not None:
st.metric("Calinski-Harabasz Index", f"{calinski_harabasz:.1f}")
else:
st.metric("Calinski-Harabasz Index", "N/A")
def display_summary_stats(filtered_df, selected_sources):
"""Display summary statistics"""
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Messages", len(filtered_df))
with col2:
st.metric("Unique Authors", filtered_df['author_name'].nunique())
with col3:
st.metric("Source Files", len(selected_sources))
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
"""Display clustering results and export options"""
if cluster_labels is None:
return
st.subheader("📊 Clustering Results")
# Add cluster information to dataframe for export
export_df = filtered_df.copy()
export_df['cluster_id'] = cluster_labels
export_df['x_coordinate'] = reduced_embeddings[:, 0]
export_df['y_coordinate'] = reduced_embeddings[:, 1]
# Add z coordinate if 3D
if enable_3d and reduced_embeddings.shape[1] >= 3:
export_df['z_coordinate'] = reduced_embeddings[:, 2]
# Show cluster distribution
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
st.bar_chart(cluster_dist)
# Download option
csv_data = export_df.to_csv(index=False)
dimension_text = "3D" if enable_3d else "2D"
st.download_button(
label="📥 Download Clustering Results (CSV)",
data=csv_data,
file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
mime="text/csv"
)
def display_data_table(filtered_df, cluster_labels=None):
"""Display the data table with optional clustering information"""
if not st.checkbox("Show Data Table"):
return
st.subheader("📋 Message Data")
display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy()
# Add clustering info if available
if cluster_labels is not None:
display_df['cluster'] = cluster_labels
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
st.dataframe(display_df, use_container_width=True)