text embedding script and class
This commit is contained in:
111
scripts/embedder.py
Normal file
111
scripts/embedder.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# batch_embedder.py
|
||||
# Description: A script to process all CSV files in a directory,
|
||||
# add text embeddings to a specified column, and
|
||||
# save the results back to the original files.
|
||||
#
|
||||
# This script assumes the TextEmbedder class is in a file named `main.py`
|
||||
# in the same directory.
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from embed_class import TextEmbedder # Importing the class from main.py
|
||||
|
||||
def create_sample_files(directory: str):
|
||||
"""Creates a few sample CSV files for demonstration purposes."""
|
||||
if not os.path.exists(directory):
|
||||
print(f"Creating sample directory: '{directory}'")
|
||||
os.makedirs(directory)
|
||||
|
||||
# Sample file 1: Product descriptions
|
||||
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
|
||||
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
|
||||
df1 = pd.DataFrame(df1_data)
|
||||
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
|
||||
|
||||
# Sample file 2: Customer reviews
|
||||
df2_data = {'review_id': [101, 102, 103],
|
||||
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
|
||||
df2 = pd.DataFrame(df2_data)
|
||||
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
|
||||
|
||||
print(f"Created sample files in '{directory}'.")
|
||||
|
||||
|
||||
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
|
||||
"""
|
||||
Finds all CSV files in a directory, embeds a user-specified text column,
|
||||
and overwrites the original CSV with the new data.
|
||||
|
||||
Args:
|
||||
directory_path (str): The path to the directory containing CSV files.
|
||||
model_name (str): The sentence-transformer model to use for embedding.
|
||||
"""
|
||||
print(f"Starting batch processing for directory: '{directory_path}'")
|
||||
|
||||
# 1. Initialize the TextEmbedder
|
||||
# This will load the model, which can take a moment.
|
||||
try:
|
||||
embedder = TextEmbedder(model_name)
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
|
||||
return
|
||||
|
||||
# 2. Find all CSV files in the directory
|
||||
try:
|
||||
all_files = os.listdir(directory_path)
|
||||
csv_files = [f for f in all_files if f.endswith('.csv')]
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
|
||||
return
|
||||
|
||||
if not csv_files:
|
||||
print("No CSV files found in the directory.")
|
||||
return
|
||||
|
||||
print(f"Found {len(csv_files)} CSV files to process.")
|
||||
|
||||
# 3. Loop through each CSV file
|
||||
for filename in csv_files:
|
||||
file_path = os.path.join(directory_path, filename)
|
||||
print(f"\n--- Processing file: {filename} ---")
|
||||
|
||||
try:
|
||||
# Read the CSV into a DataFrame
|
||||
df = pd.read_csv(file_path)
|
||||
print("Available columns:", list(df.columns))
|
||||
|
||||
# Ask the user for the column to embed
|
||||
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
|
||||
|
||||
# Check if the column exists
|
||||
if column_to_embed not in df.columns:
|
||||
print(f"Column '{column_to_embed}' not found. Skipping this file.")
|
||||
continue
|
||||
|
||||
# 4. Use the embedder to add the new column
|
||||
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
|
||||
|
||||
# 5. Save the modified DataFrame back to the original file
|
||||
df_with_embeddings.to_csv(file_path, index=False)
|
||||
print(f"Successfully processed and saved '{filename}'.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred while processing {filename}: {e}")
|
||||
continue # Move to the next file
|
||||
|
||||
print("\nBatch processing complete.")
|
||||
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == '__main__':
|
||||
# Define the directory where your CSV files are located.
|
||||
# The script will look for a folder named 'csv_data' in the current directory.
|
||||
CSV_DIRECTORY = 'csv_data'
|
||||
|
||||
# This function will create the 'csv_data' directory and some sample
|
||||
# files if they don't exist. You can comment this out if you have your own files.
|
||||
create_sample_files(CSV_DIRECTORY)
|
||||
|
||||
# Run the main processing function on the directory
|
||||
process_csvs_in_directory(CSV_DIRECTORY)
|
||||
|
||||
Reference in New Issue
Block a user