170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
"""
|
|
Main application logic for the Discord Chat Embeddings Visualizer.
|
|
"""
|
|
|
|
import streamlit as st
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
# Import custom modules
|
|
from ui_components import (
|
|
setup_page_config, display_title_and_description, get_all_ui_parameters,
|
|
display_performance_warnings
|
|
)
|
|
from data_loader import (
|
|
load_all_chat_data, parse_embeddings, filter_data, get_filtered_embeddings
|
|
)
|
|
from dimensionality_reduction import (
|
|
reduce_dimensions, apply_density_based_jittering
|
|
)
|
|
from clustering import apply_clustering, generate_cluster_names
|
|
from visualization import (
|
|
create_visualization_plot, display_clustering_metrics, display_summary_stats,
|
|
display_clustering_results, display_data_table, display_cluster_summary
|
|
)
|
|
|
|
|
|
def main():
|
|
"""Main application function"""
|
|
# Set up page configuration
|
|
setup_page_config()
|
|
|
|
# Display title and description
|
|
display_title_and_description()
|
|
|
|
# Load data
|
|
with st.spinner("Loading chat data..."):
|
|
df = load_all_chat_data()
|
|
|
|
if df.empty:
|
|
st.error("No data could be loaded. Please check the data directory.")
|
|
st.stop()
|
|
|
|
# Parse embeddings
|
|
with st.spinner("Parsing embeddings..."):
|
|
embeddings, valid_df = parse_embeddings(df)
|
|
|
|
if len(embeddings) == 0:
|
|
st.error("No valid embeddings found!")
|
|
st.stop()
|
|
|
|
# Get UI parameters
|
|
params = get_all_ui_parameters(valid_df)
|
|
|
|
# Check if any sources are selected before proceeding
|
|
if not params['selected_sources']:
|
|
st.info("📂 **Select source files from the sidebar to begin visualization**")
|
|
st.markdown("### Available Data Sources:")
|
|
|
|
# Show available sources as an informational table
|
|
source_info = []
|
|
for source in valid_df['source_file'].unique():
|
|
source_data = valid_df[valid_df['source_file'] == source]
|
|
source_info.append({
|
|
'Source File': source,
|
|
'Messages': len(source_data),
|
|
'Unique Authors': source_data['author_name'].nunique(),
|
|
'Date Range': f"{source_data['timestamp_utc'].min()} to {source_data['timestamp_utc'].max()}"
|
|
})
|
|
|
|
import pandas as pd
|
|
source_df = pd.DataFrame(source_info)
|
|
st.dataframe(source_df, use_container_width=True, hide_index=True)
|
|
|
|
st.markdown("👈 **Use the sidebar to select which sources to visualize**")
|
|
st.stop()
|
|
|
|
# Filter data
|
|
filtered_df = filter_data(valid_df, params['selected_sources'], params['selected_authors'])
|
|
|
|
if filtered_df.empty:
|
|
st.warning("No data matches the current filters! Try selecting different sources or authors.")
|
|
st.stop()
|
|
|
|
# Display performance warnings
|
|
display_performance_warnings(filtered_df, params['method'], params['clustering_method'])
|
|
|
|
# Get corresponding embeddings
|
|
filtered_embeddings = get_filtered_embeddings(embeddings, valid_df, filtered_df)
|
|
|
|
st.info(f"📈 Visualizing {len(filtered_df)} messages")
|
|
|
|
# Reduce dimensions
|
|
n_components = 3 if params['enable_3d'] else 2
|
|
with st.spinner(f"Reducing dimensions using {params['method']}..."):
|
|
reduced_embeddings = reduce_dimensions(
|
|
filtered_embeddings,
|
|
method=params['method'],
|
|
n_components=n_components,
|
|
spread_factor=params['spread_factor'],
|
|
perplexity_factor=params['perplexity_factor'],
|
|
min_dist_factor=params['min_dist_factor']
|
|
)
|
|
|
|
# Apply clustering
|
|
with st.spinner(f"Applying {params['clustering_method']}..."):
|
|
cluster_labels, silhouette_avg, calinski_harabasz = apply_clustering(
|
|
filtered_embeddings,
|
|
clustering_method=params['clustering_method'],
|
|
n_clusters=params['n_clusters']
|
|
)
|
|
|
|
# Apply jittering if requested
|
|
if params['apply_jittering']:
|
|
with st.spinner("Applying smart jittering to separate overlapping points..."):
|
|
reduced_embeddings = apply_density_based_jittering(
|
|
reduced_embeddings,
|
|
density_scaling=params['density_based_jitter'],
|
|
jitter_strength=params['jitter_strength']
|
|
)
|
|
|
|
# Generate cluster names if clustering was applied
|
|
cluster_names = None
|
|
if cluster_labels is not None:
|
|
with st.spinner("Generating cluster names..."):
|
|
cluster_names = generate_cluster_names(filtered_df, cluster_labels)
|
|
|
|
# Display clustering metrics
|
|
display_clustering_metrics(
|
|
cluster_labels, silhouette_avg, calinski_harabasz,
|
|
params['show_cluster_metrics']
|
|
)
|
|
|
|
# Display cluster summary with names
|
|
if cluster_names:
|
|
display_cluster_summary(cluster_names, cluster_labels)
|
|
|
|
# Create and display the main plot
|
|
fig = create_visualization_plot(
|
|
reduced_embeddings=reduced_embeddings,
|
|
filtered_df=filtered_df,
|
|
cluster_labels=cluster_labels,
|
|
selected_sources=params['selected_sources'] if params['selected_sources'] else None,
|
|
method=params['method'],
|
|
clustering_method=params['clustering_method'],
|
|
point_size=params['point_size'],
|
|
point_opacity=params['point_opacity'],
|
|
density_based_sizing=params['density_based_sizing'],
|
|
size_variation=params['size_variation'],
|
|
enable_3d=params['enable_3d'],
|
|
cluster_names=cluster_names
|
|
)
|
|
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
# Display summary statistics
|
|
display_summary_stats(filtered_df, params['selected_sources'] or filtered_df['source_file'].unique())
|
|
|
|
# Display clustering results and export options
|
|
display_clustering_results(
|
|
filtered_df, cluster_labels, reduced_embeddings,
|
|
params['method'], params['clustering_method'], params['enable_3d']
|
|
)
|
|
|
|
# Display data table
|
|
display_data_table(filtered_df, cluster_labels)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|