BasalGanglia's picture
๐Ÿ› ๏ธ Fix HuggingFace Space configuration - Remove quotes from frontmatter
64ced8b verified
"""Async embedding service for semantic similarity in MCP tools.
This module provides high-performance async embedding functionality with
caching, batch processing, and memory optimization.
"""
import hashlib
import logging
import math
import os
from typing import Any
import openai
from dotenv import load_dotenv
from .ontology import MCPPrompt, MCPTool
from .performance import (
AsyncBatchProcessor,
EmbeddingCache,
async_cached,
main_cache,
performance_monitor,
performance_monitor_instance,
)
# Load environment variables
load_dotenv()
logger = logging.getLogger(__name__)
class AsyncEmbeddingService:
"""High-performance async embedding service with caching and optimization."""
def __init__(self, embedding_dim: int = 128, batch_size: int = 10):
"""Initialize the async embedding service.
Args:
embedding_dim: Dimension of the embedding vectors
batch_size: Number of embeddings to process in a batch
"""
self.embedding_dim = embedding_dim
self.batch_size = batch_size
# Initialize caching
self.embedding_cache = EmbeddingCache(max_size=2000, compression_precision=6)
self.batch_processor = AsyncBatchProcessor(batch_size=batch_size, max_concurrent=5)
# Initialize OpenAI client
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
self.openai_client: openai.AsyncOpenAI | None = openai.AsyncOpenAI(api_key=api_key)
self.model_name = "text-embedding-3-small"
else:
self.openai_client = None
self.model_name = "mock"
logger.warning("OPENAI_API_KEY not found. Using mock embeddings.")
# Performance tracking
self.embedding_requests = 0
self.cache_hits = 0
self._cache = {}
@performance_monitor(performance_monitor_instance)
async def get_embedding(self, text: str, use_cache: bool = True) -> list[float] | None:
"""Generate an embedding vector for the given text using OpenAI API.
Args:
text: The text to embed
use_cache: Whether to use caching
Returns:
A list of floats representing the embedding vector, or None if API call fails
"""
self.embedding_requests += 1
# Try cache first
if use_cache:
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
if cache_key in self._cache:
self.cache_hits += 1
return self._cache[cache_key]
if not self.openai_client:
logger.debug("OpenAI client not available, using mock embedding")
embedding = self._create_mock_embedding(text)
# Cache the mock embedding
if use_cache:
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
self._cache[cache_key] = embedding
return embedding
try:
# Preprocess text
cleaned_text = text.replace("\n", " ").strip()
# Limit text length to avoid API limits
if len(cleaned_text) > 8000:
cleaned_text = cleaned_text[:8000]
# Make async API call to OpenAI
response = await self.openai_client.embeddings.create(
input=cleaned_text,
model=self.model_name
)
# Extract embedding from response
embedding = response.data[0].embedding
# Cache the result
if use_cache:
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
self._cache[cache_key] = embedding
return embedding
except openai.APIError as e:
logger.error(f"OpenAI API error when generating embedding: {e}")
return self._create_mock_embedding(text)
except Exception as e:
logger.error(f"Unexpected error when generating embedding: {e}")
return self._create_mock_embedding(text)
async def get_embeddings_batch(self, texts: list[str], use_cache: bool = True) -> list[list[float] | None]:
"""Generate embeddings for multiple texts efficiently.
Args:
texts: List of texts to embed
use_cache: Whether to use caching
Returns:
List of embedding vectors
"""
if not texts:
return []
# Check cache for all texts first
embeddings = []
uncached_indices = []
uncached_texts = []
if use_cache:
for i, text in enumerate(texts):
cached = await self.embedding_cache.get_embedding(text, self.model_name)
if cached is not None:
embeddings.append(cached)
self.cache_hits += 1
else:
embeddings.append(None)
uncached_indices.append(i)
uncached_texts.append(text)
else:
uncached_indices = list(range(len(texts)))
uncached_texts = texts
embeddings = [None] * len(texts)
# Process uncached texts in batches
if uncached_texts:
if self.openai_client:
try:
# Process in batches for API efficiency
batch_embeddings = await self._process_embedding_batches(uncached_texts)
# Fill in the results and cache them
for idx, embedding in zip(uncached_indices, batch_embeddings, strict=False):
embeddings[idx] = embedding
if use_cache and embedding is not None:
await self.embedding_cache.set_embedding(
texts[idx], embedding, self.model_name
)
except Exception as e:
logger.error(f"Batch embedding failed: {e}")
# Fallback to mock embeddings
for idx, text in zip(uncached_indices, uncached_texts, strict=False):
embeddings[idx] = self._create_mock_embedding(text)
else:
# Use mock embeddings
for idx, text in zip(uncached_indices, uncached_texts, strict=False):
embeddings[idx] = self._create_mock_embedding(text)
self.embedding_requests += len(texts)
return embeddings
async def _process_embedding_batches(self, texts: list[str]) -> list[list[float] | None]:
"""Process texts in batches for API efficiency."""
all_embeddings = []
# Process in chunks of API batch size
api_batch_size = min(100, len(texts)) # OpenAI limit
for i in range(0, len(texts), api_batch_size):
batch_texts = texts[i:i + api_batch_size]
try:
# Clean texts
cleaned_texts = [
text.replace("\n", " ").strip()[:8000]
for text in batch_texts
]
# Make batch API call
response = await self.openai_client.embeddings.create(
input=cleaned_texts,
model=self.model_name
)
# Extract embeddings
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Batch API call failed: {e}")
# Fallback to mock for this batch
mock_embeddings = [self._create_mock_embedding(text) for text in batch_texts]
all_embeddings.extend(mock_embeddings)
return all_embeddings
def _create_mock_embedding(self, text: str) -> list[float]:
"""Create a deterministic mock embedding."""
# Create a deterministic hash-based embedding
text_hash = hashlib.sha256(text.encode()).hexdigest()
# Convert hash to numbers and normalize
embedding = []
for i in range(0, min(len(text_hash), self.embedding_dim * 2), 2):
hex_pair = text_hash[i : i + 2]
value = int(hex_pair, 16) / 255.0
embedding.append(value)
# Pad or truncate to desired dimension
while len(embedding) < self.embedding_dim:
embedding.append(0.0)
return embedding[:self.embedding_dim]
@async_cached(main_cache)
async def embed_tool_description(self, tool: MCPTool) -> list[float] | None:
"""Generate an embedding for a tool's description.
Args:
tool: The MCPTool to embed
Returns:
Embedding vector for the tool's description
"""
# Combine name, description, and tags for richer embedding
combined_text = f"{tool.name} {tool.description}"
if tool.tags:
combined_text += f" {' '.join(tool.tags)}"
return await self.get_embedding(combined_text)
@async_cached(main_cache)
async def embed_prompt_description(self, prompt: MCPPrompt) -> list[float] | None:
"""Generate an embedding for a prompt's description.
Args:
prompt: The MCPPrompt to embed
Returns:
Embedding vector for the prompt's description
"""
# Combine name, description, and template for richer embedding
combined_text = f"{prompt.name} {prompt.description}"
if prompt.template_string:
# Include template but limit length
template_preview = prompt.template_string[:200]
combined_text += f" {template_preview}"
return await self.get_embedding(combined_text)
async def embed_tools_batch(self, tools: list[MCPTool]) -> list[list[float] | None]:
"""Generate embeddings for multiple tools efficiently."""
tool_texts = []
for tool in tools:
combined_text = f"{tool.name} {tool.description}"
if tool.tags:
combined_text += f" {' '.join(tool.tags)}"
tool_texts.append(combined_text)
return await self.get_embeddings_batch(tool_texts)
async def embed_prompts_batch(self, prompts: list[MCPPrompt]) -> list[list[float] | None]:
"""Generate embeddings for multiple prompts efficiently."""
prompt_texts = []
for prompt in prompts:
combined_text = f"{prompt.name} {prompt.description}"
if prompt.template_string:
template_preview = prompt.template_string[:200]
combined_text += f" {template_preview}"
prompt_texts.append(combined_text)
return await self.get_embeddings_batch(prompt_texts)
def compute_similarity(
self, embedding1: list[float], embedding2: list[float]
) -> float:
"""Compute cosine similarity between two embeddings.
Args:
embedding1: First embedding vector
embedding2: Second embedding vector
Returns:
Similarity score between 0 and 1
"""
if not embedding1 or not embedding2:
return 0.0
# Ensure same length
min_len = min(len(embedding1), len(embedding2))
vec1 = embedding1[:min_len]
vec2 = embedding2[:min_len]
# Compute dot product
dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False))
# Compute magnitudes
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(b * b for b in vec2))
# Avoid division by zero
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
# Cosine similarity, normalized to 0-1 range
cosine_sim = dot_product / (magnitude1 * magnitude2)
return max(0.0, min(1.0, (cosine_sim + 1) / 2))
async def find_similar_tools(
self, query: str, tools: list[MCPTool], top_k: int = 5
) -> list[tuple[MCPTool, float]]:
"""Find tools most similar to the given query.
Args:
query: The search query text
tools: List of tools to search through
top_k: Maximum number of tools to return
Returns:
List of (tool, similarity_score) tuples, sorted by similarity
"""
if not tools:
return []
# Get query embedding
query_embedding = await self.get_embedding(query)
if query_embedding is None:
return [(tool, 0.0) for tool in tools[:top_k]]
# Get tool embeddings in batch
tool_embeddings = await self.embed_tools_batch(tools)
# Compute similarities
tool_similarities = []
for tool, tool_embedding in zip(tools, tool_embeddings, strict=False):
if tool_embedding is not None:
similarity = self.compute_similarity(query_embedding, tool_embedding)
tool_similarities.append((tool, similarity))
else:
tool_similarities.append((tool, 0.0))
# Sort by similarity (descending) and return top_k
tool_similarities.sort(key=lambda x: x[1], reverse=True)
return tool_similarities[:top_k]
async def find_similar_prompts(
self, query: str, prompts: list[MCPPrompt], top_k: int = 5
) -> list[tuple[MCPPrompt, float]]:
"""Find prompts most similar to the given query.
Args:
query: The search query text
prompts: List of prompts to search through
top_k: Maximum number of prompts to return
Returns:
List of (prompt, similarity_score) tuples, sorted by similarity
"""
if not prompts:
return []
# Get query embedding
query_embedding = await self.get_embedding(query)
if query_embedding is None:
return [(prompt, 0.0) for prompt in prompts[:top_k]]
# Get prompt embeddings in batch
prompt_embeddings = await self.embed_prompts_batch(prompts)
# Compute similarities
prompt_similarities = []
for prompt, prompt_embedding in zip(prompts, prompt_embeddings, strict=False):
if prompt_embedding is not None:
similarity = self.compute_similarity(query_embedding, prompt_embedding)
prompt_similarities.append((prompt, similarity))
else:
prompt_similarities.append((prompt, 0.0))
# Sort by similarity (descending) and return top_k
prompt_similarities.sort(key=lambda x: x[1], reverse=True)
return prompt_similarities[:top_k]
def get_performance_stats(self) -> dict[str, Any]:
"""Get embedding service performance statistics."""
cache_hit_rate = (
self.cache_hits / self.embedding_requests
if self.embedding_requests > 0 else 0
)
return {
"total_embedding_requests": self.embedding_requests,
"cache_hits": self.cache_hits,
"cache_hit_rate": cache_hit_rate,
"cache_size": len(self._cache),
"model_name": self.model_name,
"embedding_dim": self.embedding_dim,
"batch_size": self.batch_size,
"openai_available": self.openai_client is not None
}
async def warm_up_cache(self, texts: list[str]) -> None:
"""Pre-populate cache with common texts."""
logger.info(f"Warming up embedding cache with {len(texts)} texts...")
await self.get_embeddings_batch(texts, use_cache=True)
logger.info("Cache warm-up completed")
async def clear_cache(self) -> None:
"""Clear all caches."""
await self.embedding_cache.cache.clear()
logger.info("Embedding cache cleared")