From 647111e9d33680677f8f26f7c6acfef054a54d31 Mon Sep 17 00:00:00 2001 From: Azeem Fidahusein Date: Mon, 11 Aug 2025 02:49:41 +0100 Subject: [PATCH] 3d viz --- apps/cluster_map/main.py | 7 +- apps/cluster_map/ui_components.py | 12 ++- apps/cluster_map/visualization.py | 139 +++++++++++++++++++++--------- 3 files changed, 111 insertions(+), 47 deletions(-) diff --git a/apps/cluster_map/main.py b/apps/cluster_map/main.py index 395e6f9..a8f6d9b 100644 --- a/apps/cluster_map/main.py +++ b/apps/cluster_map/main.py @@ -67,10 +67,12 @@ def main(): st.info(f"📈 Visualizing {len(filtered_df)} messages") # Reduce dimensions + n_components = 3 if params['enable_3d'] else 2 with st.spinner(f"Reducing dimensions using {params['method']}..."): reduced_embeddings = reduce_dimensions( filtered_embeddings, method=params['method'], + n_components=n_components, spread_factor=params['spread_factor'], perplexity_factor=params['perplexity_factor'], min_dist_factor=params['min_dist_factor'] @@ -110,7 +112,8 @@ def main(): point_size=params['point_size'], point_opacity=params['point_opacity'], density_based_sizing=params['density_based_sizing'], - size_variation=params['size_variation'] + size_variation=params['size_variation'], + enable_3d=params['enable_3d'] ) st.plotly_chart(fig, use_container_width=True) @@ -121,7 +124,7 @@ def main(): # Display clustering results and export options display_clustering_results( filtered_df, cluster_labels, reduced_embeddings, - params['method'], params['clustering_method'] + params['method'], params['clustering_method'], params['enable_3d'] ) # Display data table diff --git a/apps/cluster_map/ui_components.py b/apps/cluster_map/ui_components.py index 83b7944..a02c831 100644 --- a/apps/cluster_map/ui_components.py +++ b/apps/cluster_map/ui_components.py @@ -30,6 +30,13 @@ def create_method_controls(): """Create controls for dimension reduction and clustering methods""" st.sidebar.header("🎛️ Visualization Controls") + # 3D visualization toggle + enable_3d = st.sidebar.checkbox( + "Enable 3D Visualization", + value=False, + help="Switch between 2D and 3D visualization. 3D uses 3 components instead of 2." + ) + # Dimension reduction method method = st.sidebar.selectbox( "Dimension Reduction Method", @@ -45,7 +52,7 @@ def create_method_controls(): help="Apply clustering to identify groups. HDBSCAN and OPTICS can find variable density clusters." ) - return method, clustering_method + return method, clustering_method, enable_3d def create_clustering_controls(clustering_method): @@ -196,7 +203,7 @@ def display_performance_warnings(filtered_df, method, clustering_method): def get_all_ui_parameters(valid_df): """Get all UI parameters in a single function call""" # Method selection - method, clustering_method = create_method_controls() + method, clustering_method, enable_3d = create_method_controls() # Clustering parameters n_clusters = create_clustering_controls(clustering_method) @@ -219,6 +226,7 @@ def get_all_ui_parameters(valid_df): return { 'method': method, 'clustering_method': clustering_method, + 'enable_3d': enable_3d, 'n_clusters': n_clusters, 'spread_factor': spread_factor, 'perplexity_factor': perplexity_factor, diff --git a/apps/cluster_map/visualization.py b/apps/cluster_map/visualization.py index 93f38d4..66d2e2d 100644 --- a/apps/cluster_map/visualization.py +++ b/apps/cluster_map/visualization.py @@ -47,7 +47,7 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False, def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text, - point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA"): + point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False): """Create a plot colored by clusters""" fig = go.Figure() @@ -63,26 +63,43 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise" - fig.add_trace(go.Scatter( - x=cluster_embeddings[:, 0], - y=cluster_embeddings[:, 1], - mode='markers', - name=cluster_name, - marker=dict( - size=cluster_sizes, - color=colors[i % len(colors)], - opacity=point_opacity, - line=dict(width=1, color='white') - ), - hovertemplate='%{hovertext}', - hovertext=cluster_hover - )) + if enable_3d: + fig.add_trace(go.Scatter3d( + x=cluster_embeddings[:, 0], + y=cluster_embeddings[:, 1], + z=cluster_embeddings[:, 2], + mode='markers', + name=cluster_name, + marker=dict( + size=cluster_sizes, + color=colors[i % len(colors)], + opacity=point_opacity, + line=dict(width=1, color='white') + ), + hovertemplate='%{hovertext}', + hovertext=cluster_hover + )) + else: + fig.add_trace(go.Scatter( + x=cluster_embeddings[:, 0], + y=cluster_embeddings[:, 1], + mode='markers', + name=cluster_name, + marker=dict( + size=cluster_sizes, + color=colors[i % len(colors)], + opacity=point_opacity, + line=dict(width=1, color='white') + ), + hovertemplate='%{hovertext}', + hovertext=cluster_hover + )) return fig def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text, - point_sizes, point_opacity=DEFAULT_POINT_OPACITY): + point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False): """Create a plot colored by source files""" fig = go.Figure() colors = px.colors.qualitative.Set1 @@ -94,20 +111,37 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask] source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask] - fig.add_trace(go.Scatter( - x=source_embeddings[:, 0], - y=source_embeddings[:, 1], - mode='markers', - name=source, - marker=dict( - size=source_sizes, - color=colors[i % len(colors)], - opacity=point_opacity, - line=dict(width=1, color='white') - ), - hovertemplate='%{hovertext}', - hovertext=source_hover - )) + if enable_3d: + fig.add_trace(go.Scatter3d( + x=source_embeddings[:, 0], + y=source_embeddings[:, 1], + z=source_embeddings[:, 2], + mode='markers', + name=source, + marker=dict( + size=source_sizes, + color=colors[i % len(colors)], + opacity=point_opacity, + line=dict(width=1, color='white') + ), + hovertemplate='%{hovertext}', + hovertext=source_hover + )) + else: + fig.add_trace(go.Scatter( + x=source_embeddings[:, 0], + y=source_embeddings[:, 1], + mode='markers', + name=source, + marker=dict( + size=source_sizes, + color=colors[i % len(colors)], + opacity=point_opacity, + line=dict(width=1, color='white') + ), + hovertemplate='%{hovertext}', + hovertext=source_hover + )) return fig @@ -115,7 +149,7 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None, selected_sources=None, method="PCA", clustering_method="None", point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY, - density_based_sizing=False, size_variation=2.0): + density_based_sizing=False, size_variation=2.0, enable_3d=False): """Create the main visualization plot""" # Create hover text @@ -128,23 +162,37 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No # Create plot based on coloring strategy if cluster_labels is not None: fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, - hover_text, point_sizes, point_opacity, method) + hover_text, point_sizes, point_opacity, method, enable_3d) else: if selected_sources is None: selected_sources = filtered_df['source_file'].unique() fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, - hover_text, point_sizes, point_opacity) + hover_text, point_sizes, point_opacity, enable_3d) # Update layout title_suffix = f" with {clustering_method}" if clustering_method != "None" else "" - fig.update_layout( - title=f"Discord Chat Messages - {method} Visualization{title_suffix}", - xaxis_title=f"{method} Component 1", - yaxis_title=f"{method} Component 2", - hovermode='closest', - width=1000, - height=700 - ) + dimension_text = "3D" if enable_3d else "2D" + + if enable_3d: + fig.update_layout( + title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}", + scene=dict( + xaxis_title=f"{method} Component 1", + yaxis_title=f"{method} Component 2", + zaxis_title=f"{method} Component 3" + ), + width=1000, + height=700 + ) + else: + fig.update_layout( + title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}", + xaxis_title=f"{method} Component 1", + yaxis_title=f"{method} Component 2", + hovermode='closest', + width=1000, + height=700 + ) return fig @@ -182,7 +230,7 @@ def display_summary_stats(filtered_df, selected_sources): st.metric("Source Files", len(selected_sources)) -def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method): +def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False): """Display clustering results and export options""" if cluster_labels is None: return @@ -195,16 +243,21 @@ def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, export_df['x_coordinate'] = reduced_embeddings[:, 0] export_df['y_coordinate'] = reduced_embeddings[:, 1] + # Add z coordinate if 3D + if enable_3d and reduced_embeddings.shape[1] >= 3: + export_df['z_coordinate'] = reduced_embeddings[:, 2] + # Show cluster distribution cluster_dist = pd.Series(cluster_labels).value_counts().sort_index() st.bar_chart(cluster_dist) # Download option csv_data = export_df.to_csv(index=False) + dimension_text = "3D" if enable_3d else "2D" st.download_button( label="📥 Download Clustering Results (CSV)", data=csv_data, - file_name=f"chat_clusters_{method}_{clustering_method}.csv", + file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv", mime="text/csv" )