Spaces:
Sleeping
Sleeping
| import logging | |
| import hashlib | |
| import json | |
| from datetime import datetime, timedelta | |
| from collections import deque | |
| from sambanova import SambaNova | |
| from utils.config import Config | |
| class SambaNovaService: | |
| def __init__(self): | |
| self.api_keys = self._load_api_keys() | |
| self.key_index = 0 | |
| self.client = None | |
| self.key_rotation_queue = deque(self.api_keys) | |
| # Caching | |
| self.response_cache = {} | |
| self.cache_ttl = 3600 # 1 hour | |
| # Quota tracking | |
| self.quota_tracker = {key: {'calls': 0, 'errors': 0, 'last_reset': datetime.now()} | |
| for key in self.api_keys} | |
| self.quota_limit_per_hour = 100 | |
| self.initialize_client() | |
| def _load_api_keys(self): | |
| """Load API keys from Config""" | |
| import os | |
| keys_str = os.getenv('SAMBANOVA_API_KEYS', '') | |
| if not keys_str: | |
| return [Config.SAMBANOVA_API_KEY] | |
| keys_list = [key.strip() for key in keys_str.split(',') if key.strip()] | |
| return keys_list if keys_list else [Config.SAMBANOVA_API_KEY] | |
| def initialize_client(self, api_key=None): | |
| """Initialize SambaNova client with specific or rotated API key""" | |
| try: | |
| key = api_key or self.api_keys[self.key_index] | |
| self.client = SambaNova( | |
| api_key=key, | |
| base_url=Config.SAMBANOVA_BASE_URL, | |
| ) | |
| logging.info(f"SambaNova client initialized with key: {key[:10]}...") | |
| except Exception as e: | |
| logging.error(f"Failed to initialize SambaNova client: {e}") | |
| self.client = None | |
| def _get_cache_key(self, prompt, model_name, temperature, top_p): | |
| """Generate cache key from parameters""" | |
| key_str = f"{prompt}_{model_name}_{temperature}_{top_p}" | |
| return hashlib.md5(key_str.encode()).hexdigest() | |
| def _get_cached_response(self, cache_key): | |
| """Get response from cache if valid""" | |
| if cache_key in self.response_cache: | |
| cached_data = self.response_cache[cache_key] | |
| if datetime.now() < cached_data['expires_at']: | |
| logging.info(f"Cache hit for key: {cache_key}") | |
| return cached_data['response'] | |
| else: | |
| del self.response_cache[cache_key] | |
| return None | |
| def _cache_response(self, cache_key, response): | |
| """Cache response with TTL""" | |
| self.response_cache[cache_key] = { | |
| 'response': response, | |
| 'expires_at': datetime.now() + timedelta(seconds=self.cache_ttl), | |
| 'cached_at': datetime.now() | |
| } | |
| def _update_quota(self, api_key, is_error=False): | |
| """Update quota tracking for API key""" | |
| if api_key not in self.quota_tracker: | |
| self.quota_tracker[api_key] = {'calls': 0, 'errors': 0, 'last_reset': datetime.now()} | |
| tracker = self.quota_tracker[api_key] | |
| # Reset if 1 hour passed | |
| if datetime.now() - tracker['last_reset'] > timedelta(hours=1): | |
| tracker['calls'] = 0 | |
| tracker['errors'] = 0 | |
| tracker['last_reset'] = datetime.now() | |
| if is_error: | |
| tracker['errors'] += 1 | |
| else: | |
| tracker['calls'] += 1 | |
| logging.debug(f"Quota for {api_key[:10]}...: {tracker['calls']} calls, {tracker['errors']} errors") | |
| def _is_quota_exceeded(self, api_key): | |
| """Check if API key quota is exceeded""" | |
| if api_key not in self.quota_tracker: | |
| return False | |
| tracker = self.quota_tracker[api_key] | |
| return tracker['calls'] >= self.quota_limit_per_hour | |
| def _rotate_key(self): | |
| """Rotate to next API key""" | |
| if len(self.api_keys) <= 1: | |
| return self.api_keys[0] | |
| self.key_rotation_queue.rotate(-1) | |
| next_key = self.key_rotation_queue[0] | |
| self.initialize_client(next_key) | |
| logging.info(f"Rotated to key: {next_key[:10]}...") | |
| return next_key | |
| def _get_best_key(self): | |
| """Get the best available key (lowest errors, within quota)""" | |
| available_keys = [ | |
| key for key in self.api_keys | |
| if not self._is_quota_exceeded(key) | |
| ] | |
| if not available_keys: | |
| logging.warning("All keys have exceeded quota, resetting rotation") | |
| return self._rotate_key() | |
| # Return key with lowest error count | |
| best_key = min(available_keys, | |
| key=lambda k: self.quota_tracker[k]['errors']) | |
| return best_key | |
| def generate_response(self, prompt, model_name="Qwen3-235B", | |
| temperature=0.1, top_p=0.1, max_retries=3): | |
| """Generate content using SambaNova API with retry and rotation""" | |
| # Check cache first | |
| cache_key = self._get_cache_key(prompt, model_name, temperature, top_p) | |
| cached_response = self._get_cached_response(cache_key) | |
| if cached_response: | |
| return cached_response | |
| # Try with different keys | |
| attempted_keys = set() | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| # Get best available key | |
| current_key = self._get_best_key() | |
| if current_key in attempted_keys and len(attempted_keys) < len(self.api_keys): | |
| # Try next key if current one was already attempted | |
| current_key = self._rotate_key() | |
| attempted_keys.add(current_key) | |
| self.initialize_client(current_key) | |
| # Check quota before calling API | |
| if self._is_quota_exceeded(current_key): | |
| logging.warning(f"Key {current_key[:10]}... quota exceeded, rotating") | |
| retry_count += 1 | |
| continue | |
| # Call API | |
| response = self.client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "Bạn là một chuyên gia tư vấn pháp luật Việt Nam. Hãy cung cấp câu trả lời chính xác, chi tiết và có căn cứ pháp lý." | |
| }, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| # Update quota | |
| self._update_quota(current_key, is_error=False) | |
| if response and response.choices and len(response.choices) > 0: | |
| result = response.choices[0].message.content | |
| self._cache_response(cache_key, result) | |
| logging.info(f"Successfully generated response using key: {current_key[:10]}...") | |
| return result | |
| else: | |
| logging.error("SambaNova API returned empty response") | |
| self._update_quota(current_key, is_error=True) | |
| retry_count += 1 | |
| except Exception as e: | |
| current_key = self._get_best_key() | |
| self._update_quota(current_key, is_error=True) | |
| logging.warning(f"Error with key {current_key[:10]}...: {e}. Retry {retry_count + 1}/{max_retries}") | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| self._rotate_key() | |
| logging.error(f"Failed to generate response after {max_retries} retries") | |
| return "Lỗi khi gọi API sau nhiều lần thử. Vui lòng thử lại sau." | |
| def get_quota_status(self): | |
| """Get quota status for all API keys""" | |
| status = {} | |
| for key, tracker in self.quota_tracker.items(): | |
| status[f"{key[:10]}..."] = { | |
| 'calls': tracker['calls'], | |
| 'errors': tracker['errors'], | |
| 'limit': self.quota_limit_per_hour, | |
| 'remaining': max(0, self.quota_limit_per_hour - tracker['calls']), | |
| 'last_reset': tracker['last_reset'].isoformat() | |
| } | |
| return status | |
| def clear_cache(self): | |
| """Clear response cache""" | |
| self.response_cache.clear() | |
| logging.info("Response cache cleared") | |
| def get_cache_status(self): | |
| """Get cache statistics""" | |
| return { | |
| 'total_cached': len(self.response_cache), | |
| 'cache_ttl_seconds': self.cache_ttl | |
| } | |
| # Create global instance | |
| sambanova_service = SambaNovaService() | |
| # # SambaNova AI service | |
| # import logging | |
| # from sambanova import SambaNova | |
| # from utils.config import Config | |
| # class SambaNovaService: | |
| # def __init__(self): | |
| # self.client = None | |
| # self.initialize_client() | |
| # def initialize_client(self): | |
| # try: | |
| # self.client = SambaNova( | |
| # api_key=Config.SAMBANOVA_API_KEY, | |
| # base_url=Config.SAMBANOVA_BASE_URL, | |
| # ) | |
| # logging.info("SambaNova client initialized") | |
| # except Exception as e: | |
| # logging.error(f"Failed to initialize SambaNova client: {e}") | |
| # self.client = None | |
| # def generate_response(self, prompt, model_name="Qwen3-235B", temperature=0.1, top_p=0.1): | |
| # """Generate content using SambaNova API""" | |
| # if not self.client: | |
| # return "SambaNova client not initialized." | |
| # try: | |
| # response = self.client.chat.completions.create( | |
| # model=model_name, | |
| # messages=[ | |
| # { | |
| # "role": "system", | |
| # "content": "Bạn là một chuyên gia tư vấn pháp luật Việt Nam. Hãy cung cấp câu trả lời chính xác, chi tiết và có căn cứ pháp lý." | |
| # }, | |
| # {"role": "user", "content": prompt} | |
| # ], | |
| # temperature=temperature, | |
| # top_p=top_p | |
| # ) | |
| # if response and response.choices and len(response.choices) > 0: | |
| # return response.choices[0].message.content | |
| # else: | |
| # logging.error("SambaNova API returned empty response") | |
| # return "Không có phản hồi từ mô hình." | |
| # except Exception as e: | |
| # logging.error(f"Error calling SambaNova API: {e}") | |
| # return f"Lỗi khi gọi API: {str(e)}" | |
| # # Tạo instance toàn cục | |
| # sambanova_service = SambaNovaService() |