This commit is contained in:
2025-08-11 02:49:41 +01:00
parent 4ca7e8ab61
commit 647111e9d3
3 changed files with 111 additions and 47 deletions

View File

@@ -67,10 +67,12 @@ def main():
st.info(f"📈 Visualizing {len(filtered_df)} messages") st.info(f"📈 Visualizing {len(filtered_df)} messages")
# Reduce dimensions # Reduce dimensions
n_components = 3 if params['enable_3d'] else 2
with st.spinner(f"Reducing dimensions using {params['method']}..."): with st.spinner(f"Reducing dimensions using {params['method']}..."):
reduced_embeddings = reduce_dimensions( reduced_embeddings = reduce_dimensions(
filtered_embeddings, filtered_embeddings,
method=params['method'], method=params['method'],
n_components=n_components,
spread_factor=params['spread_factor'], spread_factor=params['spread_factor'],
perplexity_factor=params['perplexity_factor'], perplexity_factor=params['perplexity_factor'],
min_dist_factor=params['min_dist_factor'] min_dist_factor=params['min_dist_factor']
@@ -110,7 +112,8 @@ def main():
point_size=params['point_size'], point_size=params['point_size'],
point_opacity=params['point_opacity'], point_opacity=params['point_opacity'],
density_based_sizing=params['density_based_sizing'], density_based_sizing=params['density_based_sizing'],
size_variation=params['size_variation'] size_variation=params['size_variation'],
enable_3d=params['enable_3d']
) )
st.plotly_chart(fig, use_container_width=True) st.plotly_chart(fig, use_container_width=True)
@@ -121,7 +124,7 @@ def main():
# Display clustering results and export options # Display clustering results and export options
display_clustering_results( display_clustering_results(
filtered_df, cluster_labels, reduced_embeddings, filtered_df, cluster_labels, reduced_embeddings,
params['method'], params['clustering_method'] params['method'], params['clustering_method'], params['enable_3d']
) )
# Display data table # Display data table

View File

@@ -30,6 +30,13 @@ def create_method_controls():
"""Create controls for dimension reduction and clustering methods""" """Create controls for dimension reduction and clustering methods"""
st.sidebar.header("🎛️ Visualization Controls") st.sidebar.header("🎛️ Visualization Controls")
# 3D visualization toggle
enable_3d = st.sidebar.checkbox(
"Enable 3D Visualization",
value=False,
help="Switch between 2D and 3D visualization. 3D uses 3 components instead of 2."
)
# Dimension reduction method # Dimension reduction method
method = st.sidebar.selectbox( method = st.sidebar.selectbox(
"Dimension Reduction Method", "Dimension Reduction Method",
@@ -45,7 +52,7 @@ def create_method_controls():
help="Apply clustering to identify groups. HDBSCAN and OPTICS can find variable density clusters." help="Apply clustering to identify groups. HDBSCAN and OPTICS can find variable density clusters."
) )
return method, clustering_method return method, clustering_method, enable_3d
def create_clustering_controls(clustering_method): def create_clustering_controls(clustering_method):
@@ -196,7 +203,7 @@ def display_performance_warnings(filtered_df, method, clustering_method):
def get_all_ui_parameters(valid_df): def get_all_ui_parameters(valid_df):
"""Get all UI parameters in a single function call""" """Get all UI parameters in a single function call"""
# Method selection # Method selection
method, clustering_method = create_method_controls() method, clustering_method, enable_3d = create_method_controls()
# Clustering parameters # Clustering parameters
n_clusters = create_clustering_controls(clustering_method) n_clusters = create_clustering_controls(clustering_method)
@@ -219,6 +226,7 @@ def get_all_ui_parameters(valid_df):
return { return {
'method': method, 'method': method,
'clustering_method': clustering_method, 'clustering_method': clustering_method,
'enable_3d': enable_3d,
'n_clusters': n_clusters, 'n_clusters': n_clusters,
'spread_factor': spread_factor, 'spread_factor': spread_factor,
'perplexity_factor': perplexity_factor, 'perplexity_factor': perplexity_factor,

View File

@@ -47,7 +47,7 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text, def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA"): point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
"""Create a plot colored by clusters""" """Create a plot colored by clusters"""
fig = go.Figure() fig = go.Figure()
@@ -63,6 +63,23 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise" cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
if enable_3d:
fig.add_trace(go.Scatter3d(
x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1],
z=cluster_embeddings[:, 2],
mode='markers',
name=cluster_name,
marker=dict(
size=cluster_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=cluster_hover
))
else:
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=cluster_embeddings[:, 0], x=cluster_embeddings[:, 0],
y=cluster_embeddings[:, 1], y=cluster_embeddings[:, 1],
@@ -82,7 +99,7 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text, def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY): point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
"""Create a plot colored by source files""" """Create a plot colored by source files"""
fig = go.Figure() fig = go.Figure()
colors = px.colors.qualitative.Set1 colors = px.colors.qualitative.Set1
@@ -94,6 +111,23 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask] source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask] source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
if enable_3d:
fig.add_trace(go.Scatter3d(
x=source_embeddings[:, 0],
y=source_embeddings[:, 1],
z=source_embeddings[:, 2],
mode='markers',
name=source,
marker=dict(
size=source_sizes,
color=colors[i % len(colors)],
opacity=point_opacity,
line=dict(width=1, color='white')
),
hovertemplate='%{hovertext}<extra></extra>',
hovertext=source_hover
))
else:
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=source_embeddings[:, 0], x=source_embeddings[:, 0],
y=source_embeddings[:, 1], y=source_embeddings[:, 1],
@@ -115,7 +149,7 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None, def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
selected_sources=None, method="PCA", clustering_method="None", selected_sources=None, method="PCA", clustering_method="None",
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY, point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
density_based_sizing=False, size_variation=2.0): density_based_sizing=False, size_variation=2.0, enable_3d=False):
"""Create the main visualization plot""" """Create the main visualization plot"""
# Create hover text # Create hover text
@@ -128,17 +162,31 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No
# Create plot based on coloring strategy # Create plot based on coloring strategy
if cluster_labels is not None: if cluster_labels is not None:
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
hover_text, point_sizes, point_opacity, method) hover_text, point_sizes, point_opacity, method, enable_3d)
else: else:
if selected_sources is None: if selected_sources is None:
selected_sources = filtered_df['source_file'].unique() selected_sources = filtered_df['source_file'].unique()
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
hover_text, point_sizes, point_opacity) hover_text, point_sizes, point_opacity, enable_3d)
# Update layout # Update layout
title_suffix = f" with {clustering_method}" if clustering_method != "None" else "" title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
dimension_text = "3D" if enable_3d else "2D"
if enable_3d:
fig.update_layout( fig.update_layout(
title=f"Discord Chat Messages - {method} Visualization{title_suffix}", title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
scene=dict(
xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2",
zaxis_title=f"{method} Component 3"
),
width=1000,
height=700
)
else:
fig.update_layout(
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
xaxis_title=f"{method} Component 1", xaxis_title=f"{method} Component 1",
yaxis_title=f"{method} Component 2", yaxis_title=f"{method} Component 2",
hovermode='closest', hovermode='closest',
@@ -182,7 +230,7 @@ def display_summary_stats(filtered_df, selected_sources):
st.metric("Source Files", len(selected_sources)) st.metric("Source Files", len(selected_sources))
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method): def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
"""Display clustering results and export options""" """Display clustering results and export options"""
if cluster_labels is None: if cluster_labels is None:
return return
@@ -195,16 +243,21 @@ def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings,
export_df['x_coordinate'] = reduced_embeddings[:, 0] export_df['x_coordinate'] = reduced_embeddings[:, 0]
export_df['y_coordinate'] = reduced_embeddings[:, 1] export_df['y_coordinate'] = reduced_embeddings[:, 1]
# Add z coordinate if 3D
if enable_3d and reduced_embeddings.shape[1] >= 3:
export_df['z_coordinate'] = reduced_embeddings[:, 2]
# Show cluster distribution # Show cluster distribution
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index() cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
st.bar_chart(cluster_dist) st.bar_chart(cluster_dist)
# Download option # Download option
csv_data = export_df.to_csv(index=False) csv_data = export_df.to_csv(index=False)
dimension_text = "3D" if enable_3d else "2D"
st.download_button( st.download_button(
label="📥 Download Clustering Results (CSV)", label="📥 Download Clustering Results (CSV)",
data=csv_data, data=csv_data,
file_name=f"chat_clusters_{method}_{clustering_method}.csv", file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
mime="text/csv" mime="text/csv"
) )