Files
cult-scraper/apps/cluster_map/visualization.py

312 lines
12 KiB
Python

"""
Visualization functions for creating interactive plots and displays.
"""
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from dimensionality_reduction import calculate_local_density_scaling
from config import MESSAGE_CONTENT_PREVIEW_LENGTH, DEFAULT_POINT_SIZE, DEFAULT_POINT_OPACITY
def create_hover_text(df):
"""Create hover text for plotly"""
hover_text = []
for _, row in df.iterrows():
text = f"<b>Author:</b> {row['author_name']}<br>"
text += f"<b>Timestamp:</b> {row['timestamp_utc']}<br>"
text += f"<b>Source:</b> {row['source_file']}<br>"
# Handle potential NaN or non-string content
content = row['content']
if pd.isna(content) or content is None:
content_text = "[No content]"
else:
content_str = str(content)
content_text = content_str[:MESSAGE_CONTENT_PREVIEW_LENGTH] + ('...' if len(content_str) > MESSAGE_CONTENT_PREVIEW_LENGTH else '')
text += f"<b>Content:</b> {content_text}"
hover_text.append(text)
return hover_text
def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
point_size=DEFAULT_POINT_SIZE, size_variation=2.0):
"""Calculate point sizes based on density if enabled"""
if not density_based_sizing:
return [point_size] * len(reduced_embeddings)
local_densities = calculate_local_density_scaling(reduced_embeddings)
# Invert densities so sparse areas get larger points
inverted_densities = 1.0 - local_densities
# Scale point sizes
point_sizes = point_size * (1.0 + inverted_densities * (size_variation - 1.0))
return point_sizes
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False,
cluster_names=None):
"""Create a plot colored by clusters"""
fig = go.Figure()
unique_clusters = np.unique(cluster_labels)
colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
for i, cluster_id in enumerate(unique_clusters):
cluster_mask = cluster_labels == cluster_id
if cluster_mask.any():
cluster_embeddings = reduced_embeddings[cluster_mask]
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
# Use generated name if available, otherwise fall back to default
if cluster_names and cluster_id in cluster_names:
cluster_name = cluster_names[cluster_id]
else:
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
if enable_3d:
fig.add_trace(go.Scatter3d(
x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1],
z=cluster_embeddings[:, 2],
mode='markers',
name=cluster_name,
marker=dict(
size=cluster_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=cluster_hover
))
else:
fig.add_trace(go.Scatter(
x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1],
mode='markers',
name=cluster_name,
marker=dict(
size=cluster_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=cluster_hover
))
return fig
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
"""Create a plot colored by source files"""
fig = go.Figure()
colors = px.colors.qualitative.Set1
for i, source in enumerate(selected_sources):
source_mask = filtered_df['source_file'] == source
if source_mask.any():
source_embeddings = reduced_embeddings[source_mask]
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
if enable_3d:
fig.add_trace(go.Scatter3d(
x=source_embeddings[:, 0],
y=source_embeddings[:, 1],
z=source_embeddings[:, 2],
mode='markers',
name=source,
marker=dict(
size=source_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=source_hover
))
else:
fig.add_trace(go.Scatter(
x=source_embeddings[:, 0],
y=source_embeddings[:, 1],
mode='markers',
name=source,
marker=dict(
size=source_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=source_hover
))
return fig
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
selected_sources=None, method="PCA", clustering_method="None",
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
density_based_sizing=False, size_variation=2.0, enable_3d=False,
cluster_names=None):
"""Create the main visualization plot"""
# Create hover text
hover_text = create_hover_text(filtered_df)
# Calculate point sizes
point_sizes = calculate_point_sizes(reduced_embeddings, density_based_sizing,
point_size, size_variation)
# Create plot based on coloring strategy
if cluster_labels is not None:
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
hover_text, point_sizes, point_opacity, method, enable_3d,
cluster_names)
else:
if selected_sources is None:
selected_sources = filtered_df['source_file'].unique()
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
hover_text, point_sizes, point_opacity, enable_3d)
# Update layout
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
dimension_text = "3D" if enable_3d else "2D"
if enable_3d:
fig.update_layout(
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
scene=dict(
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
zaxis_title=f"{method} Component 3"
),
width=1000,
height=700
)
else:
fig.update_layout(
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
hovermode='closest',
width=1000,
height=700
)
return fig
def display_clustering_metrics(cluster_labels, silhouette_avg, calinski_harabasz, show_metrics=True):
"""Display clustering quality metrics"""
if cluster_labels is not None and show_metrics:
col1, col2, col3 = st.columns(3)
with col1:
n_clusters_found = len(np.unique(cluster_labels[cluster_labels != -1]))
st.metric("Clusters Found", n_clusters_found)
with col2:
if silhouette_avg is not None:
st.metric("Silhouette Score", f"{silhouette_avg:.3f}")
else:
st.metric("Silhouette Score", "N/A")
with col3:
if calinski_harabasz is not None:
st.metric("Calinski-Harabasz Index", f"{calinski_harabasz:.1f}")
else:
st.metric("Calinski-Harabasz Index", "N/A")
def display_summary_stats(filtered_df, selected_sources):
"""Display summary statistics"""
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Messages", len(filtered_df))
with col2:
st.metric("Unique Authors", filtered_df['author_name'].nunique())
with col3:
st.metric("Source Files", len(selected_sources))
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
"""Display clustering results and export options"""
if cluster_labels is None:
return
st.subheader("📊 Clustering Results")
# Add cluster information to dataframe for export
export_df = filtered_df.copy()
export_df['cluster_id'] = cluster_labels
export_df['x_coordinate'] = reduced_embeddings[:, 0]
export_df['y_coordinate'] = reduced_embeddings[:, 1]
# Add z coordinate if 3D
if enable_3d and reduced_embeddings.shape[1] >= 3:
export_df['z_coordinate'] = reduced_embeddings[:, 2]
# Show cluster distribution
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
st.bar_chart(cluster_dist)
# Download option
csv_data = export_df.to_csv(index=False)
dimension_text = "3D" if enable_3d else "2D"
st.download_button(
label="📥 Download Clustering Results (CSV)",
data=csv_data,
file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
mime="text/csv"
)
def display_data_table(filtered_df, cluster_labels=None):
"""Display the data table with optional clustering information"""
if not st.checkbox("Show Data Table"):
return
st.subheader("📋 Message Data")
display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy()
# Add clustering info if available
if cluster_labels is not None:
display_df['cluster'] = cluster_labels
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
st.dataframe(display_df, use_container_width=True)
def display_cluster_summary(cluster_names, cluster_labels):
"""Display a summary of cluster names and their sizes"""
if not cluster_names or cluster_labels is None:
return
st.subheader("🏷️ Cluster Summary")
# Create summary data
cluster_summary = []
for cluster_id, name in cluster_names.items():
count = np.sum(cluster_labels == cluster_id)
cluster_summary.append({
'Cluster ID': cluster_id,
'Cluster Name': name,
'Message Count': count,
'Percentage': f"{100 * count / len(cluster_labels):.1f}%"
})
# Sort by message count
cluster_summary.sort(key=lambda x: x['Message Count'], reverse=True)
# Display as table
summary_df = pd.DataFrame(cluster_summary)
st.dataframe(summary_df, use_container_width=True, hide_index=True)