227 lines
8.7 KiB
Python
227 lines
8.7 KiB
Python
"""
|
|
Clustering algorithms and evaluation metrics.
|
|
"""
|
|
|
|
import numpy as np
|
|
import streamlit as st
|
|
from sklearn.cluster import SpectralClustering, AgglomerativeClustering, OPTICS
|
|
from sklearn.mixture import GaussianMixture
|
|
from sklearn.preprocessing import StandardScaler
|
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
|
import hdbscan
|
|
import pandas as pd
|
|
from collections import Counter
|
|
import re
|
|
from config import DEFAULT_RANDOM_STATE
|
|
|
|
|
|
def summarize_cluster_content(cluster_messages, max_words=3):
|
|
"""
|
|
Generate a meaningful name for a cluster based on its message content.
|
|
|
|
Args:
|
|
cluster_messages: List of message contents in the cluster
|
|
max_words: Maximum number of words in the cluster name
|
|
|
|
Returns:
|
|
str: Generated cluster name
|
|
"""
|
|
if not cluster_messages:
|
|
return "Empty Cluster"
|
|
|
|
# Combine all messages and clean text
|
|
all_text = " ".join([str(msg) for msg in cluster_messages if pd.notna(msg)])
|
|
if not all_text.strip():
|
|
return "Empty Content"
|
|
|
|
# Basic text cleaning
|
|
text = all_text.lower()
|
|
|
|
# Remove URLs, mentions, and special characters
|
|
text = re.sub(r'http[s]?://\S+', '', text) # Remove URLs
|
|
text = re.sub(r'<@\d+>', '', text) # Remove Discord mentions
|
|
text = re.sub(r'<:\w+:\d+>', '', text) # Remove custom emojis
|
|
text = re.sub(r'[^\w\s]', ' ', text) # Remove punctuation
|
|
text = re.sub(r'\s+', ' ', text).strip() # Normalize whitespace
|
|
|
|
if not text:
|
|
return "Special Characters"
|
|
|
|
# Split into words and filter out common words
|
|
words = text.split()
|
|
|
|
# Common stop words to filter out
|
|
stop_words = {
|
|
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
|
|
'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before', 'after',
|
|
'above', 'below', 'between', 'among', 'until', 'without', 'under', 'over',
|
|
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had',
|
|
'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might',
|
|
'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
|
|
'my', 'your', 'his', 'her', 'its', 'our', 'their', 'this', 'that', 'these', 'those',
|
|
'just', 'like', 'get', 'know', 'think', 'see', 'go', 'come', 'say', 'said',
|
|
'yeah', 'yes', 'no', 'oh', 'ok', 'okay', 'well', 'so', 'but', 'if', 'when',
|
|
'what', 'where', 'why', 'how', 'who', 'which', 'than', 'then', 'now', 'here',
|
|
'there', 'also', 'too', 'very', 'really', 'pretty', 'much', 'more', 'most',
|
|
'some', 'any', 'all', 'many', 'few', 'little', 'big', 'small', 'good', 'bad'
|
|
}
|
|
|
|
# Filter out stop words and very short/long words
|
|
filtered_words = [
|
|
word for word in words
|
|
if word not in stop_words
|
|
and len(word) >= 3
|
|
and len(word) <= 15
|
|
and word.isalpha() # Only alphabetic words
|
|
]
|
|
|
|
if not filtered_words:
|
|
return f"Chat ({len(cluster_messages)} msgs)"
|
|
|
|
# Count word frequencies
|
|
word_counts = Counter(filtered_words)
|
|
|
|
# Get most common words
|
|
most_common = word_counts.most_common(max_words * 2) # Get more than needed for filtering
|
|
|
|
# Select diverse words (avoid very similar words)
|
|
selected_words = []
|
|
for word, count in most_common:
|
|
# Avoid adding very similar words
|
|
if not any(word.startswith(existing[:4]) or existing.startswith(word[:4])
|
|
for existing in selected_words):
|
|
selected_words.append(word)
|
|
if len(selected_words) >= max_words:
|
|
break
|
|
|
|
if not selected_words:
|
|
return f"Discussion ({len(cluster_messages)} msgs)"
|
|
|
|
# Create cluster name
|
|
cluster_name = " + ".join(selected_words[:max_words]).title()
|
|
|
|
# Add message count for context
|
|
cluster_name += f" ({len(cluster_messages)})"
|
|
|
|
return cluster_name
|
|
|
|
|
|
def generate_cluster_names(filtered_df, cluster_labels):
|
|
"""
|
|
Generate names for all clusters based on their content.
|
|
|
|
Args:
|
|
filtered_df: DataFrame with message data
|
|
cluster_labels: Array of cluster labels for each message
|
|
|
|
Returns:
|
|
dict: Mapping from cluster_id to cluster_name
|
|
"""
|
|
if cluster_labels is None:
|
|
return {}
|
|
|
|
cluster_names = {}
|
|
unique_clusters = np.unique(cluster_labels)
|
|
|
|
for cluster_id in unique_clusters:
|
|
if cluster_id == -1:
|
|
cluster_names[cluster_id] = "Noise/Outliers"
|
|
continue
|
|
|
|
# Get messages in this cluster
|
|
cluster_mask = cluster_labels == cluster_id
|
|
cluster_messages = filtered_df[cluster_mask]['content'].tolist()
|
|
|
|
# Generate name
|
|
cluster_name = summarize_cluster_content(cluster_messages)
|
|
cluster_names[cluster_id] = cluster_name
|
|
|
|
return cluster_names
|
|
|
|
|
|
def apply_clustering(embeddings, clustering_method="None", n_clusters=5):
|
|
"""
|
|
Apply clustering algorithm to embeddings and return labels and metrics.
|
|
|
|
Args:
|
|
embeddings: High-dimensional embeddings to cluster
|
|
clustering_method: Name of clustering algorithm
|
|
n_clusters: Number of clusters (for methods that require it)
|
|
|
|
Returns:
|
|
tuple: (cluster_labels, silhouette_score, calinski_harabasz_score)
|
|
"""
|
|
if clustering_method == "None" or len(embeddings) <= n_clusters:
|
|
return None, None, None
|
|
|
|
# Standardize embeddings for better clustering
|
|
scaler = StandardScaler()
|
|
scaled_embeddings = scaler.fit_transform(embeddings)
|
|
|
|
cluster_labels = None
|
|
silhouette_avg = None
|
|
calinski_harabasz = None
|
|
|
|
try:
|
|
if clustering_method == "HDBSCAN":
|
|
min_cluster_size = max(2, len(embeddings) // 20) # Adaptive min cluster size
|
|
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size,
|
|
min_samples=1, cluster_selection_epsilon=0.5)
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
elif clustering_method == "Spectral Clustering":
|
|
clusterer = SpectralClustering(n_clusters=n_clusters, random_state=DEFAULT_RANDOM_STATE,
|
|
affinity='rbf', gamma=1.0)
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
elif clustering_method == "Gaussian Mixture":
|
|
clusterer = GaussianMixture(n_components=n_clusters, random_state=DEFAULT_RANDOM_STATE,
|
|
covariance_type='full', max_iter=200)
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
elif clustering_method == "Agglomerative (Ward)":
|
|
clusterer = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
elif clustering_method == "Agglomerative (Complete)":
|
|
clusterer = AgglomerativeClustering(n_clusters=n_clusters, linkage='complete')
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
elif clustering_method == "OPTICS":
|
|
min_samples = max(2, len(embeddings) // 50)
|
|
clusterer = OPTICS(min_samples=min_samples, xi=0.05, min_cluster_size=0.1)
|
|
cluster_labels = clusterer.fit_predict(scaled_embeddings)
|
|
|
|
# Calculate clustering quality metrics
|
|
if cluster_labels is not None and len(np.unique(cluster_labels)) > 1:
|
|
# Only calculate if we have multiple clusters and no noise-only clustering
|
|
valid_labels = cluster_labels[cluster_labels != -1] # Remove noise points for HDBSCAN/OPTICS
|
|
valid_embeddings = scaled_embeddings[cluster_labels != -1]
|
|
|
|
if len(valid_labels) > 0 and len(np.unique(valid_labels)) > 1:
|
|
silhouette_avg = silhouette_score(valid_embeddings, valid_labels)
|
|
calinski_harabasz = calinski_harabasz_score(valid_embeddings, valid_labels)
|
|
|
|
except Exception as e:
|
|
st.warning(f"Clustering failed: {str(e)}")
|
|
cluster_labels = None
|
|
|
|
return cluster_labels, silhouette_avg, calinski_harabasz
|
|
|
|
|
|
def get_cluster_statistics(cluster_labels):
|
|
"""Get basic statistics about clustering results"""
|
|
if cluster_labels is None:
|
|
return {}
|
|
|
|
unique_clusters = np.unique(cluster_labels)
|
|
n_clusters = len(unique_clusters[unique_clusters != -1]) # Exclude noise cluster (-1)
|
|
n_noise = np.sum(cluster_labels == -1)
|
|
|
|
return {
|
|
"n_clusters": n_clusters,
|
|
"n_noise_points": n_noise,
|
|
"cluster_distribution": np.bincount(cluster_labels[cluster_labels != -1]) if n_clusters > 0 else [],
|
|
"unique_clusters": unique_clusters
|
|
}
|