beter clusters and qol

This commit is contained in:
2025-08-11 03:04:50 +01:00
parent 647111e9d3
commit 2b8659fc95
5 changed files with 234 additions and 15 deletions

View File

@@ -47,7 +47,8 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False,
cluster_names=None):
"""Create a plot colored by clusters"""
fig = go.Figure()
@@ -61,7 +62,11 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
# Use generated name if available, otherwise fall back to default
if cluster_names and cluster_id in cluster_names:
cluster_name = cluster_names[cluster_id]
else:
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
if enable_3d:
fig.add_trace(go.Scatter3d(
@@ -149,7 +154,8 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
selected_sources=None, method="PCA", clustering_method="None",
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
density_based_sizing=False, size_variation=2.0, enable_3d=False):
density_based_sizing=False, size_variation=2.0, enable_3d=False,
cluster_names=None):
"""Create the main visualization plot"""
# Create hover text
@@ -162,7 +168,8 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No
# Create plot based on coloring strategy
if cluster_labels is not None:
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
hover_text, point_sizes, point_opacity, method, enable_3d)
hover_text, point_sizes, point_opacity, method, enable_3d,
cluster_names)
else:
if selected_sources is None:
selected_sources = filtered_df['source_file'].unique()
@@ -276,3 +283,29 @@ def display_data_table(filtered_df, cluster_labels=None):
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
st.dataframe(display_df, use_container_width=True)
def display_cluster_summary(cluster_names, cluster_labels):
"""Display a summary of cluster names and their sizes"""
if not cluster_names or cluster_labels is None:
return
st.subheader("🏷️ Cluster Summary")
# Create summary data
cluster_summary = []
for cluster_id, name in cluster_names.items():
count = np.sum(cluster_labels == cluster_id)
cluster_summary.append({
'Cluster ID': cluster_id,
'Cluster Name': name,
'Message Count': count,
'Percentage': f"{100 * count / len(cluster_labels):.1f}%"
})
# Sort by message count
cluster_summary.sort(key=lambda x: x['Message Count'], reverse=True)
# Display as table
summary_df = pd.DataFrame(cluster_summary)
st.dataframe(summary_df, use_container_width=True, hide_index=True)