Files
cult-scraper/apps/cluster_map/main.py
2025-08-11 02:37:21 +01:00

133 lines
4.3 KiB
Python

"""
Main application logic for the Discord Chat Embeddings Visualizer.
"""
import streamlit as st
import warnings
warnings.filterwarnings('ignore')
# Import custom modules
from ui_components import (
setup_page_config, display_title_and_description, get_all_ui_parameters,
display_performance_warnings
)
from data_loader import (
load_all_chat_data, parse_embeddings, filter_data, get_filtered_embeddings
)
from dimensionality_reduction import (
reduce_dimensions, apply_density_based_jittering
)
from clustering import apply_clustering
from visualization import (
create_visualization_plot, display_clustering_metrics, display_summary_stats,
display_clustering_results, display_data_table
)
def main():
"""Main application function"""
# Set up page configuration
setup_page_config()
# Display title and description
display_title_and_description()
# Load data
with st.spinner("Loading chat data..."):
df = load_all_chat_data()
if df.empty:
st.error("No data could be loaded. Please check the data directory.")
st.stop()
# Parse embeddings
with st.spinner("Parsing embeddings..."):
embeddings, valid_df = parse_embeddings(df)
if len(embeddings) == 0:
st.error("No valid embeddings found!")
st.stop()
# Get UI parameters
params = get_all_ui_parameters(valid_df)
# Filter data
filtered_df = filter_data(valid_df, params['selected_sources'], params['selected_authors'])
if filtered_df.empty:
st.warning("No data matches the current filters!")
st.stop()
# Display performance warnings
display_performance_warnings(filtered_df, params['method'], params['clustering_method'])
# Get corresponding embeddings
filtered_embeddings = get_filtered_embeddings(embeddings, valid_df, filtered_df)
st.info(f"📈 Visualizing {len(filtered_df)} messages")
# Reduce dimensions
with st.spinner(f"Reducing dimensions using {params['method']}..."):
reduced_embeddings = reduce_dimensions(
filtered_embeddings,
method=params['method'],
spread_factor=params['spread_factor'],
perplexity_factor=params['perplexity_factor'],
min_dist_factor=params['min_dist_factor']
)
# Apply clustering
with st.spinner(f"Applying {params['clustering_method']}..."):
cluster_labels, silhouette_avg, calinski_harabasz = apply_clustering(
filtered_embeddings,
clustering_method=params['clustering_method'],
n_clusters=params['n_clusters']
)
# Apply jittering if requested
if params['apply_jittering']:
with st.spinner("Applying smart jittering to separate overlapping points..."):
reduced_embeddings = apply_density_based_jittering(
reduced_embeddings,
density_scaling=params['density_based_jitter'],
jitter_strength=params['jitter_strength']
)
# Display clustering metrics
display_clustering_metrics(
cluster_labels, silhouette_avg, calinski_harabasz,
params['show_cluster_metrics']
)
# Create and display the main plot
fig = create_visualization_plot(
reduced_embeddings=reduced_embeddings,
filtered_df=filtered_df,
cluster_labels=cluster_labels,
selected_sources=params['selected_sources'] if params['selected_sources'] else None,
method=params['method'],
clustering_method=params['clustering_method'],
point_size=params['point_size'],
point_opacity=params['point_opacity'],
density_based_sizing=params['density_based_sizing'],
size_variation=params['size_variation']
)
st.plotly_chart(fig, use_container_width=True)
# Display summary statistics
display_summary_stats(filtered_df, params['selected_sources'] or filtered_df['source_file'].unique())
# Display clustering results and export options
display_clustering_results(
filtered_df, cluster_labels, reduced_embeddings,
params['method'], params['clustering_method']
)
# Display data table
display_data_table(filtered_df, cluster_labels)
if __name__ == "__main__":
main()