279 lines
11 KiB
Python
279 lines
11 KiB
Python
"""
|
|
Visualization functions for creating interactive plots and displays.
|
|
"""
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
import streamlit as st
|
|
from dimensionality_reduction import calculate_local_density_scaling
|
|
from config import MESSAGE_CONTENT_PREVIEW_LENGTH, DEFAULT_POINT_SIZE, DEFAULT_POINT_OPACITY
|
|
|
|
|
|
def create_hover_text(df):
|
|
"""Create hover text for plotly"""
|
|
hover_text = []
|
|
for _, row in df.iterrows():
|
|
text = f"<b>Author:</b> {row['author_name']}<br>"
|
|
text += f"<b>Timestamp:</b> {row['timestamp_utc']}<br>"
|
|
text += f"<b>Source:</b> {row['source_file']}<br>"
|
|
|
|
# Handle potential NaN or non-string content
|
|
content = row['content']
|
|
if pd.isna(content) or content is None:
|
|
content_text = "[No content]"
|
|
else:
|
|
content_str = str(content)
|
|
content_text = content_str[:MESSAGE_CONTENT_PREVIEW_LENGTH] + ('...' if len(content_str) > MESSAGE_CONTENT_PREVIEW_LENGTH else '')
|
|
|
|
text += f"<b>Content:</b> {content_text}"
|
|
hover_text.append(text)
|
|
return hover_text
|
|
|
|
|
|
def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
|
|
point_size=DEFAULT_POINT_SIZE, size_variation=2.0):
|
|
"""Calculate point sizes based on density if enabled"""
|
|
if not density_based_sizing:
|
|
return [point_size] * len(reduced_embeddings)
|
|
|
|
local_densities = calculate_local_density_scaling(reduced_embeddings)
|
|
# Invert densities so sparse areas get larger points
|
|
inverted_densities = 1.0 - local_densities
|
|
# Scale point sizes
|
|
point_sizes = point_size * (1.0 + inverted_densities * (size_variation - 1.0))
|
|
return point_sizes
|
|
|
|
|
|
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
|
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
|
|
"""Create a plot colored by clusters"""
|
|
fig = go.Figure()
|
|
|
|
unique_clusters = np.unique(cluster_labels)
|
|
colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
|
|
|
|
for i, cluster_id in enumerate(unique_clusters):
|
|
cluster_mask = cluster_labels == cluster_id
|
|
if cluster_mask.any():
|
|
cluster_embeddings = reduced_embeddings[cluster_mask]
|
|
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
|
|
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
|
|
|
|
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
|
|
|
if enable_3d:
|
|
fig.add_trace(go.Scatter3d(
|
|
x=cluster_embeddings[:, 0],
|
|
y=cluster_embeddings[:, 1],
|
|
z=cluster_embeddings[:, 2],
|
|
mode='markers',
|
|
name=cluster_name,
|
|
marker=dict(
|
|
size=cluster_sizes,
|
|
color=colors[i % len(colors)],
|
|
opacity=point_opacity,
|
|
line=dict(width=1, color='white')
|
|
),
|
|
hovertemplate='%{hovertext}<extra></extra>',
|
|
hovertext=cluster_hover
|
|
))
|
|
else:
|
|
fig.add_trace(go.Scatter(
|
|
x=cluster_embeddings[:, 0],
|
|
y=cluster_embeddings[:, 1],
|
|
mode='markers',
|
|
name=cluster_name,
|
|
marker=dict(
|
|
size=cluster_sizes,
|
|
color=colors[i % len(colors)],
|
|
opacity=point_opacity,
|
|
line=dict(width=1, color='white')
|
|
),
|
|
hovertemplate='%{hovertext}<extra></extra>',
|
|
hovertext=cluster_hover
|
|
))
|
|
|
|
return fig
|
|
|
|
|
|
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
|
|
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
|
|
"""Create a plot colored by source files"""
|
|
fig = go.Figure()
|
|
colors = px.colors.qualitative.Set1
|
|
|
|
for i, source in enumerate(selected_sources):
|
|
source_mask = filtered_df['source_file'] == source
|
|
if source_mask.any():
|
|
source_embeddings = reduced_embeddings[source_mask]
|
|
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
|
|
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
|
|
|
|
if enable_3d:
|
|
fig.add_trace(go.Scatter3d(
|
|
x=source_embeddings[:, 0],
|
|
y=source_embeddings[:, 1],
|
|
z=source_embeddings[:, 2],
|
|
mode='markers',
|
|
name=source,
|
|
marker=dict(
|
|
size=source_sizes,
|
|
color=colors[i % len(colors)],
|
|
opacity=point_opacity,
|
|
line=dict(width=1, color='white')
|
|
),
|
|
hovertemplate='%{hovertext}<extra></extra>',
|
|
hovertext=source_hover
|
|
))
|
|
else:
|
|
fig.add_trace(go.Scatter(
|
|
x=source_embeddings[:, 0],
|
|
y=source_embeddings[:, 1],
|
|
mode='markers',
|
|
name=source,
|
|
marker=dict(
|
|
size=source_sizes,
|
|
color=colors[i % len(colors)],
|
|
opacity=point_opacity,
|
|
line=dict(width=1, color='white')
|
|
),
|
|
hovertemplate='%{hovertext}<extra></extra>',
|
|
hovertext=source_hover
|
|
))
|
|
|
|
return fig
|
|
|
|
|
|
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
|
selected_sources=None, method="PCA", clustering_method="None",
|
|
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
|
density_based_sizing=False, size_variation=2.0, enable_3d=False):
|
|
"""Create the main visualization plot"""
|
|
|
|
# Create hover text
|
|
hover_text = create_hover_text(filtered_df)
|
|
|
|
# Calculate point sizes
|
|
point_sizes = calculate_point_sizes(reduced_embeddings, density_based_sizing,
|
|
point_size, size_variation)
|
|
|
|
# Create plot based on coloring strategy
|
|
if cluster_labels is not None:
|
|
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
|
hover_text, point_sizes, point_opacity, method, enable_3d)
|
|
else:
|
|
if selected_sources is None:
|
|
selected_sources = filtered_df['source_file'].unique()
|
|
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
|
|
hover_text, point_sizes, point_opacity, enable_3d)
|
|
|
|
# Update layout
|
|
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
|
|
dimension_text = "3D" if enable_3d else "2D"
|
|
|
|
if enable_3d:
|
|
fig.update_layout(
|
|
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
|
scene=dict(
|
|
xaxis_title=f"{method} Component 1",
|
|
yaxis_title=f"{method} Component 2",
|
|
zaxis_title=f"{method} Component 3"
|
|
),
|
|
width=1000,
|
|
height=700
|
|
)
|
|
else:
|
|
fig.update_layout(
|
|
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
|
xaxis_title=f"{method} Component 1",
|
|
yaxis_title=f"{method} Component 2",
|
|
hovermode='closest',
|
|
width=1000,
|
|
height=700
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
def display_clustering_metrics(cluster_labels, silhouette_avg, calinski_harabasz, show_metrics=True):
|
|
"""Display clustering quality metrics"""
|
|
if cluster_labels is not None and show_metrics:
|
|
col1, col2, col3 = st.columns(3)
|
|
with col1:
|
|
n_clusters_found = len(np.unique(cluster_labels[cluster_labels != -1]))
|
|
st.metric("Clusters Found", n_clusters_found)
|
|
with col2:
|
|
if silhouette_avg is not None:
|
|
st.metric("Silhouette Score", f"{silhouette_avg:.3f}")
|
|
else:
|
|
st.metric("Silhouette Score", "N/A")
|
|
with col3:
|
|
if calinski_harabasz is not None:
|
|
st.metric("Calinski-Harabasz Index", f"{calinski_harabasz:.1f}")
|
|
else:
|
|
st.metric("Calinski-Harabasz Index", "N/A")
|
|
|
|
|
|
def display_summary_stats(filtered_df, selected_sources):
|
|
"""Display summary statistics"""
|
|
col1, col2, col3 = st.columns(3)
|
|
|
|
with col1:
|
|
st.metric("Total Messages", len(filtered_df))
|
|
|
|
with col2:
|
|
st.metric("Unique Authors", filtered_df['author_name'].nunique())
|
|
|
|
with col3:
|
|
st.metric("Source Files", len(selected_sources))
|
|
|
|
|
|
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
|
|
"""Display clustering results and export options"""
|
|
if cluster_labels is None:
|
|
return
|
|
|
|
st.subheader("📊 Clustering Results")
|
|
|
|
# Add cluster information to dataframe for export
|
|
export_df = filtered_df.copy()
|
|
export_df['cluster_id'] = cluster_labels
|
|
export_df['x_coordinate'] = reduced_embeddings[:, 0]
|
|
export_df['y_coordinate'] = reduced_embeddings[:, 1]
|
|
|
|
# Add z coordinate if 3D
|
|
if enable_3d and reduced_embeddings.shape[1] >= 3:
|
|
export_df['z_coordinate'] = reduced_embeddings[:, 2]
|
|
|
|
# Show cluster distribution
|
|
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
|
|
st.bar_chart(cluster_dist)
|
|
|
|
# Download option
|
|
csv_data = export_df.to_csv(index=False)
|
|
dimension_text = "3D" if enable_3d else "2D"
|
|
st.download_button(
|
|
label="📥 Download Clustering Results (CSV)",
|
|
data=csv_data,
|
|
file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
|
|
mime="text/csv"
|
|
)
|
|
|
|
|
|
def display_data_table(filtered_df, cluster_labels=None):
|
|
"""Display the data table with optional clustering information"""
|
|
if not st.checkbox("Show Data Table"):
|
|
return
|
|
|
|
st.subheader("📋 Message Data")
|
|
display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy()
|
|
|
|
# Add clustering info if available
|
|
if cluster_labels is not None:
|
|
display_df['cluster'] = cluster_labels
|
|
|
|
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
|
|
st.dataframe(display_df, use_container_width=True)
|