From fb3fb70cc5de5bd0a17e3377288a13aaa4c0a5dc Mon Sep 17 00:00:00 2001 From: Azeem Fidahusein Date: Mon, 11 Aug 2025 01:47:52 +0100 Subject: [PATCH] text embedding script and class --- scripts/embed_class.py | 147 +++++++++++++++++++++++++++++++++++++++++ scripts/embedder.py | 111 +++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 scripts/embed_class.py create mode 100644 scripts/embedder.py diff --git a/scripts/embed_class.py b/scripts/embed_class.py new file mode 100644 index 0000000..ddd61bc --- /dev/null +++ b/scripts/embed_class.py @@ -0,0 +1,147 @@ +# main.py +# Description: A simple Python class to generate text embeddings using sentence-transformers. +# +# Required libraries: +# pip install sentence-transformers pandas torch +# +# This script defines a TextEmbedder class that can be used to: +# 1. Load a pre-trained sentence-transformer model. +# 2. Embed a single string or a list of strings into vectors. +# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column. + +import pandas as pd +from sentence_transformers import SentenceTransformer +from typing import List, Union + +class TextEmbedder: + """ + A simple class to handle text embedding using sentence-transformers. + """ + def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): + """ + Initializes the TextEmbedder and loads the specified model. + + Args: + model_name (str): The name of the sentence-transformer model to use. + Defaults to 'all-MiniLM-L6-v2', a small and efficient model. + """ + self.model_name = model_name + self.model = None + self.load_model() + + def load_model(self): + """ + Loads the sentence-transformer model from Hugging Face. + This method is called automatically during initialization. + """ + try: + print(f"Loading model: '{self.model_name}'...") + self.model = SentenceTransformer(self.model_name) + print("Model loaded successfully.") + except Exception as e: + print(f"Error loading model: {e}") + self.model = None + + def embed(self, text: Union[str, List[str]]): + """ + Generates vector embeddings for a given string or list of strings. + + Args: + text (Union[str, List[str]]): A single string or a list of strings to embed. + + Returns: + A list of vector embeddings. Each embedding is a list of floats. + Returns None if the model is not loaded. + """ + if self.model is None: + print("Model is not loaded. Cannot perform inference.") + return None + + print(f"Embedding text...") + # The model's encode function handles both single strings and lists of strings. + embeddings = self.model.encode(text, convert_to_numpy=False) + # We convert to a list of lists for easier use with pandas. + if isinstance(text, str): + return embeddings.tolist() + return [emb.tolist() for emb in embeddings] + + + def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame: + """ + Embeds the text in a specified DataFrame column and adds the embeddings + as a new column to the DataFrame. + + Args: + df (pd.DataFrame): The pandas DataFrame to process. + column_name (str): The name of the column containing the text to embed. + + Returns: + pd.DataFrame: The original DataFrame with a new column containing the embeddings. + Returns the original DataFrame unmodified if an error occurs. + """ + if self.model is None: + print("Model is not loaded. Cannot process DataFrame.") + return df + + if column_name not in df.columns: + print(f"Error: Column '{column_name}' not found in the DataFrame.") + return df + + # Ensure the column is of string type and handle potential missing values (NaN) + # by filling them with an empty string. + text_to_embed = df[column_name].astype(str).fillna('').tolist() + + # Generate embeddings for the entire column's text + embeddings = self.embed(text_to_embed) + + if embeddings: + # Add the embeddings as a new column + new_column_name = f'{column_name}_embedding' + df[new_column_name] = embeddings + print(f"Successfully added '{new_column_name}' to the DataFrame.") + + return df + +# --- Example Usage --- +if __name__ == '__main__': + # 1. Initialize the embedder. This will automatically load the model. + embedder = TextEmbedder(model_name='all-MiniLM-L6-v2') + + # 2. Embed a single string + print("\n--- Embedding a single string ---") + single_string = "This is a simple test sentence." + vector = embedder.embed(single_string) + if vector: + print(f"Original string: '{single_string}'") + # Print the first 5 dimensions of the vector for brevity + print(f"Resulting vector (first 5 dims): {vector[:5]}") + print(f"Vector dimension: {len(vector)}") + + # 3. Embed a list of strings + print("\n--- Embedding a list of strings ---") + list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"] + vectors = embedder.embed(list_of_strings) + if vectors: + for i, text in enumerate(list_of_strings): + print(f"Original string: '{text}'") + print(f"Resulting vector (first 5 dims): {vectors[i][:5]}") + print(f"Vector dimension: {len(vectors[i])}\n") + + + # 4. Embed a pandas DataFrame column + print("\n--- Embedding a DataFrame column ---") + # Create a sample DataFrame + data = {'product_id': [1, 2, 3], + 'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']} + my_df = pd.DataFrame(data) + + print("Original DataFrame:") + print(my_df) + + # Embed the 'description' column + df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description') + + print("\nDataFrame with embeddings:") + # Using .to_string() to ensure the full content is displayed + print(df_with_embeddings.to_string()) + diff --git a/scripts/embedder.py b/scripts/embedder.py new file mode 100644 index 0000000..8a085cd --- /dev/null +++ b/scripts/embedder.py @@ -0,0 +1,111 @@ +# batch_embedder.py +# Description: A script to process all CSV files in a directory, +# add text embeddings to a specified column, and +# save the results back to the original files. +# +# This script assumes the TextEmbedder class is in a file named `main.py` +# in the same directory. + +import os +import pandas as pd +from embed_class import TextEmbedder # Importing the class from main.py + +def create_sample_files(directory: str): + """Creates a few sample CSV files for demonstration purposes.""" + if not os.path.exists(directory): + print(f"Creating sample directory: '{directory}'") + os.makedirs(directory) + + # Sample file 1: Product descriptions + df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'], + 'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']} + df1 = pd.DataFrame(df1_data) + df1.to_csv(os.path.join(directory, 'products.csv'), index=False) + + # Sample file 2: Customer reviews + df2_data = {'review_id': [101, 102, 103], + 'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']} + df2 = pd.DataFrame(df2_data) + df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False) + + print(f"Created sample files in '{directory}'.") + + +def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'): + """ + Finds all CSV files in a directory, embeds a user-specified text column, + and overwrites the original CSV with the new data. + + Args: + directory_path (str): The path to the directory containing CSV files. + model_name (str): The sentence-transformer model to use for embedding. + """ + print(f"Starting batch processing for directory: '{directory_path}'") + + # 1. Initialize the TextEmbedder + # This will load the model, which can take a moment. + try: + embedder = TextEmbedder(model_name) + except Exception as e: + print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}") + return + + # 2. Find all CSV files in the directory + try: + all_files = os.listdir(directory_path) + csv_files = [f for f in all_files if f.endswith('.csv')] + except FileNotFoundError: + print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.") + return + + if not csv_files: + print("No CSV files found in the directory.") + return + + print(f"Found {len(csv_files)} CSV files to process.") + + # 3. Loop through each CSV file + for filename in csv_files: + file_path = os.path.join(directory_path, filename) + print(f"\n--- Processing file: {filename} ---") + + try: + # Read the CSV into a DataFrame + df = pd.read_csv(file_path) + print("Available columns:", list(df.columns)) + + # Ask the user for the column to embed + column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ") + + # Check if the column exists + if column_to_embed not in df.columns: + print(f"Column '{column_to_embed}' not found. Skipping this file.") + continue + + # 4. Use the embedder to add the new column + df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed) + + # 5. Save the modified DataFrame back to the original file + df_with_embeddings.to_csv(file_path, index=False) + print(f"Successfully processed and saved '{filename}'.") + + except Exception as e: + print(f"An error occurred while processing {filename}: {e}") + continue # Move to the next file + + print("\nBatch processing complete.") + + +# --- Main Execution Block --- +if __name__ == '__main__': + # Define the directory where your CSV files are located. + # The script will look for a folder named 'csv_data' in the current directory. + CSV_DIRECTORY = 'csv_data' + + # This function will create the 'csv_data' directory and some sample + # files if they don't exist. You can comment this out if you have your own files. + create_sample_files(CSV_DIRECTORY) + + # Run the main processing function on the directory + process_csvs_in_directory(CSV_DIRECTORY) +