text embedding script and class
This commit is contained in:
147
scripts/embed_class.py
Normal file
147
scripts/embed_class.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# main.py
|
||||
# Description: A simple Python class to generate text embeddings using sentence-transformers.
|
||||
#
|
||||
# Required libraries:
|
||||
# pip install sentence-transformers pandas torch
|
||||
#
|
||||
# This script defines a TextEmbedder class that can be used to:
|
||||
# 1. Load a pre-trained sentence-transformer model.
|
||||
# 2. Embed a single string or a list of strings into vectors.
|
||||
# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column.
|
||||
|
||||
import pandas as pd
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from typing import List, Union
|
||||
|
||||
class TextEmbedder:
|
||||
"""
|
||||
A simple class to handle text embedding using sentence-transformers.
|
||||
"""
|
||||
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
||||
"""
|
||||
Initializes the TextEmbedder and loads the specified model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the sentence-transformer model to use.
|
||||
Defaults to 'all-MiniLM-L6-v2', a small and efficient model.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model = None
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Loads the sentence-transformer model from Hugging Face.
|
||||
This method is called automatically during initialization.
|
||||
"""
|
||||
try:
|
||||
print(f"Loading model: '{self.model_name}'...")
|
||||
self.model = SentenceTransformer(self.model_name)
|
||||
print("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
self.model = None
|
||||
|
||||
def embed(self, text: Union[str, List[str]]):
|
||||
"""
|
||||
Generates vector embeddings for a given string or list of strings.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): A single string or a list of strings to embed.
|
||||
|
||||
Returns:
|
||||
A list of vector embeddings. Each embedding is a list of floats.
|
||||
Returns None if the model is not loaded.
|
||||
"""
|
||||
if self.model is None:
|
||||
print("Model is not loaded. Cannot perform inference.")
|
||||
return None
|
||||
|
||||
print(f"Embedding text...")
|
||||
# The model's encode function handles both single strings and lists of strings.
|
||||
embeddings = self.model.encode(text, convert_to_numpy=False)
|
||||
# We convert to a list of lists for easier use with pandas.
|
||||
if isinstance(text, str):
|
||||
return embeddings.tolist()
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
|
||||
|
||||
def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
Embeds the text in a specified DataFrame column and adds the embeddings
|
||||
as a new column to the DataFrame.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The pandas DataFrame to process.
|
||||
column_name (str): The name of the column containing the text to embed.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The original DataFrame with a new column containing the embeddings.
|
||||
Returns the original DataFrame unmodified if an error occurs.
|
||||
"""
|
||||
if self.model is None:
|
||||
print("Model is not loaded. Cannot process DataFrame.")
|
||||
return df
|
||||
|
||||
if column_name not in df.columns:
|
||||
print(f"Error: Column '{column_name}' not found in the DataFrame.")
|
||||
return df
|
||||
|
||||
# Ensure the column is of string type and handle potential missing values (NaN)
|
||||
# by filling them with an empty string.
|
||||
text_to_embed = df[column_name].astype(str).fillna('').tolist()
|
||||
|
||||
# Generate embeddings for the entire column's text
|
||||
embeddings = self.embed(text_to_embed)
|
||||
|
||||
if embeddings:
|
||||
# Add the embeddings as a new column
|
||||
new_column_name = f'{column_name}_embedding'
|
||||
df[new_column_name] = embeddings
|
||||
print(f"Successfully added '{new_column_name}' to the DataFrame.")
|
||||
|
||||
return df
|
||||
|
||||
# --- Example Usage ---
|
||||
if __name__ == '__main__':
|
||||
# 1. Initialize the embedder. This will automatically load the model.
|
||||
embedder = TextEmbedder(model_name='all-MiniLM-L6-v2')
|
||||
|
||||
# 2. Embed a single string
|
||||
print("\n--- Embedding a single string ---")
|
||||
single_string = "This is a simple test sentence."
|
||||
vector = embedder.embed(single_string)
|
||||
if vector:
|
||||
print(f"Original string: '{single_string}'")
|
||||
# Print the first 5 dimensions of the vector for brevity
|
||||
print(f"Resulting vector (first 5 dims): {vector[:5]}")
|
||||
print(f"Vector dimension: {len(vector)}")
|
||||
|
||||
# 3. Embed a list of strings
|
||||
print("\n--- Embedding a list of strings ---")
|
||||
list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"]
|
||||
vectors = embedder.embed(list_of_strings)
|
||||
if vectors:
|
||||
for i, text in enumerate(list_of_strings):
|
||||
print(f"Original string: '{text}'")
|
||||
print(f"Resulting vector (first 5 dims): {vectors[i][:5]}")
|
||||
print(f"Vector dimension: {len(vectors[i])}\n")
|
||||
|
||||
|
||||
# 4. Embed a pandas DataFrame column
|
||||
print("\n--- Embedding a DataFrame column ---")
|
||||
# Create a sample DataFrame
|
||||
data = {'product_id': [1, 2, 3],
|
||||
'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']}
|
||||
my_df = pd.DataFrame(data)
|
||||
|
||||
print("Original DataFrame:")
|
||||
print(my_df)
|
||||
|
||||
# Embed the 'description' column
|
||||
df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description')
|
||||
|
||||
print("\nDataFrame with embeddings:")
|
||||
# Using .to_string() to ensure the full content is displayed
|
||||
print(df_with_embeddings.to_string())
|
||||
|
||||
111
scripts/embedder.py
Normal file
111
scripts/embedder.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# batch_embedder.py
|
||||
# Description: A script to process all CSV files in a directory,
|
||||
# add text embeddings to a specified column, and
|
||||
# save the results back to the original files.
|
||||
#
|
||||
# This script assumes the TextEmbedder class is in a file named `main.py`
|
||||
# in the same directory.
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from embed_class import TextEmbedder # Importing the class from main.py
|
||||
|
||||
def create_sample_files(directory: str):
|
||||
"""Creates a few sample CSV files for demonstration purposes."""
|
||||
if not os.path.exists(directory):
|
||||
print(f"Creating sample directory: '{directory}'")
|
||||
os.makedirs(directory)
|
||||
|
||||
# Sample file 1: Product descriptions
|
||||
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
|
||||
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
|
||||
df1 = pd.DataFrame(df1_data)
|
||||
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
|
||||
|
||||
# Sample file 2: Customer reviews
|
||||
df2_data = {'review_id': [101, 102, 103],
|
||||
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
|
||||
df2 = pd.DataFrame(df2_data)
|
||||
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
|
||||
|
||||
print(f"Created sample files in '{directory}'.")
|
||||
|
||||
|
||||
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
|
||||
"""
|
||||
Finds all CSV files in a directory, embeds a user-specified text column,
|
||||
and overwrites the original CSV with the new data.
|
||||
|
||||
Args:
|
||||
directory_path (str): The path to the directory containing CSV files.
|
||||
model_name (str): The sentence-transformer model to use for embedding.
|
||||
"""
|
||||
print(f"Starting batch processing for directory: '{directory_path}'")
|
||||
|
||||
# 1. Initialize the TextEmbedder
|
||||
# This will load the model, which can take a moment.
|
||||
try:
|
||||
embedder = TextEmbedder(model_name)
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
|
||||
return
|
||||
|
||||
# 2. Find all CSV files in the directory
|
||||
try:
|
||||
all_files = os.listdir(directory_path)
|
||||
csv_files = [f for f in all_files if f.endswith('.csv')]
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
|
||||
return
|
||||
|
||||
if not csv_files:
|
||||
print("No CSV files found in the directory.")
|
||||
return
|
||||
|
||||
print(f"Found {len(csv_files)} CSV files to process.")
|
||||
|
||||
# 3. Loop through each CSV file
|
||||
for filename in csv_files:
|
||||
file_path = os.path.join(directory_path, filename)
|
||||
print(f"\n--- Processing file: {filename} ---")
|
||||
|
||||
try:
|
||||
# Read the CSV into a DataFrame
|
||||
df = pd.read_csv(file_path)
|
||||
print("Available columns:", list(df.columns))
|
||||
|
||||
# Ask the user for the column to embed
|
||||
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
|
||||
|
||||
# Check if the column exists
|
||||
if column_to_embed not in df.columns:
|
||||
print(f"Column '{column_to_embed}' not found. Skipping this file.")
|
||||
continue
|
||||
|
||||
# 4. Use the embedder to add the new column
|
||||
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
|
||||
|
||||
# 5. Save the modified DataFrame back to the original file
|
||||
df_with_embeddings.to_csv(file_path, index=False)
|
||||
print(f"Successfully processed and saved '{filename}'.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred while processing {filename}: {e}")
|
||||
continue # Move to the next file
|
||||
|
||||
print("\nBatch processing complete.")
|
||||
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == '__main__':
|
||||
# Define the directory where your CSV files are located.
|
||||
# The script will look for a folder named 'csv_data' in the current directory.
|
||||
CSV_DIRECTORY = 'csv_data'
|
||||
|
||||
# This function will create the 'csv_data' directory and some sample
|
||||
# files if they don't exist. You can comment this out if you have your own files.
|
||||
create_sample_files(CSV_DIRECTORY)
|
||||
|
||||
# Run the main processing function on the directory
|
||||
process_csvs_in_directory(CSV_DIRECTORY)
|
||||
|
||||
Reference in New Issue
Block a user