refactor
This commit is contained in:
86
apps/cluster_map/data_loader.py
Normal file
86
apps/cluster_map/data_loader.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Data loading and parsing utilities for Discord chat logs.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from config import CHAT_LOGS_PATH
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_chat_data():
|
||||
"""Load all CSV files from the discord_chat_logs folder"""
|
||||
chat_logs_path = Path(CHAT_LOGS_PATH)
|
||||
|
||||
with st.expander("📁 Loading Details", expanded=False):
|
||||
# Display the path for debugging
|
||||
st.write(f"Looking for CSV files in: {chat_logs_path}")
|
||||
st.write(f"Path exists: {chat_logs_path.exists()}")
|
||||
|
||||
all_data = []
|
||||
|
||||
for csv_file in chat_logs_path.glob("*.csv"):
|
||||
try:
|
||||
df = pd.read_csv(csv_file)
|
||||
df['source_file'] = csv_file.stem # Add source file name
|
||||
all_data.append(df)
|
||||
st.write(f"✅ Loaded {len(df)} messages from {csv_file.name}")
|
||||
except Exception as e:
|
||||
st.error(f"❌ Error loading {csv_file.name}: {e}")
|
||||
|
||||
if all_data:
|
||||
combined_df = pd.concat(all_data, ignore_index=True)
|
||||
st.success(f"🎉 Successfully loaded {len(combined_df)} total messages from {len(all_data)} files")
|
||||
else:
|
||||
st.error("No data loaded!")
|
||||
combined_df = pd.DataFrame()
|
||||
|
||||
return combined_df if all_data else pd.DataFrame()
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def parse_embeddings(df):
|
||||
"""Parse the content_embedding column from string to numpy array"""
|
||||
embeddings = []
|
||||
valid_indices = []
|
||||
|
||||
for idx, embedding_str in enumerate(df['content_embedding']):
|
||||
try:
|
||||
# Parse the string representation of the list
|
||||
embedding = ast.literal_eval(embedding_str)
|
||||
if isinstance(embedding, list) and len(embedding) > 0:
|
||||
embeddings.append(embedding)
|
||||
valid_indices.append(idx)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
embeddings_array = np.array(embeddings)
|
||||
valid_df = df.iloc[valid_indices].copy()
|
||||
|
||||
st.info(f"📊 Parsed {len(embeddings)} valid embeddings from {len(df)} messages")
|
||||
st.info(f"🔢 Embedding dimension: {embeddings_array.shape[1] if len(embeddings) > 0 else 0}")
|
||||
|
||||
return embeddings_array, valid_df
|
||||
|
||||
|
||||
def filter_data(df, selected_sources, selected_authors):
|
||||
"""Filter dataframe by selected sources and authors"""
|
||||
if not selected_sources:
|
||||
selected_sources = df['source_file'].unique()
|
||||
|
||||
filtered_df = df[
|
||||
(df['source_file'].isin(selected_sources)) &
|
||||
(df['author_name'].isin(selected_authors))
|
||||
]
|
||||
|
||||
return filtered_df
|
||||
|
||||
|
||||
def get_filtered_embeddings(embeddings, valid_df, filtered_df):
|
||||
"""Get embeddings corresponding to filtered dataframe"""
|
||||
filtered_indices = filtered_df.index.tolist()
|
||||
filtered_embeddings = embeddings[[i for i, idx in enumerate(valid_df.index) if idx in filtered_indices]]
|
||||
return filtered_embeddings
|
||||
Reference in New Issue
Block a user