3d viz
This commit is contained in:
@@ -67,10 +67,12 @@ def main():
|
|||||||
st.info(f"📈 Visualizing {len(filtered_df)} messages")
|
st.info(f"📈 Visualizing {len(filtered_df)} messages")
|
||||||
|
|
||||||
# Reduce dimensions
|
# Reduce dimensions
|
||||||
|
n_components = 3 if params['enable_3d'] else 2
|
||||||
with st.spinner(f"Reducing dimensions using {params['method']}..."):
|
with st.spinner(f"Reducing dimensions using {params['method']}..."):
|
||||||
reduced_embeddings = reduce_dimensions(
|
reduced_embeddings = reduce_dimensions(
|
||||||
filtered_embeddings,
|
filtered_embeddings,
|
||||||
method=params['method'],
|
method=params['method'],
|
||||||
|
n_components=n_components,
|
||||||
spread_factor=params['spread_factor'],
|
spread_factor=params['spread_factor'],
|
||||||
perplexity_factor=params['perplexity_factor'],
|
perplexity_factor=params['perplexity_factor'],
|
||||||
min_dist_factor=params['min_dist_factor']
|
min_dist_factor=params['min_dist_factor']
|
||||||
@@ -110,7 +112,8 @@ def main():
|
|||||||
point_size=params['point_size'],
|
point_size=params['point_size'],
|
||||||
point_opacity=params['point_opacity'],
|
point_opacity=params['point_opacity'],
|
||||||
density_based_sizing=params['density_based_sizing'],
|
density_based_sizing=params['density_based_sizing'],
|
||||||
size_variation=params['size_variation']
|
size_variation=params['size_variation'],
|
||||||
|
enable_3d=params['enable_3d']
|
||||||
)
|
)
|
||||||
|
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
@@ -121,7 +124,7 @@ def main():
|
|||||||
# Display clustering results and export options
|
# Display clustering results and export options
|
||||||
display_clustering_results(
|
display_clustering_results(
|
||||||
filtered_df, cluster_labels, reduced_embeddings,
|
filtered_df, cluster_labels, reduced_embeddings,
|
||||||
params['method'], params['clustering_method']
|
params['method'], params['clustering_method'], params['enable_3d']
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display data table
|
# Display data table
|
||||||
|
|||||||
@@ -30,6 +30,13 @@ def create_method_controls():
|
|||||||
"""Create controls for dimension reduction and clustering methods"""
|
"""Create controls for dimension reduction and clustering methods"""
|
||||||
st.sidebar.header("🎛️ Visualization Controls")
|
st.sidebar.header("🎛️ Visualization Controls")
|
||||||
|
|
||||||
|
# 3D visualization toggle
|
||||||
|
enable_3d = st.sidebar.checkbox(
|
||||||
|
"Enable 3D Visualization",
|
||||||
|
value=False,
|
||||||
|
help="Switch between 2D and 3D visualization. 3D uses 3 components instead of 2."
|
||||||
|
)
|
||||||
|
|
||||||
# Dimension reduction method
|
# Dimension reduction method
|
||||||
method = st.sidebar.selectbox(
|
method = st.sidebar.selectbox(
|
||||||
"Dimension Reduction Method",
|
"Dimension Reduction Method",
|
||||||
@@ -45,7 +52,7 @@ def create_method_controls():
|
|||||||
help="Apply clustering to identify groups. HDBSCAN and OPTICS can find variable density clusters."
|
help="Apply clustering to identify groups. HDBSCAN and OPTICS can find variable density clusters."
|
||||||
)
|
)
|
||||||
|
|
||||||
return method, clustering_method
|
return method, clustering_method, enable_3d
|
||||||
|
|
||||||
|
|
||||||
def create_clustering_controls(clustering_method):
|
def create_clustering_controls(clustering_method):
|
||||||
@@ -196,7 +203,7 @@ def display_performance_warnings(filtered_df, method, clustering_method):
|
|||||||
def get_all_ui_parameters(valid_df):
|
def get_all_ui_parameters(valid_df):
|
||||||
"""Get all UI parameters in a single function call"""
|
"""Get all UI parameters in a single function call"""
|
||||||
# Method selection
|
# Method selection
|
||||||
method, clustering_method = create_method_controls()
|
method, clustering_method, enable_3d = create_method_controls()
|
||||||
|
|
||||||
# Clustering parameters
|
# Clustering parameters
|
||||||
n_clusters = create_clustering_controls(clustering_method)
|
n_clusters = create_clustering_controls(clustering_method)
|
||||||
@@ -219,6 +226,7 @@ def get_all_ui_parameters(valid_df):
|
|||||||
return {
|
return {
|
||||||
'method': method,
|
'method': method,
|
||||||
'clustering_method': clustering_method,
|
'clustering_method': clustering_method,
|
||||||
|
'enable_3d': enable_3d,
|
||||||
'n_clusters': n_clusters,
|
'n_clusters': n_clusters,
|
||||||
'spread_factor': spread_factor,
|
'spread_factor': spread_factor,
|
||||||
'perplexity_factor': perplexity_factor,
|
'perplexity_factor': perplexity_factor,
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
|
|||||||
|
|
||||||
|
|
||||||
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
||||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA"):
|
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
|
||||||
"""Create a plot colored by clusters"""
|
"""Create a plot colored by clusters"""
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
||||||
@@ -63,26 +63,43 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
|
|||||||
|
|
||||||
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
||||||
|
|
||||||
fig.add_trace(go.Scatter(
|
if enable_3d:
|
||||||
x=cluster_embeddings[:, 0],
|
fig.add_trace(go.Scatter3d(
|
||||||
y=cluster_embeddings[:, 1],
|
x=cluster_embeddings[:, 0],
|
||||||
mode='markers',
|
y=cluster_embeddings[:, 1],
|
||||||
name=cluster_name,
|
z=cluster_embeddings[:, 2],
|
||||||
marker=dict(
|
mode='markers',
|
||||||
size=cluster_sizes,
|
name=cluster_name,
|
||||||
color=colors[i % len(colors)],
|
marker=dict(
|
||||||
opacity=point_opacity,
|
size=cluster_sizes,
|
||||||
line=dict(width=1, color='white')
|
color=colors[i % len(colors)],
|
||||||
),
|
opacity=point_opacity,
|
||||||
hovertemplate='%{hovertext}<extra></extra>',
|
line=dict(width=1, color='white')
|
||||||
hovertext=cluster_hover
|
),
|
||||||
))
|
hovertemplate='%{hovertext}<extra></extra>',
|
||||||
|
hovertext=cluster_hover
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
fig.add_trace(go.Scatter(
|
||||||
|
x=cluster_embeddings[:, 0],
|
||||||
|
y=cluster_embeddings[:, 1],
|
||||||
|
mode='markers',
|
||||||
|
name=cluster_name,
|
||||||
|
marker=dict(
|
||||||
|
size=cluster_sizes,
|
||||||
|
color=colors[i % len(colors)],
|
||||||
|
opacity=point_opacity,
|
||||||
|
line=dict(width=1, color='white')
|
||||||
|
),
|
||||||
|
hovertemplate='%{hovertext}<extra></extra>',
|
||||||
|
hovertext=cluster_hover
|
||||||
|
))
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
|
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
|
||||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY):
|
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
|
||||||
"""Create a plot colored by source files"""
|
"""Create a plot colored by source files"""
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
colors = px.colors.qualitative.Set1
|
colors = px.colors.qualitative.Set1
|
||||||
@@ -94,20 +111,37 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
|
|||||||
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
|
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
|
||||||
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
|
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
|
||||||
|
|
||||||
fig.add_trace(go.Scatter(
|
if enable_3d:
|
||||||
x=source_embeddings[:, 0],
|
fig.add_trace(go.Scatter3d(
|
||||||
y=source_embeddings[:, 1],
|
x=source_embeddings[:, 0],
|
||||||
mode='markers',
|
y=source_embeddings[:, 1],
|
||||||
name=source,
|
z=source_embeddings[:, 2],
|
||||||
marker=dict(
|
mode='markers',
|
||||||
size=source_sizes,
|
name=source,
|
||||||
color=colors[i % len(colors)],
|
marker=dict(
|
||||||
opacity=point_opacity,
|
size=source_sizes,
|
||||||
line=dict(width=1, color='white')
|
color=colors[i % len(colors)],
|
||||||
),
|
opacity=point_opacity,
|
||||||
hovertemplate='%{hovertext}<extra></extra>',
|
line=dict(width=1, color='white')
|
||||||
hovertext=source_hover
|
),
|
||||||
))
|
hovertemplate='%{hovertext}<extra></extra>',
|
||||||
|
hovertext=source_hover
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
fig.add_trace(go.Scatter(
|
||||||
|
x=source_embeddings[:, 0],
|
||||||
|
y=source_embeddings[:, 1],
|
||||||
|
mode='markers',
|
||||||
|
name=source,
|
||||||
|
marker=dict(
|
||||||
|
size=source_sizes,
|
||||||
|
color=colors[i % len(colors)],
|
||||||
|
opacity=point_opacity,
|
||||||
|
line=dict(width=1, color='white')
|
||||||
|
),
|
||||||
|
hovertemplate='%{hovertext}<extra></extra>',
|
||||||
|
hovertext=source_hover
|
||||||
|
))
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
@@ -115,7 +149,7 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
|
|||||||
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
||||||
selected_sources=None, method="PCA", clustering_method="None",
|
selected_sources=None, method="PCA", clustering_method="None",
|
||||||
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
||||||
density_based_sizing=False, size_variation=2.0):
|
density_based_sizing=False, size_variation=2.0, enable_3d=False):
|
||||||
"""Create the main visualization plot"""
|
"""Create the main visualization plot"""
|
||||||
|
|
||||||
# Create hover text
|
# Create hover text
|
||||||
@@ -128,23 +162,37 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No
|
|||||||
# Create plot based on coloring strategy
|
# Create plot based on coloring strategy
|
||||||
if cluster_labels is not None:
|
if cluster_labels is not None:
|
||||||
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
||||||
hover_text, point_sizes, point_opacity, method)
|
hover_text, point_sizes, point_opacity, method, enable_3d)
|
||||||
else:
|
else:
|
||||||
if selected_sources is None:
|
if selected_sources is None:
|
||||||
selected_sources = filtered_df['source_file'].unique()
|
selected_sources = filtered_df['source_file'].unique()
|
||||||
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
|
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
|
||||||
hover_text, point_sizes, point_opacity)
|
hover_text, point_sizes, point_opacity, enable_3d)
|
||||||
|
|
||||||
# Update layout
|
# Update layout
|
||||||
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
|
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
|
||||||
fig.update_layout(
|
dimension_text = "3D" if enable_3d else "2D"
|
||||||
title=f"Discord Chat Messages - {method} Visualization{title_suffix}",
|
|
||||||
xaxis_title=f"{method} Component 1",
|
if enable_3d:
|
||||||
yaxis_title=f"{method} Component 2",
|
fig.update_layout(
|
||||||
hovermode='closest',
|
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
||||||
width=1000,
|
scene=dict(
|
||||||
height=700
|
xaxis_title=f"{method} Component 1",
|
||||||
)
|
yaxis_title=f"{method} Component 2",
|
||||||
|
zaxis_title=f"{method} Component 3"
|
||||||
|
),
|
||||||
|
width=1000,
|
||||||
|
height=700
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
||||||
|
xaxis_title=f"{method} Component 1",
|
||||||
|
yaxis_title=f"{method} Component 2",
|
||||||
|
hovermode='closest',
|
||||||
|
width=1000,
|
||||||
|
height=700
|
||||||
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
@@ -182,7 +230,7 @@ def display_summary_stats(filtered_df, selected_sources):
|
|||||||
st.metric("Source Files", len(selected_sources))
|
st.metric("Source Files", len(selected_sources))
|
||||||
|
|
||||||
|
|
||||||
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method):
|
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
|
||||||
"""Display clustering results and export options"""
|
"""Display clustering results and export options"""
|
||||||
if cluster_labels is None:
|
if cluster_labels is None:
|
||||||
return
|
return
|
||||||
@@ -195,16 +243,21 @@ def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings,
|
|||||||
export_df['x_coordinate'] = reduced_embeddings[:, 0]
|
export_df['x_coordinate'] = reduced_embeddings[:, 0]
|
||||||
export_df['y_coordinate'] = reduced_embeddings[:, 1]
|
export_df['y_coordinate'] = reduced_embeddings[:, 1]
|
||||||
|
|
||||||
|
# Add z coordinate if 3D
|
||||||
|
if enable_3d and reduced_embeddings.shape[1] >= 3:
|
||||||
|
export_df['z_coordinate'] = reduced_embeddings[:, 2]
|
||||||
|
|
||||||
# Show cluster distribution
|
# Show cluster distribution
|
||||||
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
|
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
|
||||||
st.bar_chart(cluster_dist)
|
st.bar_chart(cluster_dist)
|
||||||
|
|
||||||
# Download option
|
# Download option
|
||||||
csv_data = export_df.to_csv(index=False)
|
csv_data = export_df.to_csv(index=False)
|
||||||
|
dimension_text = "3D" if enable_3d else "2D"
|
||||||
st.download_button(
|
st.download_button(
|
||||||
label="📥 Download Clustering Results (CSV)",
|
label="📥 Download Clustering Results (CSV)",
|
||||||
data=csv_data,
|
data=csv_data,
|
||||||
file_name=f"chat_clusters_{method}_{clustering_method}.csv",
|
file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
|
||||||
mime="text/csv"
|
mime="text/csv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user