import streamlit as st import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go from sklearn.decomposition import PCA from sklearn.manifold import TSNE import json import os from pathlib import Path import ast # Set page config st.set_page_config( page_title="Discord Chat Embeddings Visualizer", page_icon="🗨️", layout="wide" ) # Title and description st.title("🗨️ Discord Chat Embeddings Visualizer") st.markdown("Explore Discord chat messages through their vector embeddings in 2D space") @st.cache_data def load_all_chat_data(): """Load all CSV files from the discord_chat_logs folder""" chat_logs_path = Path("../../discord_chat_logs") # Display the path for debugging st.write(f"Looking for CSV files in: {chat_logs_path}") st.write(f"Path exists: {chat_logs_path.exists()}") all_data = [] for csv_file in chat_logs_path.glob("*.csv"): try: df = pd.read_csv(csv_file) df['source_file'] = csv_file.stem # Add source file name all_data.append(df) st.write(f"✅ Loaded {len(df)} messages from {csv_file.name}") except Exception as e: st.error(f"❌ Error loading {csv_file.name}: {e}") if all_data: combined_df = pd.concat(all_data, ignore_index=True) st.success(f"🎉 Successfully loaded {len(combined_df)} total messages from {len(all_data)} files") return combined_df else: st.error("No data loaded!") return pd.DataFrame() @st.cache_data def parse_embeddings(df): """Parse the content_embedding column from string to numpy array""" embeddings = [] valid_indices = [] for idx, embedding_str in enumerate(df['content_embedding']): try: # Parse the string representation of the list embedding = ast.literal_eval(embedding_str) if isinstance(embedding, list) and len(embedding) > 0: embeddings.append(embedding) valid_indices.append(idx) except Exception as e: continue embeddings_array = np.array(embeddings) valid_df = df.iloc[valid_indices].copy() st.info(f"📊 Parsed {len(embeddings)} valid embeddings from {len(df)} messages") st.info(f"🔢 Embedding dimension: {embeddings_array.shape[1] if len(embeddings) > 0 else 0}") return embeddings_array, valid_df @st.cache_data def reduce_dimensions(embeddings, method="PCA", n_components=2): """Reduce embeddings to 2D using PCA or t-SNE""" if method == "PCA": reducer = PCA(n_components=n_components, random_state=42) elif method == "t-SNE": reducer = TSNE(n_components=n_components, random_state=42, perplexity=min(30, len(embeddings)-1)) reduced_embeddings = reducer.fit_transform(embeddings) return reduced_embeddings def create_hover_text(df): """Create hover text for plotly""" hover_text = [] for _, row in df.iterrows(): text = f"Author: {row['author_name']}
" text += f"Timestamp: {row['timestamp_utc']}
" text += f"Source: {row['source_file']}
" # Handle potential NaN or non-string content content = row['content'] if pd.isna(content) or content is None: content_text = "[No content]" else: content_str = str(content) content_text = content_str[:200] + ('...' if len(content_str) > 200 else '') text += f"Content: {content_text}" hover_text.append(text) return hover_text def main(): # Load data with st.spinner("Loading chat data..."): df = load_all_chat_data() if df.empty: st.stop() # Parse embeddings with st.spinner("Parsing embeddings..."): embeddings, valid_df = parse_embeddings(df) if len(embeddings) == 0: st.error("No valid embeddings found!") st.stop() # Sidebar controls st.sidebar.header("🎛️ Visualization Controls") # Dimension reduction method method = st.sidebar.selectbox( "Dimension Reduction Method", ["PCA", "t-SNE"], help="PCA is faster, t-SNE may reveal better clusters" ) # Source file filter source_files = valid_df['source_file'].unique() selected_sources = st.sidebar.multiselect( "Filter by Source Files", source_files, default=source_files, help="Select which chat log files to include" ) # Author filter authors = valid_df['author_name'].unique() selected_authors = st.sidebar.multiselect( "Filter by Authors", authors, default=authors[:10] if len(authors) > 10 else authors, # Limit to first 10 for performance help="Select which authors to include" ) # Filter data filtered_df = valid_df[ (valid_df['source_file'].isin(selected_sources)) & (valid_df['author_name'].isin(selected_authors)) ] if filtered_df.empty: st.warning("No data matches the current filters!") st.stop() # Get corresponding embeddings filtered_indices = filtered_df.index.tolist() filtered_embeddings = embeddings[[i for i, idx in enumerate(valid_df.index) if idx in filtered_indices]] st.info(f"📈 Visualizing {len(filtered_df)} messages") # Reduce dimensions with st.spinner(f"Reducing dimensions using {method}..."): reduced_embeddings = reduce_dimensions(filtered_embeddings, method) # Create hover text hover_text = create_hover_text(filtered_df) # Create the plot fig = go.Figure() # Color by source file colors = px.colors.qualitative.Set1 for i, source in enumerate(selected_sources): source_mask = filtered_df['source_file'] == source if source_mask.any(): source_data = filtered_df[source_mask] source_embeddings = reduced_embeddings[source_mask] source_hover = [hover_text[j] for j, mask in enumerate(source_mask) if mask] fig.add_trace(go.Scatter( x=source_embeddings[:, 0], y=source_embeddings[:, 1], mode='markers', name=source, marker=dict( size=8, color=colors[i % len(colors)], opacity=0.7, line=dict(width=1, color='white') ), hovertemplate='%{hovertext}', hovertext=source_hover )) fig.update_layout( title=f"Discord Chat Messages - {method} Visualization", xaxis_title=f"{method} Component 1", yaxis_title=f"{method} Component 2", hovermode='closest', width=1000, height=700 ) # Display the plot st.plotly_chart(fig, use_container_width=True) # Statistics col1, col2, col3 = st.columns(3) with col1: st.metric("Total Messages", len(filtered_df)) with col2: st.metric("Unique Authors", filtered_df['author_name'].nunique()) with col3: st.metric("Source Files", len(selected_sources)) # Show data table if st.checkbox("Show Data Table"): st.subheader("📋 Message Data") display_df = filtered_df[['timestamp_utc', 'author_name', 'source_file', 'content']].copy() display_df['content'] = display_df['content'].str[:100] + '...' # Truncate for display st.dataframe(display_df, use_container_width=True) if __name__ == "__main__": main()