""" Visualization functions for creating interactive plots and displays. """ import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go import streamlit as st from dimensionality_reduction import calculate_local_density_scaling from config import MESSAGE_CONTENT_PREVIEW_LENGTH, DEFAULT_POINT_SIZE, DEFAULT_POINT_OPACITY def create_hover_text(df): """Create hover text for plotly""" hover_text = [] for _, row in df.iterrows(): text = f"Author: {row['author_name']}
" text += f"Timestamp: {row['timestamp_utc']}
" text += f"Source: {row['source_file']}
" # Handle potential NaN or non-string content content = row['content'] if pd.isna(content) or content is None: content_text = "[No content]" else: content_str = str(content) content_text = content_str[:MESSAGE_CONTENT_PREVIEW_LENGTH] + ('...' if len(content_str) > MESSAGE_CONTENT_PREVIEW_LENGTH else '') text += f"Content: {content_text}" hover_text.append(text) return hover_text def calculate_point_sizes(reduced_embeddings, density_based_sizing=False, point_size=DEFAULT_POINT_SIZE, size_variation=2.0): """Calculate point sizes based on density if enabled""" if not density_based_sizing: return [point_size] * len(reduced_embeddings) local_densities = calculate_local_density_scaling(reduced_embeddings) # Invert densities so sparse areas get larger points inverted_densities = 1.0 - local_densities # Scale point sizes point_sizes = point_size * (1.0 + inverted_densities * (size_variation - 1.0)) return point_sizes def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text, point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False, cluster_names=None): """Create a plot colored by clusters""" fig = go.Figure() unique_clusters = np.unique(cluster_labels) colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel for i, cluster_id in enumerate(unique_clusters): cluster_mask = cluster_labels == cluster_id if cluster_mask.any(): cluster_embeddings = reduced_embeddings[cluster_mask] cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask] cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask] # Use generated name if available, otherwise fall back to default if cluster_names and cluster_id in cluster_names: cluster_name = cluster_names[cluster_id] else: cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise" if enable_3d: fig.add_trace(go.Scatter3d( x=cluster_embeddings[:, 0], y=cluster_embeddings[:, 1], z=cluster_embeddings[:, 2], mode='markers', name=cluster_name, marker=dict( size=cluster_sizes, color=colors[i % len(colors)], opacity=point_opacity, line=dict(width=1, color='white') ), hovertemplate='%{hovertext}', hovertext=cluster_hover )) else: fig.add_trace(go.Scatter( x=cluster_embeddings[:, 0], y=cluster_embeddings[:, 1], mode='markers', name=cluster_name, marker=dict( size=cluster_sizes, color=colors[i % len(colors)], opacity=point_opacity, line=dict(width=1, color='white') ), hovertemplate='%{hovertext}', hovertext=cluster_hover )) return fig def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text, point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False): """Create a plot colored by source files""" fig = go.Figure() colors = px.colors.qualitative.Set1 for i, source in enumerate(selected_sources): source_mask = filtered_df['source_file'] == source if source_mask.any(): source_embeddings = reduced_embeddings[source_mask] source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask] source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask] if enable_3d: fig.add_trace(go.Scatter3d( x=source_embeddings[:, 0], y=source_embeddings[:, 1], z=source_embeddings[:, 2], mode='markers', name=source, marker=dict( size=source_sizes, color=colors[i % len(colors)], opacity=point_opacity, line=dict(width=1, color='white') ), hovertemplate='%{hovertext}', hovertext=source_hover )) else: fig.add_trace(go.Scatter( x=source_embeddings[:, 0], y=source_embeddings[:, 1], mode='markers', name=source, marker=dict( size=source_sizes, color=colors[i % len(colors)], opacity=point_opacity, line=dict(width=1, color='white') ), hovertemplate='%{hovertext}', hovertext=source_hover )) return fig def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None, selected_sources=None, method="PCA", clustering_method="None", point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY, density_based_sizing=False, size_variation=2.0, enable_3d=False, cluster_names=None): """Create the main visualization plot""" # Create hover text hover_text = create_hover_text(filtered_df) # Calculate point sizes point_sizes = calculate_point_sizes(reduced_embeddings, density_based_sizing, point_size, size_variation) # Create plot based on coloring strategy if cluster_labels is not None: fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text, point_sizes, point_opacity, method, enable_3d, cluster_names) else: if selected_sources is None: selected_sources = filtered_df['source_file'].unique() fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text, point_sizes, point_opacity, enable_3d) # Update layout title_suffix = f" with {clustering_method}" if clustering_method != "None" else "" dimension_text = "3D" if enable_3d else "2D" if enable_3d: fig.update_layout( title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}", scene=dict( xaxis_title=f"{method} Component 1", yaxis_title=f"{method} Component 2", zaxis_title=f"{method} Component 3" ), width=1000, height=700 ) else: fig.update_layout( title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}", xaxis_title=f"{method} Component 1", yaxis_title=f"{method} Component 2", hovermode='closest', width=1000, height=700 ) return fig def display_clustering_metrics(cluster_labels, silhouette_avg, calinski_harabasz, show_metrics=True): """Display clustering quality metrics""" if cluster_labels is not None and show_metrics: col1, col2, col3 = st.columns(3) with col1: n_clusters_found = len(np.unique(cluster_labels[cluster_labels != -1])) st.metric("Clusters Found", n_clusters_found) with col2: if silhouette_avg is not None: st.metric("Silhouette Score", f"{silhouette_avg:.3f}") else: st.metric("Silhouette Score", "N/A") with col3: if calinski_harabasz is not None: st.metric("Calinski-Harabasz Index", f"{calinski_harabasz:.1f}") else: st.metric("Calinski-Harabasz Index", "N/A") def display_summary_stats(filtered_df, selected_sources): """Display summary statistics""" col1, col2, col3 = st.columns(3) with col1: st.metric("Total Messages", len(filtered_df)) with col2: st.metric("Unique Authors", filtered_df['author_name'].nunique()) with col3: st.metric("Source Files", len(selected_sources)) def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False): """Display clustering results and export options""" if cluster_labels is None: return st.subheader("📊 Clustering Results") # Add cluster information to dataframe for export export_df = filtered_df.copy() export_df['cluster_id'] = cluster_labels export_df['x_coordinate'] = reduced_embeddings[:, 0] export_df['y_coordinate'] = reduced_embeddings[:, 1] # Add z coordinate if 3D if enable_3d and reduced_embeddings.shape[1] >= 3: export_df['z_coordinate'] = reduced_embeddings[:, 2] # Show cluster distribution cluster_dist = pd.Series(cluster_labels).value_counts().sort_index() st.bar_chart(cluster_dist) # Download option csv_data = export_df.to_csv(index=False) dimension_text = "3D" if enable_3d else "2D" st.download_button( label="📥 Download Clustering Results (CSV)", data=csv_data, file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv", mime="text/csv" ) def display_data_table(filtered_df, cluster_labels=None): """Display the data table with optional clustering information""" if not st.checkbox("Show Data Table"): return st.subheader("📋 Message Data") display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy() # Add clustering info if available if cluster_labels is not None: display_df['cluster'] = cluster_labels display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display st.dataframe(display_df, use_container_width=True) def display_cluster_summary(cluster_names, cluster_labels): """Display a summary of cluster names and their sizes""" if not cluster_names or cluster_labels is None: return st.subheader("🏷️ Cluster Summary") # Create summary data cluster_summary = [] for cluster_id, name in cluster_names.items(): count = np.sum(cluster_labels == cluster_id) cluster_summary.append({ 'Cluster ID': cluster_id, 'Cluster Name': name, 'Message Count': count, 'Percentage': f"{100 * count / len(cluster_labels):.1f}%" }) # Sort by message count cluster_summary.sort(key=lambda x: x['Message Count'], reverse=True) # Display as table summary_df = pd.DataFrame(cluster_summary) st.dataframe(summary_df, use_container_width=True, hide_index=True)