|
|
"""Async embedding service for semantic similarity in MCP tools. |
|
|
|
|
|
This module provides high-performance async embedding functionality with |
|
|
caching, batch processing, and memory optimization. |
|
|
""" |
|
|
|
|
|
import hashlib |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
from typing import Any |
|
|
|
|
|
import openai |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from .ontology import MCPPrompt, MCPTool |
|
|
from .performance import ( |
|
|
AsyncBatchProcessor, |
|
|
EmbeddingCache, |
|
|
async_cached, |
|
|
main_cache, |
|
|
performance_monitor, |
|
|
performance_monitor_instance, |
|
|
) |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AsyncEmbeddingService: |
|
|
"""High-performance async embedding service with caching and optimization.""" |
|
|
|
|
|
def __init__(self, embedding_dim: int = 128, batch_size: int = 10): |
|
|
"""Initialize the async embedding service. |
|
|
|
|
|
Args: |
|
|
embedding_dim: Dimension of the embedding vectors |
|
|
batch_size: Number of embeddings to process in a batch |
|
|
""" |
|
|
self.embedding_dim = embedding_dim |
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
self.embedding_cache = EmbeddingCache(max_size=2000, compression_precision=6) |
|
|
self.batch_processor = AsyncBatchProcessor(batch_size=batch_size, max_concurrent=5) |
|
|
|
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
if api_key: |
|
|
self.openai_client: openai.AsyncOpenAI | None = openai.AsyncOpenAI(api_key=api_key) |
|
|
self.model_name = "text-embedding-3-small" |
|
|
else: |
|
|
self.openai_client = None |
|
|
self.model_name = "mock" |
|
|
logger.warning("OPENAI_API_KEY not found. Using mock embeddings.") |
|
|
|
|
|
|
|
|
self.embedding_requests = 0 |
|
|
self.cache_hits = 0 |
|
|
self._cache = {} |
|
|
|
|
|
@performance_monitor(performance_monitor_instance) |
|
|
async def get_embedding(self, text: str, use_cache: bool = True) -> list[float] | None: |
|
|
"""Generate an embedding vector for the given text using OpenAI API. |
|
|
|
|
|
Args: |
|
|
text: The text to embed |
|
|
use_cache: Whether to use caching |
|
|
|
|
|
Returns: |
|
|
A list of floats representing the embedding vector, or None if API call fails |
|
|
""" |
|
|
self.embedding_requests += 1 |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() |
|
|
if cache_key in self._cache: |
|
|
self.cache_hits += 1 |
|
|
return self._cache[cache_key] |
|
|
|
|
|
if not self.openai_client: |
|
|
logger.debug("OpenAI client not available, using mock embedding") |
|
|
embedding = self._create_mock_embedding(text) |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() |
|
|
self._cache[cache_key] = embedding |
|
|
|
|
|
return embedding |
|
|
|
|
|
try: |
|
|
|
|
|
cleaned_text = text.replace("\n", " ").strip() |
|
|
|
|
|
|
|
|
if len(cleaned_text) > 8000: |
|
|
cleaned_text = cleaned_text[:8000] |
|
|
|
|
|
|
|
|
response = await self.openai_client.embeddings.create( |
|
|
input=cleaned_text, |
|
|
model=self.model_name |
|
|
) |
|
|
|
|
|
|
|
|
embedding = response.data[0].embedding |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() |
|
|
self._cache[cache_key] = embedding |
|
|
|
|
|
return embedding |
|
|
|
|
|
except openai.APIError as e: |
|
|
logger.error(f"OpenAI API error when generating embedding: {e}") |
|
|
return self._create_mock_embedding(text) |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error when generating embedding: {e}") |
|
|
return self._create_mock_embedding(text) |
|
|
|
|
|
async def get_embeddings_batch(self, texts: list[str], use_cache: bool = True) -> list[list[float] | None]: |
|
|
"""Generate embeddings for multiple texts efficiently. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to embed |
|
|
use_cache: Whether to use caching |
|
|
|
|
|
Returns: |
|
|
List of embedding vectors |
|
|
""" |
|
|
if not texts: |
|
|
return [] |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
uncached_indices = [] |
|
|
uncached_texts = [] |
|
|
|
|
|
if use_cache: |
|
|
for i, text in enumerate(texts): |
|
|
cached = await self.embedding_cache.get_embedding(text, self.model_name) |
|
|
if cached is not None: |
|
|
embeddings.append(cached) |
|
|
self.cache_hits += 1 |
|
|
else: |
|
|
embeddings.append(None) |
|
|
uncached_indices.append(i) |
|
|
uncached_texts.append(text) |
|
|
else: |
|
|
uncached_indices = list(range(len(texts))) |
|
|
uncached_texts = texts |
|
|
embeddings = [None] * len(texts) |
|
|
|
|
|
|
|
|
if uncached_texts: |
|
|
if self.openai_client: |
|
|
try: |
|
|
|
|
|
batch_embeddings = await self._process_embedding_batches(uncached_texts) |
|
|
|
|
|
|
|
|
for idx, embedding in zip(uncached_indices, batch_embeddings, strict=False): |
|
|
embeddings[idx] = embedding |
|
|
if use_cache and embedding is not None: |
|
|
await self.embedding_cache.set_embedding( |
|
|
texts[idx], embedding, self.model_name |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch embedding failed: {e}") |
|
|
|
|
|
for idx, text in zip(uncached_indices, uncached_texts, strict=False): |
|
|
embeddings[idx] = self._create_mock_embedding(text) |
|
|
else: |
|
|
|
|
|
for idx, text in zip(uncached_indices, uncached_texts, strict=False): |
|
|
embeddings[idx] = self._create_mock_embedding(text) |
|
|
|
|
|
self.embedding_requests += len(texts) |
|
|
return embeddings |
|
|
|
|
|
async def _process_embedding_batches(self, texts: list[str]) -> list[list[float] | None]: |
|
|
"""Process texts in batches for API efficiency.""" |
|
|
all_embeddings = [] |
|
|
|
|
|
|
|
|
api_batch_size = min(100, len(texts)) |
|
|
|
|
|
for i in range(0, len(texts), api_batch_size): |
|
|
batch_texts = texts[i:i + api_batch_size] |
|
|
|
|
|
try: |
|
|
|
|
|
cleaned_texts = [ |
|
|
text.replace("\n", " ").strip()[:8000] |
|
|
for text in batch_texts |
|
|
] |
|
|
|
|
|
|
|
|
response = await self.openai_client.embeddings.create( |
|
|
input=cleaned_texts, |
|
|
model=self.model_name |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = [data.embedding for data in response.data] |
|
|
all_embeddings.extend(batch_embeddings) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch API call failed: {e}") |
|
|
|
|
|
mock_embeddings = [self._create_mock_embedding(text) for text in batch_texts] |
|
|
all_embeddings.extend(mock_embeddings) |
|
|
|
|
|
return all_embeddings |
|
|
|
|
|
def _create_mock_embedding(self, text: str) -> list[float]: |
|
|
"""Create a deterministic mock embedding.""" |
|
|
|
|
|
text_hash = hashlib.sha256(text.encode()).hexdigest() |
|
|
|
|
|
|
|
|
embedding = [] |
|
|
for i in range(0, min(len(text_hash), self.embedding_dim * 2), 2): |
|
|
hex_pair = text_hash[i : i + 2] |
|
|
value = int(hex_pair, 16) / 255.0 |
|
|
embedding.append(value) |
|
|
|
|
|
|
|
|
while len(embedding) < self.embedding_dim: |
|
|
embedding.append(0.0) |
|
|
|
|
|
return embedding[:self.embedding_dim] |
|
|
|
|
|
@async_cached(main_cache) |
|
|
async def embed_tool_description(self, tool: MCPTool) -> list[float] | None: |
|
|
"""Generate an embedding for a tool's description. |
|
|
|
|
|
Args: |
|
|
tool: The MCPTool to embed |
|
|
|
|
|
Returns: |
|
|
Embedding vector for the tool's description |
|
|
""" |
|
|
|
|
|
combined_text = f"{tool.name} {tool.description}" |
|
|
if tool.tags: |
|
|
combined_text += f" {' '.join(tool.tags)}" |
|
|
|
|
|
return await self.get_embedding(combined_text) |
|
|
|
|
|
@async_cached(main_cache) |
|
|
async def embed_prompt_description(self, prompt: MCPPrompt) -> list[float] | None: |
|
|
"""Generate an embedding for a prompt's description. |
|
|
|
|
|
Args: |
|
|
prompt: The MCPPrompt to embed |
|
|
|
|
|
Returns: |
|
|
Embedding vector for the prompt's description |
|
|
""" |
|
|
|
|
|
combined_text = f"{prompt.name} {prompt.description}" |
|
|
if prompt.template_string: |
|
|
|
|
|
template_preview = prompt.template_string[:200] |
|
|
combined_text += f" {template_preview}" |
|
|
|
|
|
return await self.get_embedding(combined_text) |
|
|
|
|
|
async def embed_tools_batch(self, tools: list[MCPTool]) -> list[list[float] | None]: |
|
|
"""Generate embeddings for multiple tools efficiently.""" |
|
|
tool_texts = [] |
|
|
for tool in tools: |
|
|
combined_text = f"{tool.name} {tool.description}" |
|
|
if tool.tags: |
|
|
combined_text += f" {' '.join(tool.tags)}" |
|
|
tool_texts.append(combined_text) |
|
|
|
|
|
return await self.get_embeddings_batch(tool_texts) |
|
|
|
|
|
async def embed_prompts_batch(self, prompts: list[MCPPrompt]) -> list[list[float] | None]: |
|
|
"""Generate embeddings for multiple prompts efficiently.""" |
|
|
prompt_texts = [] |
|
|
for prompt in prompts: |
|
|
combined_text = f"{prompt.name} {prompt.description}" |
|
|
if prompt.template_string: |
|
|
template_preview = prompt.template_string[:200] |
|
|
combined_text += f" {template_preview}" |
|
|
prompt_texts.append(combined_text) |
|
|
|
|
|
return await self.get_embeddings_batch(prompt_texts) |
|
|
|
|
|
def compute_similarity( |
|
|
self, embedding1: list[float], embedding2: list[float] |
|
|
) -> float: |
|
|
"""Compute cosine similarity between two embeddings. |
|
|
|
|
|
Args: |
|
|
embedding1: First embedding vector |
|
|
embedding2: Second embedding vector |
|
|
|
|
|
Returns: |
|
|
Similarity score between 0 and 1 |
|
|
""" |
|
|
if not embedding1 or not embedding2: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
min_len = min(len(embedding1), len(embedding2)) |
|
|
vec1 = embedding1[:min_len] |
|
|
vec2 = embedding2[:min_len] |
|
|
|
|
|
|
|
|
dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False)) |
|
|
|
|
|
|
|
|
magnitude1 = math.sqrt(sum(a * a for a in vec1)) |
|
|
magnitude2 = math.sqrt(sum(b * b for b in vec2)) |
|
|
|
|
|
|
|
|
if magnitude1 == 0 or magnitude2 == 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
cosine_sim = dot_product / (magnitude1 * magnitude2) |
|
|
return max(0.0, min(1.0, (cosine_sim + 1) / 2)) |
|
|
|
|
|
async def find_similar_tools( |
|
|
self, query: str, tools: list[MCPTool], top_k: int = 5 |
|
|
) -> list[tuple[MCPTool, float]]: |
|
|
"""Find tools most similar to the given query. |
|
|
|
|
|
Args: |
|
|
query: The search query text |
|
|
tools: List of tools to search through |
|
|
top_k: Maximum number of tools to return |
|
|
|
|
|
Returns: |
|
|
List of (tool, similarity_score) tuples, sorted by similarity |
|
|
""" |
|
|
if not tools: |
|
|
return [] |
|
|
|
|
|
|
|
|
query_embedding = await self.get_embedding(query) |
|
|
if query_embedding is None: |
|
|
return [(tool, 0.0) for tool in tools[:top_k]] |
|
|
|
|
|
|
|
|
tool_embeddings = await self.embed_tools_batch(tools) |
|
|
|
|
|
|
|
|
tool_similarities = [] |
|
|
for tool, tool_embedding in zip(tools, tool_embeddings, strict=False): |
|
|
if tool_embedding is not None: |
|
|
similarity = self.compute_similarity(query_embedding, tool_embedding) |
|
|
tool_similarities.append((tool, similarity)) |
|
|
else: |
|
|
tool_similarities.append((tool, 0.0)) |
|
|
|
|
|
|
|
|
tool_similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
return tool_similarities[:top_k] |
|
|
|
|
|
async def find_similar_prompts( |
|
|
self, query: str, prompts: list[MCPPrompt], top_k: int = 5 |
|
|
) -> list[tuple[MCPPrompt, float]]: |
|
|
"""Find prompts most similar to the given query. |
|
|
|
|
|
Args: |
|
|
query: The search query text |
|
|
prompts: List of prompts to search through |
|
|
top_k: Maximum number of prompts to return |
|
|
|
|
|
Returns: |
|
|
List of (prompt, similarity_score) tuples, sorted by similarity |
|
|
""" |
|
|
if not prompts: |
|
|
return [] |
|
|
|
|
|
|
|
|
query_embedding = await self.get_embedding(query) |
|
|
if query_embedding is None: |
|
|
return [(prompt, 0.0) for prompt in prompts[:top_k]] |
|
|
|
|
|
|
|
|
prompt_embeddings = await self.embed_prompts_batch(prompts) |
|
|
|
|
|
|
|
|
prompt_similarities = [] |
|
|
for prompt, prompt_embedding in zip(prompts, prompt_embeddings, strict=False): |
|
|
if prompt_embedding is not None: |
|
|
similarity = self.compute_similarity(query_embedding, prompt_embedding) |
|
|
prompt_similarities.append((prompt, similarity)) |
|
|
else: |
|
|
prompt_similarities.append((prompt, 0.0)) |
|
|
|
|
|
|
|
|
prompt_similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
return prompt_similarities[:top_k] |
|
|
|
|
|
def get_performance_stats(self) -> dict[str, Any]: |
|
|
"""Get embedding service performance statistics.""" |
|
|
cache_hit_rate = ( |
|
|
self.cache_hits / self.embedding_requests |
|
|
if self.embedding_requests > 0 else 0 |
|
|
) |
|
|
|
|
|
return { |
|
|
"total_embedding_requests": self.embedding_requests, |
|
|
"cache_hits": self.cache_hits, |
|
|
"cache_hit_rate": cache_hit_rate, |
|
|
"cache_size": len(self._cache), |
|
|
"model_name": self.model_name, |
|
|
"embedding_dim": self.embedding_dim, |
|
|
"batch_size": self.batch_size, |
|
|
"openai_available": self.openai_client is not None |
|
|
} |
|
|
|
|
|
async def warm_up_cache(self, texts: list[str]) -> None: |
|
|
"""Pre-populate cache with common texts.""" |
|
|
logger.info(f"Warming up embedding cache with {len(texts)} texts...") |
|
|
await self.get_embeddings_batch(texts, use_cache=True) |
|
|
logger.info("Cache warm-up completed") |
|
|
|
|
|
async def clear_cache(self) -> None: |
|
|
"""Clear all caches.""" |
|
|
await self.embedding_cache.cache.clear() |
|
|
logger.info("Embedding cache cleared") |
|
|
|