"""Async embedding service for semantic similarity in MCP tools. This module provides high-performance async embedding functionality with caching, batch processing, and memory optimization. """ import hashlib import logging import math import os from typing import Any import openai from dotenv import load_dotenv from .ontology import MCPPrompt, MCPTool from .performance import ( AsyncBatchProcessor, EmbeddingCache, async_cached, main_cache, performance_monitor, performance_monitor_instance, ) # Load environment variables load_dotenv() logger = logging.getLogger(__name__) class AsyncEmbeddingService: """High-performance async embedding service with caching and optimization.""" def __init__(self, embedding_dim: int = 128, batch_size: int = 10): """Initialize the async embedding service. Args: embedding_dim: Dimension of the embedding vectors batch_size: Number of embeddings to process in a batch """ self.embedding_dim = embedding_dim self.batch_size = batch_size # Initialize caching self.embedding_cache = EmbeddingCache(max_size=2000, compression_precision=6) self.batch_processor = AsyncBatchProcessor(batch_size=batch_size, max_concurrent=5) # Initialize OpenAI client api_key = os.getenv("OPENAI_API_KEY") if api_key: self.openai_client: openai.AsyncOpenAI | None = openai.AsyncOpenAI(api_key=api_key) self.model_name = "text-embedding-3-small" else: self.openai_client = None self.model_name = "mock" logger.warning("OPENAI_API_KEY not found. Using mock embeddings.") # Performance tracking self.embedding_requests = 0 self.cache_hits = 0 self._cache = {} @performance_monitor(performance_monitor_instance) async def get_embedding(self, text: str, use_cache: bool = True) -> list[float] | None: """Generate an embedding vector for the given text using OpenAI API. Args: text: The text to embed use_cache: Whether to use caching Returns: A list of floats representing the embedding vector, or None if API call fails """ self.embedding_requests += 1 # Try cache first if use_cache: cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() if cache_key in self._cache: self.cache_hits += 1 return self._cache[cache_key] if not self.openai_client: logger.debug("OpenAI client not available, using mock embedding") embedding = self._create_mock_embedding(text) # Cache the mock embedding if use_cache: cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() self._cache[cache_key] = embedding return embedding try: # Preprocess text cleaned_text = text.replace("\n", " ").strip() # Limit text length to avoid API limits if len(cleaned_text) > 8000: cleaned_text = cleaned_text[:8000] # Make async API call to OpenAI response = await self.openai_client.embeddings.create( input=cleaned_text, model=self.model_name ) # Extract embedding from response embedding = response.data[0].embedding # Cache the result if use_cache: cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() self._cache[cache_key] = embedding return embedding except openai.APIError as e: logger.error(f"OpenAI API error when generating embedding: {e}") return self._create_mock_embedding(text) except Exception as e: logger.error(f"Unexpected error when generating embedding: {e}") return self._create_mock_embedding(text) async def get_embeddings_batch(self, texts: list[str], use_cache: bool = True) -> list[list[float] | None]: """Generate embeddings for multiple texts efficiently. Args: texts: List of texts to embed use_cache: Whether to use caching Returns: List of embedding vectors """ if not texts: return [] # Check cache for all texts first embeddings = [] uncached_indices = [] uncached_texts = [] if use_cache: for i, text in enumerate(texts): cached = await self.embedding_cache.get_embedding(text, self.model_name) if cached is not None: embeddings.append(cached) self.cache_hits += 1 else: embeddings.append(None) uncached_indices.append(i) uncached_texts.append(text) else: uncached_indices = list(range(len(texts))) uncached_texts = texts embeddings = [None] * len(texts) # Process uncached texts in batches if uncached_texts: if self.openai_client: try: # Process in batches for API efficiency batch_embeddings = await self._process_embedding_batches(uncached_texts) # Fill in the results and cache them for idx, embedding in zip(uncached_indices, batch_embeddings, strict=False): embeddings[idx] = embedding if use_cache and embedding is not None: await self.embedding_cache.set_embedding( texts[idx], embedding, self.model_name ) except Exception as e: logger.error(f"Batch embedding failed: {e}") # Fallback to mock embeddings for idx, text in zip(uncached_indices, uncached_texts, strict=False): embeddings[idx] = self._create_mock_embedding(text) else: # Use mock embeddings for idx, text in zip(uncached_indices, uncached_texts, strict=False): embeddings[idx] = self._create_mock_embedding(text) self.embedding_requests += len(texts) return embeddings async def _process_embedding_batches(self, texts: list[str]) -> list[list[float] | None]: """Process texts in batches for API efficiency.""" all_embeddings = [] # Process in chunks of API batch size api_batch_size = min(100, len(texts)) # OpenAI limit for i in range(0, len(texts), api_batch_size): batch_texts = texts[i:i + api_batch_size] try: # Clean texts cleaned_texts = [ text.replace("\n", " ").strip()[:8000] for text in batch_texts ] # Make batch API call response = await self.openai_client.embeddings.create( input=cleaned_texts, model=self.model_name ) # Extract embeddings batch_embeddings = [data.embedding for data in response.data] all_embeddings.extend(batch_embeddings) except Exception as e: logger.error(f"Batch API call failed: {e}") # Fallback to mock for this batch mock_embeddings = [self._create_mock_embedding(text) for text in batch_texts] all_embeddings.extend(mock_embeddings) return all_embeddings def _create_mock_embedding(self, text: str) -> list[float]: """Create a deterministic mock embedding.""" # Create a deterministic hash-based embedding text_hash = hashlib.sha256(text.encode()).hexdigest() # Convert hash to numbers and normalize embedding = [] for i in range(0, min(len(text_hash), self.embedding_dim * 2), 2): hex_pair = text_hash[i : i + 2] value = int(hex_pair, 16) / 255.0 embedding.append(value) # Pad or truncate to desired dimension while len(embedding) < self.embedding_dim: embedding.append(0.0) return embedding[:self.embedding_dim] @async_cached(main_cache) async def embed_tool_description(self, tool: MCPTool) -> list[float] | None: """Generate an embedding for a tool's description. Args: tool: The MCPTool to embed Returns: Embedding vector for the tool's description """ # Combine name, description, and tags for richer embedding combined_text = f"{tool.name} {tool.description}" if tool.tags: combined_text += f" {' '.join(tool.tags)}" return await self.get_embedding(combined_text) @async_cached(main_cache) async def embed_prompt_description(self, prompt: MCPPrompt) -> list[float] | None: """Generate an embedding for a prompt's description. Args: prompt: The MCPPrompt to embed Returns: Embedding vector for the prompt's description """ # Combine name, description, and template for richer embedding combined_text = f"{prompt.name} {prompt.description}" if prompt.template_string: # Include template but limit length template_preview = prompt.template_string[:200] combined_text += f" {template_preview}" return await self.get_embedding(combined_text) async def embed_tools_batch(self, tools: list[MCPTool]) -> list[list[float] | None]: """Generate embeddings for multiple tools efficiently.""" tool_texts = [] for tool in tools: combined_text = f"{tool.name} {tool.description}" if tool.tags: combined_text += f" {' '.join(tool.tags)}" tool_texts.append(combined_text) return await self.get_embeddings_batch(tool_texts) async def embed_prompts_batch(self, prompts: list[MCPPrompt]) -> list[list[float] | None]: """Generate embeddings for multiple prompts efficiently.""" prompt_texts = [] for prompt in prompts: combined_text = f"{prompt.name} {prompt.description}" if prompt.template_string: template_preview = prompt.template_string[:200] combined_text += f" {template_preview}" prompt_texts.append(combined_text) return await self.get_embeddings_batch(prompt_texts) def compute_similarity( self, embedding1: list[float], embedding2: list[float] ) -> float: """Compute cosine similarity between two embeddings. Args: embedding1: First embedding vector embedding2: Second embedding vector Returns: Similarity score between 0 and 1 """ if not embedding1 or not embedding2: return 0.0 # Ensure same length min_len = min(len(embedding1), len(embedding2)) vec1 = embedding1[:min_len] vec2 = embedding2[:min_len] # Compute dot product dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False)) # Compute magnitudes magnitude1 = math.sqrt(sum(a * a for a in vec1)) magnitude2 = math.sqrt(sum(b * b for b in vec2)) # Avoid division by zero if magnitude1 == 0 or magnitude2 == 0: return 0.0 # Cosine similarity, normalized to 0-1 range cosine_sim = dot_product / (magnitude1 * magnitude2) return max(0.0, min(1.0, (cosine_sim + 1) / 2)) async def find_similar_tools( self, query: str, tools: list[MCPTool], top_k: int = 5 ) -> list[tuple[MCPTool, float]]: """Find tools most similar to the given query. Args: query: The search query text tools: List of tools to search through top_k: Maximum number of tools to return Returns: List of (tool, similarity_score) tuples, sorted by similarity """ if not tools: return [] # Get query embedding query_embedding = await self.get_embedding(query) if query_embedding is None: return [(tool, 0.0) for tool in tools[:top_k]] # Get tool embeddings in batch tool_embeddings = await self.embed_tools_batch(tools) # Compute similarities tool_similarities = [] for tool, tool_embedding in zip(tools, tool_embeddings, strict=False): if tool_embedding is not None: similarity = self.compute_similarity(query_embedding, tool_embedding) tool_similarities.append((tool, similarity)) else: tool_similarities.append((tool, 0.0)) # Sort by similarity (descending) and return top_k tool_similarities.sort(key=lambda x: x[1], reverse=True) return tool_similarities[:top_k] async def find_similar_prompts( self, query: str, prompts: list[MCPPrompt], top_k: int = 5 ) -> list[tuple[MCPPrompt, float]]: """Find prompts most similar to the given query. Args: query: The search query text prompts: List of prompts to search through top_k: Maximum number of prompts to return Returns: List of (prompt, similarity_score) tuples, sorted by similarity """ if not prompts: return [] # Get query embedding query_embedding = await self.get_embedding(query) if query_embedding is None: return [(prompt, 0.0) for prompt in prompts[:top_k]] # Get prompt embeddings in batch prompt_embeddings = await self.embed_prompts_batch(prompts) # Compute similarities prompt_similarities = [] for prompt, prompt_embedding in zip(prompts, prompt_embeddings, strict=False): if prompt_embedding is not None: similarity = self.compute_similarity(query_embedding, prompt_embedding) prompt_similarities.append((prompt, similarity)) else: prompt_similarities.append((prompt, 0.0)) # Sort by similarity (descending) and return top_k prompt_similarities.sort(key=lambda x: x[1], reverse=True) return prompt_similarities[:top_k] def get_performance_stats(self) -> dict[str, Any]: """Get embedding service performance statistics.""" cache_hit_rate = ( self.cache_hits / self.embedding_requests if self.embedding_requests > 0 else 0 ) return { "total_embedding_requests": self.embedding_requests, "cache_hits": self.cache_hits, "cache_hit_rate": cache_hit_rate, "cache_size": len(self._cache), "model_name": self.model_name, "embedding_dim": self.embedding_dim, "batch_size": self.batch_size, "openai_available": self.openai_client is not None } async def warm_up_cache(self, texts: list[str]) -> None: """Pre-populate cache with common texts.""" logger.info(f"Warming up embedding cache with {len(texts)} texts...") await self.get_embeddings_batch(texts, use_cache=True) logger.info("Cache warm-up completed") async def clear_cache(self) -> None: """Clear all caches.""" await self.embedding_cache.cache.clear() logger.info("Embedding cache cleared")