File size: 15,842 Bytes
1f2d50a
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
64ced8b
 
 
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
 
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
 
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
64ced8b
 
1f2d50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
 
 
 
64ced8b
1f2d50a
 
 
 
 
 
 
 
64ced8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""Async embedding service for semantic similarity in MCP tools.

This module provides high-performance async embedding functionality with
caching, batch processing, and memory optimization.
"""

import hashlib
import logging
import math
import os
from typing import Any

import openai
from dotenv import load_dotenv

from .ontology import MCPPrompt, MCPTool
from .performance import (
    AsyncBatchProcessor,
    EmbeddingCache,
    async_cached,
    main_cache,
    performance_monitor,
    performance_monitor_instance,
)

# Load environment variables
load_dotenv()

logger = logging.getLogger(__name__)


class AsyncEmbeddingService:
    """High-performance async embedding service with caching and optimization."""

    def __init__(self, embedding_dim: int = 128, batch_size: int = 10):
        """Initialize the async embedding service.

        Args:
            embedding_dim: Dimension of the embedding vectors
            batch_size: Number of embeddings to process in a batch
        """
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size

        # Initialize caching
        self.embedding_cache = EmbeddingCache(max_size=2000, compression_precision=6)
        self.batch_processor = AsyncBatchProcessor(batch_size=batch_size, max_concurrent=5)

        # Initialize OpenAI client
        api_key = os.getenv("OPENAI_API_KEY")
        if api_key:
            self.openai_client: openai.AsyncOpenAI | None = openai.AsyncOpenAI(api_key=api_key)
            self.model_name = "text-embedding-3-small"
        else:
            self.openai_client = None
            self.model_name = "mock"
            logger.warning("OPENAI_API_KEY not found. Using mock embeddings.")

        # Performance tracking
        self.embedding_requests = 0
        self.cache_hits = 0
        self._cache = {}

    @performance_monitor(performance_monitor_instance)
    async def get_embedding(self, text: str, use_cache: bool = True) -> list[float] | None:
        """Generate an embedding vector for the given text using OpenAI API.

        Args:
            text: The text to embed
            use_cache: Whether to use caching

        Returns:
            A list of floats representing the embedding vector, or None if API call fails
        """
        self.embedding_requests += 1

        # Try cache first
        if use_cache:
            cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
            if cache_key in self._cache:
                self.cache_hits += 1
                return self._cache[cache_key]

        if not self.openai_client:
            logger.debug("OpenAI client not available, using mock embedding")
            embedding = self._create_mock_embedding(text)

            # Cache the mock embedding
            if use_cache:
                cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
                self._cache[cache_key] = embedding

            return embedding

        try:
            # Preprocess text
            cleaned_text = text.replace("\n", " ").strip()

            # Limit text length to avoid API limits
            if len(cleaned_text) > 8000:
                cleaned_text = cleaned_text[:8000]

            # Make async API call to OpenAI
            response = await self.openai_client.embeddings.create(
                input=cleaned_text,
                model=self.model_name
            )

            # Extract embedding from response
            embedding = response.data[0].embedding

            # Cache the result
            if use_cache:
                cache_key = hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
                self._cache[cache_key] = embedding

            return embedding

        except openai.APIError as e:
            logger.error(f"OpenAI API error when generating embedding: {e}")
            return self._create_mock_embedding(text)
        except Exception as e:
            logger.error(f"Unexpected error when generating embedding: {e}")
            return self._create_mock_embedding(text)

    async def get_embeddings_batch(self, texts: list[str], use_cache: bool = True) -> list[list[float] | None]:
        """Generate embeddings for multiple texts efficiently.

        Args:
            texts: List of texts to embed
            use_cache: Whether to use caching

        Returns:
            List of embedding vectors
        """
        if not texts:
            return []

        # Check cache for all texts first
        embeddings = []
        uncached_indices = []
        uncached_texts = []

        if use_cache:
            for i, text in enumerate(texts):
                cached = await self.embedding_cache.get_embedding(text, self.model_name)
                if cached is not None:
                    embeddings.append(cached)
                    self.cache_hits += 1
                else:
                    embeddings.append(None)
                    uncached_indices.append(i)
                    uncached_texts.append(text)
        else:
            uncached_indices = list(range(len(texts)))
            uncached_texts = texts
            embeddings = [None] * len(texts)

        # Process uncached texts in batches
        if uncached_texts:
            if self.openai_client:
                try:
                    # Process in batches for API efficiency
                    batch_embeddings = await self._process_embedding_batches(uncached_texts)

                    # Fill in the results and cache them
                    for idx, embedding in zip(uncached_indices, batch_embeddings, strict=False):
                        embeddings[idx] = embedding
                        if use_cache and embedding is not None:
                            await self.embedding_cache.set_embedding(
                                texts[idx], embedding, self.model_name
                            )

                except Exception as e:
                    logger.error(f"Batch embedding failed: {e}")
                    # Fallback to mock embeddings
                    for idx, text in zip(uncached_indices, uncached_texts, strict=False):
                        embeddings[idx] = self._create_mock_embedding(text)
            else:
                # Use mock embeddings
                for idx, text in zip(uncached_indices, uncached_texts, strict=False):
                    embeddings[idx] = self._create_mock_embedding(text)

        self.embedding_requests += len(texts)
        return embeddings

    async def _process_embedding_batches(self, texts: list[str]) -> list[list[float] | None]:
        """Process texts in batches for API efficiency."""
        all_embeddings = []

        # Process in chunks of API batch size
        api_batch_size = min(100, len(texts))  # OpenAI limit

        for i in range(0, len(texts), api_batch_size):
            batch_texts = texts[i:i + api_batch_size]

            try:
                # Clean texts
                cleaned_texts = [
                    text.replace("\n", " ").strip()[:8000]
                    for text in batch_texts
                ]

                # Make batch API call
                response = await self.openai_client.embeddings.create(
                    input=cleaned_texts,
                    model=self.model_name
                )

                # Extract embeddings
                batch_embeddings = [data.embedding for data in response.data]
                all_embeddings.extend(batch_embeddings)

            except Exception as e:
                logger.error(f"Batch API call failed: {e}")
                # Fallback to mock for this batch
                mock_embeddings = [self._create_mock_embedding(text) for text in batch_texts]
                all_embeddings.extend(mock_embeddings)

        return all_embeddings

    def _create_mock_embedding(self, text: str) -> list[float]:
        """Create a deterministic mock embedding."""
        # Create a deterministic hash-based embedding
        text_hash = hashlib.sha256(text.encode()).hexdigest()

        # Convert hash to numbers and normalize
        embedding = []
        for i in range(0, min(len(text_hash), self.embedding_dim * 2), 2):
            hex_pair = text_hash[i : i + 2]
            value = int(hex_pair, 16) / 255.0
            embedding.append(value)

        # Pad or truncate to desired dimension
        while len(embedding) < self.embedding_dim:
            embedding.append(0.0)

        return embedding[:self.embedding_dim]

    @async_cached(main_cache)
    async def embed_tool_description(self, tool: MCPTool) -> list[float] | None:
        """Generate an embedding for a tool's description.

        Args:
            tool: The MCPTool to embed

        Returns:
            Embedding vector for the tool's description
        """
        # Combine name, description, and tags for richer embedding
        combined_text = f"{tool.name} {tool.description}"
        if tool.tags:
            combined_text += f" {' '.join(tool.tags)}"

        return await self.get_embedding(combined_text)

    @async_cached(main_cache)
    async def embed_prompt_description(self, prompt: MCPPrompt) -> list[float] | None:
        """Generate an embedding for a prompt's description.

        Args:
            prompt: The MCPPrompt to embed

        Returns:
            Embedding vector for the prompt's description
        """
        # Combine name, description, and template for richer embedding
        combined_text = f"{prompt.name} {prompt.description}"
        if prompt.template_string:
            # Include template but limit length
            template_preview = prompt.template_string[:200]
            combined_text += f" {template_preview}"

        return await self.get_embedding(combined_text)

    async def embed_tools_batch(self, tools: list[MCPTool]) -> list[list[float] | None]:
        """Generate embeddings for multiple tools efficiently."""
        tool_texts = []
        for tool in tools:
            combined_text = f"{tool.name} {tool.description}"
            if tool.tags:
                combined_text += f" {' '.join(tool.tags)}"
            tool_texts.append(combined_text)

        return await self.get_embeddings_batch(tool_texts)

    async def embed_prompts_batch(self, prompts: list[MCPPrompt]) -> list[list[float] | None]:
        """Generate embeddings for multiple prompts efficiently."""
        prompt_texts = []
        for prompt in prompts:
            combined_text = f"{prompt.name} {prompt.description}"
            if prompt.template_string:
                template_preview = prompt.template_string[:200]
                combined_text += f" {template_preview}"
            prompt_texts.append(combined_text)

        return await self.get_embeddings_batch(prompt_texts)

    def compute_similarity(
        self, embedding1: list[float], embedding2: list[float]
    ) -> float:
        """Compute cosine similarity between two embeddings.

        Args:
            embedding1: First embedding vector
            embedding2: Second embedding vector

        Returns:
            Similarity score between 0 and 1
        """
        if not embedding1 or not embedding2:
            return 0.0

        # Ensure same length
        min_len = min(len(embedding1), len(embedding2))
        vec1 = embedding1[:min_len]
        vec2 = embedding2[:min_len]

        # Compute dot product
        dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False))

        # Compute magnitudes
        magnitude1 = math.sqrt(sum(a * a for a in vec1))
        magnitude2 = math.sqrt(sum(b * b for b in vec2))

        # Avoid division by zero
        if magnitude1 == 0 or magnitude2 == 0:
            return 0.0

        # Cosine similarity, normalized to 0-1 range
        cosine_sim = dot_product / (magnitude1 * magnitude2)
        return max(0.0, min(1.0, (cosine_sim + 1) / 2))

    async def find_similar_tools(
        self, query: str, tools: list[MCPTool], top_k: int = 5
    ) -> list[tuple[MCPTool, float]]:
        """Find tools most similar to the given query.

        Args:
            query: The search query text
            tools: List of tools to search through
            top_k: Maximum number of tools to return

        Returns:
            List of (tool, similarity_score) tuples, sorted by similarity
        """
        if not tools:
            return []

        # Get query embedding
        query_embedding = await self.get_embedding(query)
        if query_embedding is None:
            return [(tool, 0.0) for tool in tools[:top_k]]

        # Get tool embeddings in batch
        tool_embeddings = await self.embed_tools_batch(tools)

        # Compute similarities
        tool_similarities = []
        for tool, tool_embedding in zip(tools, tool_embeddings, strict=False):
            if tool_embedding is not None:
                similarity = self.compute_similarity(query_embedding, tool_embedding)
                tool_similarities.append((tool, similarity))
            else:
                tool_similarities.append((tool, 0.0))

        # Sort by similarity (descending) and return top_k
        tool_similarities.sort(key=lambda x: x[1], reverse=True)
        return tool_similarities[:top_k]

    async def find_similar_prompts(
        self, query: str, prompts: list[MCPPrompt], top_k: int = 5
    ) -> list[tuple[MCPPrompt, float]]:
        """Find prompts most similar to the given query.

        Args:
            query: The search query text
            prompts: List of prompts to search through
            top_k: Maximum number of prompts to return

        Returns:
            List of (prompt, similarity_score) tuples, sorted by similarity
        """
        if not prompts:
            return []

        # Get query embedding
        query_embedding = await self.get_embedding(query)
        if query_embedding is None:
            return [(prompt, 0.0) for prompt in prompts[:top_k]]

        # Get prompt embeddings in batch
        prompt_embeddings = await self.embed_prompts_batch(prompts)

        # Compute similarities
        prompt_similarities = []
        for prompt, prompt_embedding in zip(prompts, prompt_embeddings, strict=False):
            if prompt_embedding is not None:
                similarity = self.compute_similarity(query_embedding, prompt_embedding)
                prompt_similarities.append((prompt, similarity))
            else:
                prompt_similarities.append((prompt, 0.0))

        # Sort by similarity (descending) and return top_k
        prompt_similarities.sort(key=lambda x: x[1], reverse=True)
        return prompt_similarities[:top_k]

    def get_performance_stats(self) -> dict[str, Any]:
        """Get embedding service performance statistics."""
        cache_hit_rate = (
            self.cache_hits / self.embedding_requests
            if self.embedding_requests > 0 else 0
        )

        return {
            "total_embedding_requests": self.embedding_requests,
            "cache_hits": self.cache_hits,
            "cache_hit_rate": cache_hit_rate,
            "cache_size": len(self._cache),
            "model_name": self.model_name,
            "embedding_dim": self.embedding_dim,
            "batch_size": self.batch_size,
            "openai_available": self.openai_client is not None
        }

    async def warm_up_cache(self, texts: list[str]) -> None:
        """Pre-populate cache with common texts."""
        logger.info(f"Warming up embedding cache with {len(texts)} texts...")
        await self.get_embeddings_batch(texts, use_cache=True)
        logger.info("Cache warm-up completed")

    async def clear_cache(self) -> None:
        """Clear all caches."""
        await self.embedding_cache.cache.clear()
        logger.info("Embedding cache cleared")