112 lines
4.4 KiB
Python
112 lines
4.4 KiB
Python
# batch_embedder.py
|
|
# Description: A script to process all CSV files in a directory,
|
|
# add text embeddings to a specified column, and
|
|
# save the results back to the original files.
|
|
#
|
|
# This script assumes the TextEmbedder class is in a file named `main.py`
|
|
# in the same directory.
|
|
|
|
import os
|
|
import pandas as pd
|
|
from embed_class import TextEmbedder # Importing the class from main.py
|
|
|
|
def create_sample_files(directory: str):
|
|
"""Creates a few sample CSV files for demonstration purposes."""
|
|
if not os.path.exists(directory):
|
|
print(f"Creating sample directory: '{directory}'")
|
|
os.makedirs(directory)
|
|
|
|
# Sample file 1: Product descriptions
|
|
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
|
|
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
|
|
df1 = pd.DataFrame(df1_data)
|
|
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
|
|
|
|
# Sample file 2: Customer reviews
|
|
df2_data = {'review_id': [101, 102, 103],
|
|
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
|
|
df2 = pd.DataFrame(df2_data)
|
|
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
|
|
|
|
print(f"Created sample files in '{directory}'.")
|
|
|
|
|
|
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
|
|
"""
|
|
Finds all CSV files in a directory, embeds a user-specified text column,
|
|
and overwrites the original CSV with the new data.
|
|
|
|
Args:
|
|
directory_path (str): The path to the directory containing CSV files.
|
|
model_name (str): The sentence-transformer model to use for embedding.
|
|
"""
|
|
print(f"Starting batch processing for directory: '{directory_path}'")
|
|
|
|
# 1. Initialize the TextEmbedder
|
|
# This will load the model, which can take a moment.
|
|
try:
|
|
embedder = TextEmbedder(model_name)
|
|
except Exception as e:
|
|
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
|
|
return
|
|
|
|
# 2. Find all CSV files in the directory
|
|
try:
|
|
all_files = os.listdir(directory_path)
|
|
csv_files = [f for f in all_files if f.endswith('.csv')]
|
|
except FileNotFoundError:
|
|
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
|
|
return
|
|
|
|
if not csv_files:
|
|
print("No CSV files found in the directory.")
|
|
return
|
|
|
|
print(f"Found {len(csv_files)} CSV files to process.")
|
|
|
|
# 3. Loop through each CSV file
|
|
for filename in csv_files:
|
|
file_path = os.path.join(directory_path, filename)
|
|
print(f"\n--- Processing file: {filename} ---")
|
|
|
|
try:
|
|
# Read the CSV into a DataFrame
|
|
df = pd.read_csv(file_path)
|
|
print("Available columns:", list(df.columns))
|
|
|
|
# Ask the user for the column to embed
|
|
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
|
|
|
|
# Check if the column exists
|
|
if column_to_embed not in df.columns:
|
|
print(f"Column '{column_to_embed}' not found. Skipping this file.")
|
|
continue
|
|
|
|
# 4. Use the embedder to add the new column
|
|
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
|
|
|
|
# 5. Save the modified DataFrame back to the original file
|
|
df_with_embeddings.to_csv(file_path, index=False)
|
|
print(f"Successfully processed and saved '{filename}'.")
|
|
|
|
except Exception as e:
|
|
print(f"An error occurred while processing {filename}: {e}")
|
|
continue # Move to the next file
|
|
|
|
print("\nBatch processing complete.")
|
|
|
|
|
|
# --- Main Execution Block ---
|
|
if __name__ == '__main__':
|
|
# Define the directory where your CSV files are located.
|
|
# The script will look for a folder named 'csv_data' in the current directory.
|
|
CSV_DIRECTORY = '../discord_chat_logs'
|
|
|
|
# This function will create the 'csv_data' directory and some sample
|
|
# files if they don't exist. You can comment this out if you have your own files.
|
|
# create_sample_files(CSV_DIRECTORY)
|
|
|
|
# Run the main processing function on the directory
|
|
process_csvs_in_directory(CSV_DIRECTORY)
|
|
|