148 lines
5.5 KiB
Python
148 lines
5.5 KiB
Python
# main.py
|
|
# Description: A simple Python class to generate text embeddings using sentence-transformers.
|
|
#
|
|
# Required libraries:
|
|
# pip install sentence-transformers pandas torch
|
|
#
|
|
# This script defines a TextEmbedder class that can be used to:
|
|
# 1. Load a pre-trained sentence-transformer model.
|
|
# 2. Embed a single string or a list of strings into vectors.
|
|
# 3. Embed an entire text column in a pandas DataFrame and add the embeddings as a new column.
|
|
|
|
import pandas as pd
|
|
from sentence_transformers import SentenceTransformer
|
|
from typing import List, Union
|
|
|
|
class TextEmbedder:
|
|
"""
|
|
A simple class to handle text embedding using sentence-transformers.
|
|
"""
|
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
|
"""
|
|
Initializes the TextEmbedder and loads the specified model.
|
|
|
|
Args:
|
|
model_name (str): The name of the sentence-transformer model to use.
|
|
Defaults to 'all-MiniLM-L6-v2', a small and efficient model.
|
|
"""
|
|
self.model_name = model_name
|
|
self.model = None
|
|
self.load_model()
|
|
|
|
def load_model(self):
|
|
"""
|
|
Loads the sentence-transformer model from Hugging Face.
|
|
This method is called automatically during initialization.
|
|
"""
|
|
try:
|
|
print(f"Loading model: '{self.model_name}'...")
|
|
self.model = SentenceTransformer(self.model_name)
|
|
print("Model loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
self.model = None
|
|
|
|
def embed(self, text: Union[str, List[str]]):
|
|
"""
|
|
Generates vector embeddings for a given string or list of strings.
|
|
|
|
Args:
|
|
text (Union[str, List[str]]): A single string or a list of strings to embed.
|
|
|
|
Returns:
|
|
A list of vector embeddings. Each embedding is a list of floats.
|
|
Returns None if the model is not loaded.
|
|
"""
|
|
if self.model is None:
|
|
print("Model is not loaded. Cannot perform inference.")
|
|
return None
|
|
|
|
print(f"Embedding text...")
|
|
# The model's encode function handles both single strings and lists of strings.
|
|
embeddings = self.model.encode(text, convert_to_numpy=False)
|
|
# We convert to a list of lists for easier use with pandas.
|
|
if isinstance(text, str):
|
|
return embeddings.tolist()
|
|
return [emb.tolist() for emb in embeddings]
|
|
|
|
|
|
def embed_dataframe_column(self, df: pd.DataFrame, column_name: str) -> pd.DataFrame:
|
|
"""
|
|
Embeds the text in a specified DataFrame column and adds the embeddings
|
|
as a new column to the DataFrame.
|
|
|
|
Args:
|
|
df (pd.DataFrame): The pandas DataFrame to process.
|
|
column_name (str): The name of the column containing the text to embed.
|
|
|
|
Returns:
|
|
pd.DataFrame: The original DataFrame with a new column containing the embeddings.
|
|
Returns the original DataFrame unmodified if an error occurs.
|
|
"""
|
|
if self.model is None:
|
|
print("Model is not loaded. Cannot process DataFrame.")
|
|
return df
|
|
|
|
if column_name not in df.columns:
|
|
print(f"Error: Column '{column_name}' not found in the DataFrame.")
|
|
return df
|
|
|
|
# Ensure the column is of string type and handle potential missing values (NaN)
|
|
# by filling them with an empty string.
|
|
text_to_embed = df[column_name].astype(str).fillna('').tolist()
|
|
|
|
# Generate embeddings for the entire column's text
|
|
embeddings = self.embed(text_to_embed)
|
|
|
|
if embeddings:
|
|
# Add the embeddings as a new column
|
|
new_column_name = f'{column_name}_embedding'
|
|
df[new_column_name] = embeddings
|
|
print(f"Successfully added '{new_column_name}' to the DataFrame.")
|
|
|
|
return df
|
|
|
|
# --- Example Usage ---
|
|
if __name__ == '__main__':
|
|
# 1. Initialize the embedder. This will automatically load the model.
|
|
embedder = TextEmbedder(model_name='all-MiniLM-L6-v2')
|
|
|
|
# 2. Embed a single string
|
|
print("\n--- Embedding a single string ---")
|
|
single_string = "This is a simple test sentence."
|
|
vector = embedder.embed(single_string)
|
|
if vector:
|
|
print(f"Original string: '{single_string}'")
|
|
# Print the first 5 dimensions of the vector for brevity
|
|
print(f"Resulting vector (first 5 dims): {vector[:5]}")
|
|
print(f"Vector dimension: {len(vector)}")
|
|
|
|
# 3. Embed a list of strings
|
|
print("\n--- Embedding a list of strings ---")
|
|
list_of_strings = ["The quick brown fox jumps over the lazy dog.", "Hello, world!"]
|
|
vectors = embedder.embed(list_of_strings)
|
|
if vectors:
|
|
for i, text in enumerate(list_of_strings):
|
|
print(f"Original string: '{text}'")
|
|
print(f"Resulting vector (first 5 dims): {vectors[i][:5]}")
|
|
print(f"Vector dimension: {len(vectors[i])}\n")
|
|
|
|
|
|
# 4. Embed a pandas DataFrame column
|
|
print("\n--- Embedding a DataFrame column ---")
|
|
# Create a sample DataFrame
|
|
data = {'product_id': [1, 2, 3],
|
|
'description': ['A comfortable cotton t-shirt.', 'High-quality noise-cancelling headphones.', 'A book about the history of computing.']}
|
|
my_df = pd.DataFrame(data)
|
|
|
|
print("Original DataFrame:")
|
|
print(my_df)
|
|
|
|
# Embed the 'description' column
|
|
df_with_embeddings = embedder.embed_dataframe_column(my_df, 'description')
|
|
|
|
print("\nDataFrame with embeddings:")
|
|
# Using .to_string() to ensure the full content is displayed
|
|
print(df_with_embeddings.to_string())
|
|
|