text embedding script and class
This commit is contained in:
147
scripts/embed_class.py
Normal file
147
scripts/embed_class.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# main.py
|
||||||
|
# Description: A simple Python class to generate text embeddings using sentence-transformers.
|
||||||
|
#
|
||||||
|
# Required libraries:
|
||||||
|
# pip install sentence-transformers pandas torch
|
||||||
|
#
|
||||||
|
# This script defines a TextEmbedder class that can be used to:
|
||||||
|
# 1. Load a pre-trained sentence-transformer model.
|
||||||
|
# 2. Embed a single string or a list of strings into vectors.
|
||||||
|
# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column.
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
class TextEmbedder:
|
||||||
|
"""
|
||||||
|
A simple class to handle text embedding using sentence-transformers.
|
||||||
|
"""
|
||||||
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
||||||
|
"""
|
||||||
|
Initializes the TextEmbedder and loads the specified model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the sentence-transformer model to use.
|
||||||
|
Defaults to 'all-MiniLM-L6-v2', a small and efficient model.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model = None
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""
|
||||||
|
Loads the sentence-transformer model from Hugging Face.
|
||||||
|
This method is called automatically during initialization.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print(f"Loading model: '{self.model_name}'...")
|
||||||
|
self.model = SentenceTransformer(self.model_name)
|
||||||
|
print("Model loaded successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading model: {e}")
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def embed(self, text: Union[str, List[str]]):
|
||||||
|
"""
|
||||||
|
Generates vector embeddings for a given string or list of strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Union[str, List[str]]): A single string or a list of strings to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of vector embeddings. Each embedding is a list of floats.
|
||||||
|
Returns None if the model is not loaded.
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
print("Model is not loaded. Cannot perform inference.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Embedding text...")
|
||||||
|
# The model's encode function handles both single strings and lists of strings.
|
||||||
|
embeddings = self.model.encode(text, convert_to_numpy=False)
|
||||||
|
# We convert to a list of lists for easier use with pandas.
|
||||||
|
if isinstance(text, str):
|
||||||
|
return embeddings.tolist()
|
||||||
|
return [emb.tolist() for emb in embeddings]
|
||||||
|
|
||||||
|
|
||||||
|
def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Embeds the text in a specified DataFrame column and adds the embeddings
|
||||||
|
as a new column to the DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): The pandas DataFrame to process.
|
||||||
|
column_name (str): The name of the column containing the text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The original DataFrame with a new column containing the embeddings.
|
||||||
|
Returns the original DataFrame unmodified if an error occurs.
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
print("Model is not loaded. Cannot process DataFrame.")
|
||||||
|
return df
|
||||||
|
|
||||||
|
if column_name not in df.columns:
|
||||||
|
print(f"Error: Column '{column_name}' not found in the DataFrame.")
|
||||||
|
return df
|
||||||
|
|
||||||
|
# Ensure the column is of string type and handle potential missing values (NaN)
|
||||||
|
# by filling them with an empty string.
|
||||||
|
text_to_embed = df[column_name].astype(str).fillna('').tolist()
|
||||||
|
|
||||||
|
# Generate embeddings for the entire column's text
|
||||||
|
embeddings = self.embed(text_to_embed)
|
||||||
|
|
||||||
|
if embeddings:
|
||||||
|
# Add the embeddings as a new column
|
||||||
|
new_column_name = f'{column_name}_embedding'
|
||||||
|
df[new_column_name] = embeddings
|
||||||
|
print(f"Successfully added '{new_column_name}' to the DataFrame.")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
# --- Example Usage ---
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 1. Initialize the embedder. This will automatically load the model.
|
||||||
|
embedder = TextEmbedder(model_name='all-MiniLM-L6-v2')
|
||||||
|
|
||||||
|
# 2. Embed a single string
|
||||||
|
print("\n--- Embedding a single string ---")
|
||||||
|
single_string = "This is a simple test sentence."
|
||||||
|
vector = embedder.embed(single_string)
|
||||||
|
if vector:
|
||||||
|
print(f"Original string: '{single_string}'")
|
||||||
|
# Print the first 5 dimensions of the vector for brevity
|
||||||
|
print(f"Resulting vector (first 5 dims): {vector[:5]}")
|
||||||
|
print(f"Vector dimension: {len(vector)}")
|
||||||
|
|
||||||
|
# 3. Embed a list of strings
|
||||||
|
print("\n--- Embedding a list of strings ---")
|
||||||
|
list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"]
|
||||||
|
vectors = embedder.embed(list_of_strings)
|
||||||
|
if vectors:
|
||||||
|
for i, text in enumerate(list_of_strings):
|
||||||
|
print(f"Original string: '{text}'")
|
||||||
|
print(f"Resulting vector (first 5 dims): {vectors[i][:5]}")
|
||||||
|
print(f"Vector dimension: {len(vectors[i])}\n")
|
||||||
|
|
||||||
|
|
||||||
|
# 4. Embed a pandas DataFrame column
|
||||||
|
print("\n--- Embedding a DataFrame column ---")
|
||||||
|
# Create a sample DataFrame
|
||||||
|
data = {'product_id': [1, 2, 3],
|
||||||
|
'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']}
|
||||||
|
my_df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
print("Original DataFrame:")
|
||||||
|
print(my_df)
|
||||||
|
|
||||||
|
# Embed the 'description' column
|
||||||
|
df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description')
|
||||||
|
|
||||||
|
print("\nDataFrame with embeddings:")
|
||||||
|
# Using .to_string() to ensure the full content is displayed
|
||||||
|
print(df_with_embeddings.to_string())
|
||||||
|
|
||||||
111
scripts/embedder.py
Normal file
111
scripts/embedder.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# batch_embedder.py
|
||||||
|
# Description: A script to process all CSV files in a directory,
|
||||||
|
# add text embeddings to a specified column, and
|
||||||
|
# save the results back to the original files.
|
||||||
|
#
|
||||||
|
# This script assumes the TextEmbedder class is in a file named `main.py`
|
||||||
|
# in the same directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from embed_class import TextEmbedder # Importing the class from main.py
|
||||||
|
|
||||||
|
def create_sample_files(directory: str):
|
||||||
|
"""Creates a few sample CSV files for demonstration purposes."""
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
print(f"Creating sample directory: '{directory}'")
|
||||||
|
os.makedirs(directory)
|
||||||
|
|
||||||
|
# Sample file 1: Product descriptions
|
||||||
|
df1_data = {'product_name': ['Smart Watch', 'Wireless Mouse', 'Keyboard'],
|
||||||
|
'description': ['A watch that tracks fitness and notifications.', 'Ergonomic mouse with long battery life.', 'Mechanical keyboard with RGB lighting.']}
|
||||||
|
df1 = pd.DataFrame(df1_data)
|
||||||
|
df1.to_csv(os.path.join(directory, 'products.csv'), index=False)
|
||||||
|
|
||||||
|
# Sample file 2: Customer reviews
|
||||||
|
df2_data = {'review_id': [101, 102, 103],
|
||||||
|
'comment_text': ['The product exceeded my expectations!', 'It arrived late and was the wrong color.', 'I would definitely recommend this to a friend.']}
|
||||||
|
df2 = pd.DataFrame(df2_data)
|
||||||
|
df2.to_csv(os.path.join(directory, 'reviews.csv'), index=False)
|
||||||
|
|
||||||
|
print(f"Created sample files in '{directory}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def process_csvs_in_directory(directory_path: str, model_name: str = 'all-MiniLM-L6-v2'):
|
||||||
|
"""
|
||||||
|
Finds all CSV files in a directory, embeds a user-specified text column,
|
||||||
|
and overwrites the original CSV with the new data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory_path (str): The path to the directory containing CSV files.
|
||||||
|
model_name (str): The sentence-transformer model to use for embedding.
|
||||||
|
"""
|
||||||
|
print(f"Starting batch processing for directory: '{directory_path}'")
|
||||||
|
|
||||||
|
# 1. Initialize the TextEmbedder
|
||||||
|
# This will load the model, which can take a moment.
|
||||||
|
try:
|
||||||
|
embedder = TextEmbedder(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to initialize TextEmbedder. Aborting. Error: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Find all CSV files in the directory
|
||||||
|
try:
|
||||||
|
all_files = os.listdir(directory_path)
|
||||||
|
csv_files = [f for f in all_files if f.endswith('.csv')]
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Directory not found at '{directory_path}'. Please create it and add CSV files.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not csv_files:
|
||||||
|
print("No CSV files found in the directory.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(csv_files)} CSV files to process.")
|
||||||
|
|
||||||
|
# 3. Loop through each CSV file
|
||||||
|
for filename in csv_files:
|
||||||
|
file_path = os.path.join(directory_path, filename)
|
||||||
|
print(f"\n--- Processing file: {filename} ---")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the CSV into a DataFrame
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
print("Available columns:", list(df.columns))
|
||||||
|
|
||||||
|
# Ask the user for the column to embed
|
||||||
|
column_to_embed = input(f"Enter the name of the column to embed for '{filename}': ")
|
||||||
|
|
||||||
|
# Check if the column exists
|
||||||
|
if column_to_embed not in df.columns:
|
||||||
|
print(f"Column '{column_to_embed}' not found. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4. Use the embedder to add the new column
|
||||||
|
df_with_embeddings = embedder.embed_dataframe_column(df, column_to_embed)
|
||||||
|
|
||||||
|
# 5. Save the modified DataFrame back to the original file
|
||||||
|
df_with_embeddings.to_csv(file_path, index=False)
|
||||||
|
print(f"Successfully processed and saved '{filename}'.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred while processing {filename}: {e}")
|
||||||
|
continue # Move to the next file
|
||||||
|
|
||||||
|
print("\nBatch processing complete.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution Block ---
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Define the directory where your CSV files are located.
|
||||||
|
# The script will look for a folder named 'csv_data' in the current directory.
|
||||||
|
CSV_DIRECTORY = 'csv_data'
|
||||||
|
|
||||||
|
# This function will create the 'csv_data' directory and some sample
|
||||||
|
# files if they don't exist. You can comment this out if you have your own files.
|
||||||
|
create_sample_files(CSV_DIRECTORY)
|
||||||
|
|
||||||
|
# Run the main processing function on the directory
|
||||||
|
process_csvs_in_directory(CSV_DIRECTORY)
|
||||||
|
|
||||||
Reference in New Issue
Block a user