beter clusters and qol
This commit is contained in:
@@ -47,7 +47,8 @@ def calculate_point_sizes(reduced_embeddings, density_based_sizing=False,
|
||||
|
||||
|
||||
def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover_text,
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False):
|
||||
point_sizes, point_opacity=DEFAULT_POINT_OPACITY, method="PCA", enable_3d=False,
|
||||
cluster_names=None):
|
||||
"""Create a plot colored by clusters"""
|
||||
fig = go.Figure()
|
||||
|
||||
@@ -61,7 +62,11 @@ def create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels, hover
|
||||
cluster_hover = [hover_text[j] for j, mask in enumerate(cluster_mask) if mask]
|
||||
cluster_sizes = [point_sizes[j] for j, mask in enumerate(cluster_mask) if mask]
|
||||
|
||||
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
||||
# Use generated name if available, otherwise fall back to default
|
||||
if cluster_names and cluster_id in cluster_names:
|
||||
cluster_name = cluster_names[cluster_id]
|
||||
else:
|
||||
cluster_name = f"Cluster {cluster_id}" if cluster_id != -1 else "Noise"
|
||||
|
||||
if enable_3d:
|
||||
fig.add_trace(go.Scatter3d(
|
||||
@@ -149,7 +154,8 @@ def create_source_colored_plot(reduced_embeddings, filtered_df, selected_sources
|
||||
def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=None,
|
||||
selected_sources=None, method="PCA", clustering_method="None",
|
||||
point_size=DEFAULT_POINT_SIZE, point_opacity=DEFAULT_POINT_OPACITY,
|
||||
density_based_sizing=False, size_variation=2.0, enable_3d=False):
|
||||
density_based_sizing=False, size_variation=2.0, enable_3d=False,
|
||||
cluster_names=None):
|
||||
"""Create the main visualization plot"""
|
||||
|
||||
# Create hover text
|
||||
@@ -162,7 +168,8 @@ def create_visualization_plot(reduced_embeddings, filtered_df, cluster_labels=No
|
||||
# Create plot based on coloring strategy
|
||||
if cluster_labels is not None:
|
||||
fig = create_clustered_plot(reduced_embeddings, filtered_df, cluster_labels,
|
||||
hover_text, point_sizes, point_opacity, method, enable_3d)
|
||||
hover_text, point_sizes, point_opacity, method, enable_3d,
|
||||
cluster_names)
|
||||
else:
|
||||
if selected_sources is None:
|
||||
selected_sources = filtered_df['source_file'].unique()
|
||||
@@ -276,3 +283,29 @@ def display_data_table(filtered_df, cluster_labels=None):
|
||||
|
||||
display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display
|
||||
st.dataframe(display_df, use_container_width=True)
|
||||
|
||||
|
||||
def display_cluster_summary(cluster_names, cluster_labels):
|
||||
"""Display a summary of cluster names and their sizes"""
|
||||
if not cluster_names or cluster_labels is None:
|
||||
return
|
||||
|
||||
st.subheader("🏷️ Cluster Summary")
|
||||
|
||||
# Create summary data
|
||||
cluster_summary = []
|
||||
for cluster_id, name in cluster_names.items():
|
||||
count = np.sum(cluster_labels == cluster_id)
|
||||
cluster_summary.append({
|
||||
'Cluster ID': cluster_id,
|
||||
'Cluster Name': name,
|
||||
'Message Count': count,
|
||||
'Percentage': f"{100 * count / len(cluster_labels):.1f}%"
|
||||
})
|
||||
|
||||
# Sort by message count
|
||||
cluster_summary.sort(key=lambda x: x['Message Count'], reverse=True)
|
||||
|
||||
# Display as table
|
||||
summary_df = pd.DataFrame(cluster_summary)
|
||||
st.dataframe(summary_df, use_container_width=True, hide_index=True)
|
||||
|
||||
Reference in New Issue
Block a user