refactor
This commit is contained in:
225
apps/cluster_map/visualization.py
Normal file
225
apps/cluster_map/visualization.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Visualization functions for creating interactive plots and displays.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
from dimensionality_reduction import calculate_local_density_scaling
|
||||
from config import MESSAGE_CONTENT_PREVIEW_LENGTH, DEFAULT_POINT_SIZE, DEFAULT_POINT_OPACITY
|
||||
|
||||
|
||||
def create_hover_text(df):
|
||||
"""Create hover text for plotly"""
|
||||
hover_text = []
|
||||
for _, row in df.iterrows():
|
||||
text = f"<b>Author:</b> {row['author_name']}<br>"
|
||||
text += f"<b>Timestamp:</b> {row['timestamp_utc']}<br>"
|
||||
text += f"<b>Source:</b> {row['source_file']}<br>"
|
||||
|
||||
# Handle potential NaN or non-string content
|
||||
content = row['content']
|
||||
if pd.isna(content) or content is None:
|
||||
content_text = "[No content]"
|
||||
else:
|
||||
content_str = str(content)
|
||||
content_text = content_str[:MESSAGE_CONTENT_PREVIEW_LENGTH] + ('...' if len(content_str) > MESSAGE_CONTENT_PREVIEW_LENGTH else '')
|
||||
|
||||
text += f"<b>Content:</b> {content_text}"
|
||||
hover_text.append(text)
|
||||
return hover_text
|
||||
|
||||
|
||||
def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
|
||||
point_size=DEFAULT_POINT_SIZE, size_variation=2.0):
|
||||
"""Calculate point sizes based on density if enabled"""
|
||||
if not density_based_sizing:
|
||||
return [point_size] * len(reduced_embeddings)
|
||||
|
||||
local_densities = calculate_local_density_scaling(reduced_embeddings)
|
||||
# Invert densities so sparse areas get larger points
|
||||
inverted_densities = 1.0 - local_densities
|
||||
# Scale point sizes
|
||||
point_sizes = point_size * (1.0 + inverted_densities * (size_variation - 1.0))
|
||||
return point_sizes
|
||||
|
||||
|
||||
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA"):
|
||||
"""Create a plot colored by clusters"""
|
||||
fig = go.Figure()
|
||||
|
||||
unique_clusters = np.unique(cluster_labels)
|
||||
colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
|
||||
|
||||
for i, cluster_id in enumerate(unique_clusters):
|
||||
cluster_mask = cluster_labels == cluster_id
|
||||
if cluster_mask.any():
|
||||
cluster_embeddings = reduced_embeddings[cluster_mask]
|
||||
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
|
||||
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
|
||||
|
||||
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=cluster_embeddings[:, 0],
|
||||
y=cluster_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=cluster_name,
|
||||
marker=dict(
|
||||
size=cluster_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=cluster_hover
|
||||
))
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY):
|
||||
"""Create a plot colored by source files"""
|
||||
fig = go.Figure()
|
||||
colors = px.colors.qualitative.Set1
|
||||
|
||||
for i, source in enumerate(selected_sources):
|
||||
source_mask = filtered_df['source_file'] == source
|
||||
if source_mask.any():
|
||||
source_embeddings = reduced_embeddings[source_mask]
|
||||
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
|
||||
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=source_embeddings[:, 0],
|
||||
y=source_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=source,
|
||||
marker=dict(
|
||||
size=source_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=source_hover
|
||||
))
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
||||
selected_sources=None, method="PCA", clustering_method="None",
|
||||
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
||||
density_based_sizing=False, size_variation=2.0):
|
||||
"""Create the main visualization plot"""
|
||||
|
||||
# Create hover text
|
||||
hover_text = create_hover_text(filtered_df)
|
||||
|
||||
# Calculate point sizes
|
||||
point_sizes = calculate_point_sizes(reduced_embeddings, density_based_sizing,
|
||||
point_size, size_variation)
|
||||
|
||||
# Create plot based on coloring strategy
|
||||
if cluster_labels is not None:
|
||||
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
||||
hover_text, point_sizes, point_opacity, method)
|
||||
else:
|
||||
if selected_sources is None:
|
||||
selected_sources = filtered_df['source_file'].unique()
|
||||
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
|
||||
hover_text, point_sizes, point_opacity)
|
||||
|
||||
# Update layout
|
||||
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
|
||||
fig.update_layout(
|
||||
title=f"Discord Chat Messages - {method} Visualization{title_suffix}",
|
||||
xaxis_title=f"{method} Component 1",
|
||||
yaxis_title=f"{method} Component 2",
|
||||
hovermode='closest',
|
||||
width=1000,
|
||||
height=700
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def display_clustering_metrics(cluster_labels, silhouette_avg, calinski_harabasz, show_metrics=True):
|
||||
"""Display clustering quality metrics"""
|
||||
if cluster_labels is not None and show_metrics:
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
n_clusters_found = len(np.unique(cluster_labels[cluster_labels != -1]))
|
||||
st.metric("Clusters Found", n_clusters_found)
|
||||
with col2:
|
||||
if silhouette_avg is not None:
|
||||
st.metric("Silhouette Score", f"{silhouette_avg:.3f}")
|
||||
else:
|
||||
st.metric("Silhouette Score", "N/A")
|
||||
with col3:
|
||||
if calinski_harabasz is not None:
|
||||
st.metric("Calinski-Harabasz Index", f"{calinski_harabasz:.1f}")
|
||||
else:
|
||||
st.metric("Calinski-Harabasz Index", "N/A")
|
||||
|
||||
|
||||
def display_summary_stats(filtered_df, selected_sources):
|
||||
"""Display summary statistics"""
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Messages", len(filtered_df))
|
||||
|
||||
with col2:
|
||||
st.metric("Unique Authors", filtered_df['author_name'].nunique())
|
||||
|
||||
with col3:
|
||||
st.metric("Source Files", len(selected_sources))
|
||||
|
||||
|
||||
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method):
|
||||
"""Display clustering results and export options"""
|
||||
if cluster_labels is None:
|
||||
return
|
||||
|
||||
st.subheader("📊 Clustering Results")
|
||||
|
||||
# Add cluster information to dataframe for export
|
||||
export_df = filtered_df.copy()
|
||||
export_df['cluster_id'] = cluster_labels
|
||||
export_df['x_coordinate'] = reduced_embeddings[:, 0]
|
||||
export_df['y_coordinate'] = reduced_embeddings[:, 1]
|
||||
|
||||
# Show cluster distribution
|
||||
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
|
||||
st.bar_chart(cluster_dist)
|
||||
|
||||
# Download option
|
||||
csv_data = export_df.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="📥 Download Clustering Results (CSV)",
|
||||
data=csv_data,
|
||||
file_name=f"chat_clusters_{method}_{clustering_method}.csv",
|
||||
mime="text/csv"
|
||||
)
|
||||
|
||||
|
||||
def display_data_table(filtered_df, cluster_labels=None):
|
||||
"""Display the data table with optional clustering information"""
|
||||
if not st.checkbox("Show Data Table"):
|
||||
return
|
||||
|
||||
st.subheader("📋 Message Data")
|
||||
display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy()
|
||||
|
||||
# Add clustering info if available
|
||||
if cluster_labels is not None:
|
||||
display_df['cluster'] = cluster_labels
|
||||
|
||||
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
|
||||
st.dataframe(display_df, use_container_width=True)
|
||||
Reference in New Issue
Block a user