""" Main application logic for the Discord Chat Embeddings Visualizer. """ import streamlit as st import warnings warnings.filterwarnings('ignore') # Import custom modules from ui_components import ( setup_page_config, display_title_and_description, get_all_ui_parameters, display_performance_warnings ) from data_loader import ( load_all_chat_data, parse_embeddings, filter_data, get_filtered_embeddings ) from dimensionality_reduction import ( reduce_dimensions, apply_density_based_jittering ) from clustering import apply_clustering from visualization import ( create_visualization_plot, display_clustering_metrics, display_summary_stats, display_clustering_results, display_data_table ) def main(): """Main application function""" # Set up page configuration setup_page_config() # Display title and description display_title_and_description() # Load data with st.spinner("Loading chat data..."): df = load_all_chat_data() if df.empty: st.error("No data could be loaded. Please check the data directory.") st.stop() # Parse embeddings with st.spinner("Parsing embeddings..."): embeddings, valid_df = parse_embeddings(df) if len(embeddings) == 0: st.error("No valid embeddings found!") st.stop() # Get UI parameters params = get_all_ui_parameters(valid_df) # Filter data filtered_df = filter_data(valid_df, params['selected_sources'], params['selected_authors']) if filtered_df.empty: st.warning("No data matches the current filters!") st.stop() # Display performance warnings display_performance_warnings(filtered_df, params['method'], params['clustering_method']) # Get corresponding embeddings filtered_embeddings = get_filtered_embeddings(embeddings, valid_df, filtered_df) st.info(f"📈 Visualizing {len(filtered_df)} messages") # Reduce dimensions with st.spinner(f"Reducing dimensions using {params['method']}..."): reduced_embeddings = reduce_dimensions( filtered_embeddings, method=params['method'], spread_factor=params['spread_factor'], perplexity_factor=params['perplexity_factor'], min_dist_factor=params['min_dist_factor'] ) # Apply clustering with st.spinner(f"Applying {params['clustering_method']}..."): cluster_labels, silhouette_avg, calinski_harabasz = apply_clustering( filtered_embeddings, clustering_method=params['clustering_method'], n_clusters=params['n_clusters'] ) # Apply jittering if requested if params['apply_jittering']: with st.spinner("Applying smart jittering to separate overlapping points..."): reduced_embeddings = apply_density_based_jittering( reduced_embeddings, density_scaling=params['density_based_jitter'], jitter_strength=params['jitter_strength'] ) # Display clustering metrics display_clustering_metrics( cluster_labels, silhouette_avg, calinski_harabasz, params['show_cluster_metrics'] ) # Create and display the main plot fig = create_visualization_plot( reduced_embeddings=reduced_embeddings, filtered_df=filtered_df, cluster_labels=cluster_labels, selected_sources=params['selected_sources'] if params['selected_sources'] else None, method=params['method'], clustering_method=params['clustering_method'], point_size=params['point_size'], point_opacity=params['point_opacity'], density_based_sizing=params['density_based_sizing'], size_variation=params['size_variation'] ) st.plotly_chart(fig, use_container_width=True) # Display summary statistics display_summary_stats(filtered_df, params['selected_sources'] or filtered_df['source_file'].unique()) # Display clustering results and export options display_clustering_results( filtered_df, cluster_labels, reduced_embeddings, params['method'], params['clustering_method'] ) # Display data table display_data_table(filtered_df, cluster_labels) if __name__ == "__main__": main()