""" Main application logic for the Discord Chat Embeddings Visualizer. """ import streamlit as st import warnings warnings.filterwarnings('ignore') # Import custom modules from ui_components import ( setup_page_config, display_title_and_description, get_all_ui_parameters, display_performance_warnings ) from data_loader import ( load_all_chat_data, parse_embeddings, filter_data, get_filtered_embeddings ) from dimensionality_reduction import ( reduce_dimensions, apply_density_based_jittering ) from clustering import apply_clustering, generate_cluster_names from visualization import ( create_visualization_plot, display_clustering_metrics, display_summary_stats, display_clustering_results, display_data_table, display_cluster_summary ) def main(): """Main application function""" # Set up page configuration setup_page_config() # Display title and description display_title_and_description() # Load data with st.spinner("Loading chat data..."): df = load_all_chat_data() if df.empty: st.error("No data could be loaded. Please check the data directory.") st.stop() # Parse embeddings with st.spinner("Parsing embeddings..."): embeddings, valid_df = parse_embeddings(df) if len(embeddings) == 0: st.error("No valid embeddings found!") st.stop() # Get UI parameters params = get_all_ui_parameters(valid_df) # Check if any sources are selected before proceeding if not params['selected_sources']: st.info("📂 **Select source files from the sidebar to begin visualization**") st.markdown("### Available Data Sources:") # Show available sources as an informational table source_info = [] for source in valid_df['source_file'].unique(): source_data = valid_df[valid_df['source_file'] == source] source_info.append({ 'Source File': source, 'Messages': len(source_data), 'Unique Authors': source_data['author_name'].nunique(), 'Date Range': f"{source_data['timestamp_utc'].min()} to {source_data['timestamp_utc'].max()}" }) import pandas as pd source_df = pd.DataFrame(source_info) st.dataframe(source_df, use_container_width=True, hide_index=True) st.markdown("👈 **Use the sidebar to select which sources to visualize**") st.stop() # Filter data filtered_df = filter_data(valid_df, params['selected_sources'], params['selected_authors']) if filtered_df.empty: st.warning("No data matches the current filters! Try selecting different sources or authors.") st.stop() # Display performance warnings display_performance_warnings(filtered_df, params['method'], params['clustering_method']) # Get corresponding embeddings filtered_embeddings = get_filtered_embeddings(embeddings, valid_df, filtered_df) st.info(f"📈 Visualizing {len(filtered_df)} messages") # Reduce dimensions n_components = 3 if params['enable_3d'] else 2 with st.spinner(f"Reducing dimensions using {params['method']}..."): reduced_embeddings = reduce_dimensions( filtered_embeddings, method=params['method'], n_components=n_components, spread_factor=params['spread_factor'], perplexity_factor=params['perplexity_factor'], min_dist_factor=params['min_dist_factor'] ) # Apply clustering with st.spinner(f"Applying {params['clustering_method']}..."): cluster_labels, silhouette_avg, calinski_harabasz = apply_clustering( filtered_embeddings, clustering_method=params['clustering_method'], n_clusters=params['n_clusters'] ) # Apply jittering if requested if params['apply_jittering']: with st.spinner("Applying smart jittering to separate overlapping points..."): reduced_embeddings = apply_density_based_jittering( reduced_embeddings, density_scaling=params['density_based_jitter'], jitter_strength=params['jitter_strength'] ) # Generate cluster names if clustering was applied cluster_names = None if cluster_labels is not None: with st.spinner("Generating cluster names..."): cluster_names = generate_cluster_names(filtered_df, cluster_labels) # Display clustering metrics display_clustering_metrics( cluster_labels, silhouette_avg, calinski_harabasz, params['show_cluster_metrics'] ) # Display cluster summary with names if cluster_names: display_cluster_summary(cluster_names, cluster_labels) # Create and display the main plot fig = create_visualization_plot( reduced_embeddings=reduced_embeddings, filtered_df=filtered_df, cluster_labels=cluster_labels, selected_sources=params['selected_sources'] if params['selected_sources'] else None, method=params['method'], clustering_method=params['clustering_method'], point_size=params['point_size'], point_opacity=params['point_opacity'], density_based_sizing=params['density_based_sizing'], size_variation=params['size_variation'], enable_3d=params['enable_3d'], cluster_names=cluster_names ) st.plotly_chart(fig, use_container_width=True) # Display summary statistics display_summary_stats(filtered_df, params['selected_sources'] or filtered_df['source_file'].unique()) # Display clustering results and export options display_clustering_results( filtered_df, cluster_labels, reduced_embeddings, params['method'], params['clustering_method'], params['enable_3d'] ) # Display data table display_data_table(filtered_df, cluster_labels) if __name__ == "__main__": main()