Files
cult-scraper/scripts/embedder.py
2025-08-11 01:51:34 +01:00

112 lines
4.4 KiB
Python

# batch_embedder.py
# Description: A script to process all CSV files in a directory,
# add text embeddings to a specified column, and
# save the results back to the original files.
#
# This script assumes the TextEmbedder class is in a file named `main.py`
# in the same directory.
import os
import pandas as pd
from embed_class import TextEmbedder # Importing the class from main.py
def create_sample_files(directory: str):
"""Creates a few sample CSV files for demonstration purposes."""
if not os.path.exists(directory):
print(f"Creating sample directory: '{directory}'")
os.makedirs(directory)
# Sample file 1: Product descriptions
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
df1 = pd.DataFrame(df1_data)
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
# Sample file 2: Customer reviews
df2_data = {'review_id': [101, 102, 103],
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
df2 = pd.DataFrame(df2_data)
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
print(f"Created sample files in '{directory}'.")
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
"""
Finds all CSV files in a directory, embeds a user-specified text column,
and overwrites the original CSV with the new data.
Args:
directory_path (str): The path to the directory containing CSV files.
model_name (str): The sentence-transformer model to use for embedding.
"""
print(f"Starting batch processing for directory: '{directory_path}'")
# 1. Initialize the TextEmbedder
# This will load the model, which can take a moment.
try:
embedder = TextEmbedder(model_name)
except Exception as e:
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
return
# 2. Find all CSV files in the directory
try:
all_files = os.listdir(directory_path)
csv_files = [f for f in all_files if f.endswith('.csv')]
except FileNotFoundError:
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
return
if not csv_files:
print("No CSV files found in the directory.")
return
print(f"Found {len(csv_files)} CSV files to process.")
# 3. Loop through each CSV file
for filename in csv_files:
file_path = os.path.join(directory_path, filename)
print(f"\n--- Processing file: {filename} ---")
try:
# Read the CSV into a DataFrame
df = pd.read_csv(file_path)
print("Available columns:", list(df.columns))
# Ask the user for the column to embed
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
# Check if the column exists
if column_to_embed not in df.columns:
print(f"Column '{column_to_embed}' not found. Skipping this file.")
continue
# 4. Use the embedder to add the new column
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
# 5. Save the modified DataFrame back to the original file
df_with_embeddings.to_csv(file_path, index=False)
print(f"Successfully processed and saved '{filename}'.")
except Exception as e:
print(f"An error occurred while processing {filename}: {e}")
continue # Move to the next file
print("\nBatch processing complete.")
# --- Main Execution Block ---
if __name__ == '__main__':
# Define the directory where your CSV files are located.
# The script will look for a folder named 'csv_data' in the current directory.
CSV_DIRECTORY = '../discord_chat_logs'
# This function will create the 'csv_data' directory and some sample
# files if they don't exist. You can comment this out if you have your own files.
# create_sample_files(CSV_DIRECTORY)
# Run the main processing function on the directory
process_csvs_in_directory(CSV_DIRECTORY)