3d viz
This commit is contained in:
@@ -47,7 +47,7 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
|
||||
|
||||
|
||||
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA"):
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
|
||||
"""Create a plot colored by clusters"""
|
||||
fig = go.Figure()
|
||||
|
||||
@@ -63,26 +63,43 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
|
||||
|
||||
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=cluster_embeddings[:, 0],
|
||||
y=cluster_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=cluster_name,
|
||||
marker=dict(
|
||||
size=cluster_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=cluster_hover
|
||||
))
|
||||
if enable_3d:
|
||||
fig.add_trace(go.Scatter3d(
|
||||
x=cluster_embeddings[:, 0],
|
||||
y=cluster_embeddings[:, 1],
|
||||
z=cluster_embeddings[:, 2],
|
||||
mode='markers',
|
||||
name=cluster_name,
|
||||
marker=dict(
|
||||
size=cluster_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=cluster_hover
|
||||
))
|
||||
else:
|
||||
fig.add_trace(go.Scatter(
|
||||
x=cluster_embeddings[:, 0],
|
||||
y=cluster_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=cluster_name,
|
||||
marker=dict(
|
||||
size=cluster_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=cluster_hover
|
||||
))
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources, hover_text,
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY):
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, enable_3d=False):
|
||||
"""Create a plot colored by source files"""
|
||||
fig = go.Figure()
|
||||
colors = px.colors.qualitative.Set1
|
||||
@@ -94,20 +111,37 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
|
||||
source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask]
|
||||
source_sizes = [point_sizes[j] for j, mask in enumerate(source_mask) if mask]
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=source_embeddings[:, 0],
|
||||
y=source_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=source,
|
||||
marker=dict(
|
||||
size=source_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=source_hover
|
||||
))
|
||||
if enable_3d:
|
||||
fig.add_trace(go.Scatter3d(
|
||||
x=source_embeddings[:, 0],
|
||||
y=source_embeddings[:, 1],
|
||||
z=source_embeddings[:, 2],
|
||||
mode='markers',
|
||||
name=source,
|
||||
marker=dict(
|
||||
size=source_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=source_hover
|
||||
))
|
||||
else:
|
||||
fig.add_trace(go.Scatter(
|
||||
x=source_embeddings[:, 0],
|
||||
y=source_embeddings[:, 1],
|
||||
mode='markers',
|
||||
name=source,
|
||||
marker=dict(
|
||||
size=source_sizes,
|
||||
color=colors[i % len(colors)],
|
||||
opacity=point_opacity,
|
||||
line=dict(width=1, color='white')
|
||||
),
|
||||
hovertemplate='%{hovertext}<extra></extra>',
|
||||
hovertext=source_hover
|
||||
))
|
||||
|
||||
return fig
|
||||
|
||||
@@ -115,7 +149,7 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
|
||||
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
||||
selected_sources=None, method="PCA", clustering_method="None",
|
||||
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
||||
density_based_sizing=False, size_variation=2.0):
|
||||
density_based_sizing=False, size_variation=2.0, enable_3d=False):
|
||||
"""Create the main visualization plot"""
|
||||
|
||||
# Create hover text
|
||||
@@ -128,23 +162,37 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No
|
||||
# Create plot based on coloring strategy
|
||||
if cluster_labels is not None:
|
||||
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
||||
hover_text, point_sizes, point_opacity, method)
|
||||
hover_text, point_sizes, point_opacity, method, enable_3d)
|
||||
else:
|
||||
if selected_sources is None:
|
||||
selected_sources = filtered_df['source_file'].unique()
|
||||
fig = create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources,
|
||||
hover_text, point_sizes, point_opacity)
|
||||
hover_text, point_sizes, point_opacity, enable_3d)
|
||||
|
||||
# Update layout
|
||||
title_suffix = f" with {clustering_method}" if clustering_method != "None" else ""
|
||||
fig.update_layout(
|
||||
title=f"Discord Chat Messages - {method} Visualization{title_suffix}",
|
||||
xaxis_title=f"{method} Component 1",
|
||||
yaxis_title=f"{method} Component 2",
|
||||
hovermode='closest',
|
||||
width=1000,
|
||||
height=700
|
||||
)
|
||||
dimension_text = "3D" if enable_3d else "2D"
|
||||
|
||||
if enable_3d:
|
||||
fig.update_layout(
|
||||
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
||||
scene=dict(
|
||||
xaxis_title=f"{method} Component 1",
|
||||
yaxis_title=f"{method} Component 2",
|
||||
zaxis_title=f"{method} Component 3"
|
||||
),
|
||||
width=1000,
|
||||
height=700
|
||||
)
|
||||
else:
|
||||
fig.update_layout(
|
||||
title=f"Discord Chat Messages - {method} {dimension_text} Visualization{title_suffix}",
|
||||
xaxis_title=f"{method} Component 1",
|
||||
yaxis_title=f"{method} Component 2",
|
||||
hovermode='closest',
|
||||
width=1000,
|
||||
height=700
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
@@ -182,7 +230,7 @@ def display_summary_stats(filtered_df, selected_sources):
|
||||
st.metric("Source Files", len(selected_sources))
|
||||
|
||||
|
||||
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method):
|
||||
def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings, method, clustering_method, enable_3d=False):
|
||||
"""Display clustering results and export options"""
|
||||
if cluster_labels is None:
|
||||
return
|
||||
@@ -195,16 +243,21 @@ def display_clustering_results(filtered_df, cluster_labels, reduced_embeddings,
|
||||
export_df['x_coordinate'] = reduced_embeddings[:, 0]
|
||||
export_df['y_coordinate'] = reduced_embeddings[:, 1]
|
||||
|
||||
# Add z coordinate if 3D
|
||||
if enable_3d and reduced_embeddings.shape[1] >= 3:
|
||||
export_df['z_coordinate'] = reduced_embeddings[:, 2]
|
||||
|
||||
# Show cluster distribution
|
||||
cluster_dist = pd.Series(cluster_labels).value_counts().sort_index()
|
||||
st.bar_chart(cluster_dist)
|
||||
|
||||
# Download option
|
||||
csv_data = export_df.to_csv(index=False)
|
||||
dimension_text = "3D" if enable_3d else "2D"
|
||||
st.download_button(
|
||||
label="📥 Download Clustering Results (CSV)",
|
||||
data=csv_data,
|
||||
file_name=f"chat_clusters_{method}_{clustering_method}.csv",
|
||||
file_name=f"chat_clusters_{method}_{clustering_method}_{dimension_text}.csv",
|
||||
mime="text/csv"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user