text embedding script and class

This commit is contained in:
2025-08-11 01:47:52 +01:00
parent 7ca86d7751
commit fb3fb70cc5
2 changed files with 258 additions and 0 deletions

147
scripts/embed_class.py Normal file
View File

@@ -0,0 +1,147 @@
# main.py
# Description: A simple Python class to generate text embeddings using sentence-transformers.
#
# Required libraries:
# pip install sentence-transformers pandas torch
#
# This script defines a TextEmbedder class that can be used to:
# 1. Load a pre-trained sentence-transformer model.
# 2. Embed a single string or a list of strings into vectors.
# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column.
import pandas as pd
from sentence_transformers import SentenceTransformer
from typing import List, Union
class TextEmbedder:
"""
A simple class to handle text embedding using sentence-transformers.
"""
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
"""
Initializes the TextEmbedder and loads the specified model.
Args:
model_name (str): The name of the sentence-transformer model to use.
Defaults to 'all-MiniLM-L6-v2', a small and efficient model.
"""
self.model_name = model_name
self.model = None
self.load_model()
def load_model(self):
"""
Loads the sentence-transformer model from Hugging Face.
This method is called automatically during initialization.
"""
try:
print(f"Loading model: '{self.model_name}'...")
self.model = SentenceTransformer(self.model_name)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
self.model = None
def embed(self, text: Union[str, List[str]]):
"""
Generates vector embeddings for a given string or list of strings.
Args:
text (Union[str, List[str]]): A single string or a list of strings to embed.
Returns:
A list of vector embeddings. Each embedding is a list of floats.
Returns None if the model is not loaded.
"""
if self.model is None:
print("Model is not loaded. Cannot perform inference.")
return None
print(f"Embedding text...")
# The model's encode function handles both single strings and lists of strings.
embeddings = self.model.encode(text, convert_to_numpy=False)
# We convert to a list of lists for easier use with pandas.
if isinstance(text, str):
return embeddings.tolist()
return [emb.tolist() for emb in embeddings]
def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame:
"""
Embeds the text in a specified DataFrame column and adds the embeddings
as a new column to the DataFrame.
Args:
df (pd.DataFrame): The pandas DataFrame to process.
column_name (str): The name of the column containing the text to embed.
Returns:
pd.DataFrame: The original DataFrame with a new column containing the embeddings.
Returns the original DataFrame unmodified if an error occurs.
"""
if self.model is None:
print("Model is not loaded. Cannot process DataFrame.")
return df
if column_name not in df.columns:
print(f"Error: Column '{column_name}' not found in the DataFrame.")
return df
# Ensure the column is of string type and handle potential missing values (NaN)
# by filling them with an empty string.
text_to_embed = df[column_name].astype(str).fillna('').tolist()
# Generate embeddings for the entire column's text
embeddings = self.embed(text_to_embed)
if embeddings:
# Add the embeddings as a new column
new_column_name = f'{column_name}_embedding'
df[new_column_name] = embeddings
print(f"Successfully added '{new_column_name}' to the DataFrame.")
return df
# --- Example Usage ---
if __name__ == '__main__':
# 1. Initialize the embedder. This will automatically load the model.
embedder = TextEmbedder(model_name='all-MiniLM-L6-v2')
# 2. Embed a single string
print("\n--- Embedding a single string ---")
single_string = "This is a simple test sentence."
vector = embedder.embed(single_string)
if vector:
print(f"Original string: '{single_string}'")
# Print the first 5 dimensions of the vector for brevity
print(f"Resulting vector (first 5 dims): {vector[:5]}")
print(f"Vector dimension: {len(vector)}")
# 3. Embed a list of strings
print("\n--- Embedding a list of strings ---")
list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"]
vectors = embedder.embed(list_of_strings)
if vectors:
for i, text in enumerate(list_of_strings):
print(f"Original string: '{text}'")
print(f"Resulting vector (first 5 dims): {vectors[i][:5]}")
print(f"Vector dimension: {len(vectors[i])}\n")
# 4. Embed a pandas DataFrame column
print("\n--- Embedding a DataFrame column ---")
# Create a sample DataFrame
data = {'product_id': [1, 2, 3],
'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']}
my_df = pd.DataFrame(data)
print("Original DataFrame:")
print(my_df)
# Embed the 'description' column
df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description')
print("\nDataFrame with embeddings:")
# Using .to_string() to ensure the full content is displayed
print(df_with_embeddings.to_string())