# batch_embedder.py # Description: A script to process all CSV files in a directory, # add text embeddings to a specified column, and # save the results back to the original files. # # This script assumes the TextEmbedder class is in a file named `main.py` # in the same directory. import os import pandas as pd from embed_class import TextEmbedder # Importing the class from main.py def create_sample_files(directory: str): """Creates a few sample CSV files for demonstration purposes.""" if not os.path.exists(directory): print(f"Creating sample directory: '{directory}'") os.makedirs(directory) # Sample file 1: Product descriptions df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'], 'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']} df1 = pd.DataFrame(df1_data) df1.to_csv(os.path.join(directory, 'products.csv'), index=False) # Sample file 2: Customer reviews df2_data = {'review_id': [101, 102, 103], 'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']} df2 = pd.DataFrame(df2_data) df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False) print(f"Created sample files in '{directory}'.") def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'): """ Finds all CSV files in a directory, embeds a user-specified text column, and overwrites the original CSV with the new data. Args: directory_path (str): The path to the directory containing CSV files. model_name (str): The sentence-transformer model to use for embedding. """ print(f"Starting batch processing for directory: '{directory_path}'") # 1. Initialize the TextEmbedder # This will load the model, which can take a moment. try: embedder = TextEmbedder(model_name) except Exception as e: print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}") return # 2. Find all CSV files in the directory try: all_files = os.listdir(directory_path) csv_files = [f for f in all_files if f.endswith('.csv')] except FileNotFoundError: print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.") return if not csv_files: print("No CSV files found in the directory.") return print(f"Found {len(csv_files)} CSV files to process.") # 3. Loop through each CSV file for filename in csv_files: file_path = os.path.join(directory_path, filename) print(f"\n--- Processing file: {filename} ---") try: # Read the CSV into a DataFrame df = pd.read_csv(file_path) print("Available columns:", list(df.columns)) # Ask the user for the column to embed column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ") # Check if the column exists if column_to_embed not in df.columns: print(f"Column '{column_to_embed}' not found. Skipping this file.") continue # 4. Use the embedder to add the new column df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed) # 5. Save the modified DataFrame back to the original file df_with_embeddings.to_csv(file_path, index=False) print(f"Successfully processed and saved '{filename}'.") except Exception as e: print(f"An error occurred while processing {filename}: {e}") continue # Move to the next file print("\nBatch processing complete.") # --- Main Execution Block --- if __name__ == '__main__': # Define the directory where your CSV files are located. # The script will look for a folder named 'csv_data' in the current directory. CSV_DIRECTORY = '../discord_chat_logs' # This function will create the 'csv_data' directory and some sample # files if they don't exist. You can comment this out if you have your own files. # create_sample_files(CSV_DIRECTORY) # Run the main processing function on the directory process_csvs_in_directory(CSV_DIRECTORY)