text embedding script and class

This commit is contained in:
2025-08-11 01:47:52 +01:00
parent 7ca86d7751
commit fb3fb70cc5
2 changed files with 258 additions and 0 deletions

147
scripts/embed_class.py Normal file
View File

@@ -0,0 +1,147 @@
# main.py
# Description: A simple Python class to generate text embeddings using sentence-transformers.
#
# Required libraries:
# pip install sentence-transformers pandas torch
#
# This script defines a TextEmbedder class that can be used to:
# 1. Load a pre-trained sentence-transformer model.
# 2. Embed a single string or a list of strings into vectors.
# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column.
import pandas as pd
from sentence_transformers import SentenceTransformer
from typing import List, Union
class TextEmbedder:
"""
A simple class to handle text embedding using sentence-transformers.
"""
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
"""
Initializes the TextEmbedder and loads the specified model.
Args:
model_name (str): The name of the sentence-transformer model to use.
Defaults to 'all-MiniLM-L6-v2', a small and efficient model.
"""
self.model_name = model_name
self.model = None
self.load_model()
def load_model(self):
"""
Loads the sentence-transformer model from Hugging Face.
This method is called automatically during initialization.
"""
try:
print(f"Loading model: '{self.model_name}'...")
self.model = SentenceTransformer(self.model_name)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
self.model = None
def embed(self, text: Union[str, List[str]]):
"""
Generates vector embeddings for a given string or list of strings.
Args:
text (Union[str, List[str]]): A single string or a list of strings to embed.
Returns:
A list of vector embeddings. Each embedding is a list of floats.
Returns None if the model is not loaded.
"""
if self.model is None:
print("Model is not loaded. Cannot perform inference.")
return None
print(f"Embedding text...")
# The model's encode function handles both single strings and lists of strings.
embeddings = self.model.encode(text, convert_to_numpy=False)
# We convert to a list of lists for easier use with pandas.
if isinstance(text, str):
return embeddings.tolist()
return [emb.tolist() for emb in embeddings]
def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame:
"""
Embeds the text in a specified DataFrame column and adds the embeddings
as a new column to the DataFrame.
Args:
df (pd.DataFrame): The pandas DataFrame to process.
column_name (str): The name of the column containing the text to embed.
Returns:
pd.DataFrame: The original DataFrame with a new column containing the embeddings.
Returns the original DataFrame unmodified if an error occurs.
"""
if self.model is None:
print("Model is not loaded. Cannot process DataFrame.")
return df
if column_name not in df.columns:
print(f"Error: Column '{column_name}' not found in the DataFrame.")
return df
# Ensure the column is of string type and handle potential missing values (NaN)
# by filling them with an empty string.
text_to_embed = df[column_name].astype(str).fillna('').tolist()
# Generate embeddings for the entire column's text
embeddings = self.embed(text_to_embed)
if embeddings:
# Add the embeddings as a new column
new_column_name = f'{column_name}_embedding'
df[new_column_name] = embeddings
print(f"Successfully added '{new_column_name}' to the DataFrame.")
return df
# --- Example Usage ---
if __name__ == '__main__':
# 1. Initialize the embedder. This will automatically load the model.
embedder = TextEmbedder(model_name='all-MiniLM-L6-v2')
# 2. Embed a single string
print("\n--- Embedding a single string ---")
single_string = "This is a simple test sentence."
vector = embedder.embed(single_string)
if vector:
print(f"Original string: '{single_string}'")
# Print the first 5 dimensions of the vector for brevity
print(f"Resulting vector (first 5 dims): {vector[:5]}")
print(f"Vector dimension: {len(vector)}")
# 3. Embed a list of strings
print("\n--- Embedding a list of strings ---")
list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"]
vectors = embedder.embed(list_of_strings)
if vectors:
for i, text in enumerate(list_of_strings):
print(f"Original string: '{text}'")
print(f"Resulting vector (first 5 dims): {vectors[i][:5]}")
print(f"Vector dimension: {len(vectors[i])}\n")
# 4. Embed a pandas DataFrame column
print("\n--- Embedding a DataFrame column ---")
# Create a sample DataFrame
data = {'product_id': [1, 2, 3],
'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']}
my_df = pd.DataFrame(data)
print("Original DataFrame:")
print(my_df)
# Embed the 'description' column
df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description')
print("\nDataFrame with embeddings:")
# Using .to_string() to ensure the full content is displayed
print(df_with_embeddings.to_string())

111
scripts/embedder.py Normal file
View File

@@ -0,0 +1,111 @@
# batch_embedder.py
# Description: A script to process all CSV files in a directory,
# add text embeddings to a specified column, and
# save the results back to the original files.
#
# This script assumes the TextEmbedder class is in a file named `main.py`
# in the same directory.
import os
import pandas as pd
from embed_class import TextEmbedder # Importing the class from main.py
def create_sample_files(directory: str):
"""Creates a few sample CSV files for demonstration purposes."""
if not os.path.exists(directory):
print(f"Creating sample directory: '{directory}'")
os.makedirs(directory)
# Sample file 1: Product descriptions
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
df1 = pd.DataFrame(df1_data)
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
# Sample file 2: Customer reviews
df2_data = {'review_id': [101, 102, 103],
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
df2 = pd.DataFrame(df2_data)
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
print(f"Created sample files in '{directory}'.")
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
"""
Finds all CSV files in a directory, embeds a user-specified text column,
and overwrites the original CSV with the new data.
Args:
directory_path (str): The path to the directory containing CSV files.
model_name (str): The sentence-transformer model to use for embedding.
"""
print(f"Starting batch processing for directory: '{directory_path}'")
# 1. Initialize the TextEmbedder
# This will load the model, which can take a moment.
try:
embedder = TextEmbedder(model_name)
except Exception as e:
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
return
# 2. Find all CSV files in the directory
try:
all_files = os.listdir(directory_path)
csv_files = [f for f in all_files if f.endswith('.csv')]
except FileNotFoundError:
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
return
if not csv_files:
print("No CSV files found in the directory.")
return
print(f"Found {len(csv_files)} CSV files to process.")
# 3. Loop through each CSV file
for filename in csv_files:
file_path = os.path.join(directory_path, filename)
print(f"\n--- Processing file: {filename} ---")
try:
# Read the CSV into a DataFrame
df = pd.read_csv(file_path)
print("Available columns:", list(df.columns))
# Ask the user for the column to embed
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
# Check if the column exists
if column_to_embed not in df.columns:
print(f"Column '{column_to_embed}' not found. Skipping this file.")
continue
# 4. Use the embedder to add the new column
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
# 5. Save the modified DataFrame back to the original file
df_with_embeddings.to_csv(file_path, index=False)
print(f"Successfully processed and saved '{filename}'.")
except Exception as e:
print(f"An error occurred while processing {filename}: {e}")
continue # Move to the next file
print("\nBatch processing complete.")
# --- Main Execution Block ---
if __name__ == '__main__':
# Define the directory where your CSV files are located.
# The script will look for a folder named 'csv_data' in the current directory.
CSV_DIRECTORY = 'csv_data'
# This function will create the 'csv_data' directory and some sample
# files if they don't exist. You can comment this out if you have your own files.
create_sample_files(CSV_DIRECTORY)
# Run the main processing function on the directory
process_csvs_in_directory(CSV_DIRECTORY)