LegalMind12 / services /sambanova.py
Nguyendat92929's picture
Update services/sambanova.py
7c527c8 verified
import logging
import hashlib
import json
from datetime import datetime, timedelta
from collections import deque
from sambanova import SambaNova
from utils.config import Config
class SambaNovaService:
def __init__(self):
self.api_keys = self._load_api_keys()
self.key_index = 0
self.client = None
self.key_rotation_queue = deque(self.api_keys)
# Caching
self.response_cache = {}
self.cache_ttl = 3600 # 1 hour
# Quota tracking
self.quota_tracker = {key: {'calls': 0, 'errors': 0, 'last_reset': datetime.now()}
for key in self.api_keys}
self.quota_limit_per_hour = 100
self.initialize_client()
def _load_api_keys(self):
"""Load API keys from Config"""
import os
keys_str = os.getenv('SAMBANOVA_API_KEYS', '')
if not keys_str:
return [Config.SAMBANOVA_API_KEY]
keys_list = [key.strip() for key in keys_str.split(',') if key.strip()]
return keys_list if keys_list else [Config.SAMBANOVA_API_KEY]
def initialize_client(self, api_key=None):
"""Initialize SambaNova client with specific or rotated API key"""
try:
key = api_key or self.api_keys[self.key_index]
self.client = SambaNova(
api_key=key,
base_url=Config.SAMBANOVA_BASE_URL,
)
logging.info(f"SambaNova client initialized with key: {key[:10]}...")
except Exception as e:
logging.error(f"Failed to initialize SambaNova client: {e}")
self.client = None
def _get_cache_key(self, prompt, model_name, temperature, top_p):
"""Generate cache key from parameters"""
key_str = f"{prompt}_{model_name}_{temperature}_{top_p}"
return hashlib.md5(key_str.encode()).hexdigest()
def _get_cached_response(self, cache_key):
"""Get response from cache if valid"""
if cache_key in self.response_cache:
cached_data = self.response_cache[cache_key]
if datetime.now() < cached_data['expires_at']:
logging.info(f"Cache hit for key: {cache_key}")
return cached_data['response']
else:
del self.response_cache[cache_key]
return None
def _cache_response(self, cache_key, response):
"""Cache response with TTL"""
self.response_cache[cache_key] = {
'response': response,
'expires_at': datetime.now() + timedelta(seconds=self.cache_ttl),
'cached_at': datetime.now()
}
def _update_quota(self, api_key, is_error=False):
"""Update quota tracking for API key"""
if api_key not in self.quota_tracker:
self.quota_tracker[api_key] = {'calls': 0, 'errors': 0, 'last_reset': datetime.now()}
tracker = self.quota_tracker[api_key]
# Reset if 1 hour passed
if datetime.now() - tracker['last_reset'] > timedelta(hours=1):
tracker['calls'] = 0
tracker['errors'] = 0
tracker['last_reset'] = datetime.now()
if is_error:
tracker['errors'] += 1
else:
tracker['calls'] += 1
logging.debug(f"Quota for {api_key[:10]}...: {tracker['calls']} calls, {tracker['errors']} errors")
def _is_quota_exceeded(self, api_key):
"""Check if API key quota is exceeded"""
if api_key not in self.quota_tracker:
return False
tracker = self.quota_tracker[api_key]
return tracker['calls'] >= self.quota_limit_per_hour
def _rotate_key(self):
"""Rotate to next API key"""
if len(self.api_keys) <= 1:
return self.api_keys[0]
self.key_rotation_queue.rotate(-1)
next_key = self.key_rotation_queue[0]
self.initialize_client(next_key)
logging.info(f"Rotated to key: {next_key[:10]}...")
return next_key
def _get_best_key(self):
"""Get the best available key (lowest errors, within quota)"""
available_keys = [
key for key in self.api_keys
if not self._is_quota_exceeded(key)
]
if not available_keys:
logging.warning("All keys have exceeded quota, resetting rotation")
return self._rotate_key()
# Return key with lowest error count
best_key = min(available_keys,
key=lambda k: self.quota_tracker[k]['errors'])
return best_key
def generate_response(self, prompt, model_name="Qwen3-235B",
temperature=0.1, top_p=0.1, max_retries=3):
"""Generate content using SambaNova API with retry and rotation"""
# Check cache first
cache_key = self._get_cache_key(prompt, model_name, temperature, top_p)
cached_response = self._get_cached_response(cache_key)
if cached_response:
return cached_response
# Try with different keys
attempted_keys = set()
retry_count = 0
while retry_count < max_retries:
try:
# Get best available key
current_key = self._get_best_key()
if current_key in attempted_keys and len(attempted_keys) < len(self.api_keys):
# Try next key if current one was already attempted
current_key = self._rotate_key()
attempted_keys.add(current_key)
self.initialize_client(current_key)
# Check quota before calling API
if self._is_quota_exceeded(current_key):
logging.warning(f"Key {current_key[:10]}... quota exceeded, rotating")
retry_count += 1
continue
# Call API
response = self.client.chat.completions.create(
model=model_name,
messages=[
{
"role": "system",
"content": "Bạn là một chuyên gia tư vấn pháp luật Việt Nam. Hãy cung cấp câu trả lời chính xác, chi tiết và có căn cứ pháp lý."
},
{"role": "user", "content": prompt}
],
temperature=temperature,
top_p=top_p
)
# Update quota
self._update_quota(current_key, is_error=False)
if response and response.choices and len(response.choices) > 0:
result = response.choices[0].message.content
self._cache_response(cache_key, result)
logging.info(f"Successfully generated response using key: {current_key[:10]}...")
return result
else:
logging.error("SambaNova API returned empty response")
self._update_quota(current_key, is_error=True)
retry_count += 1
except Exception as e:
current_key = self._get_best_key()
self._update_quota(current_key, is_error=True)
logging.warning(f"Error with key {current_key[:10]}...: {e}. Retry {retry_count + 1}/{max_retries}")
retry_count += 1
if retry_count < max_retries:
self._rotate_key()
logging.error(f"Failed to generate response after {max_retries} retries")
return "Lỗi khi gọi API sau nhiều lần thử. Vui lòng thử lại sau."
def get_quota_status(self):
"""Get quota status for all API keys"""
status = {}
for key, tracker in self.quota_tracker.items():
status[f"{key[:10]}..."] = {
'calls': tracker['calls'],
'errors': tracker['errors'],
'limit': self.quota_limit_per_hour,
'remaining': max(0, self.quota_limit_per_hour - tracker['calls']),
'last_reset': tracker['last_reset'].isoformat()
}
return status
def clear_cache(self):
"""Clear response cache"""
self.response_cache.clear()
logging.info("Response cache cleared")
def get_cache_status(self):
"""Get cache statistics"""
return {
'total_cached': len(self.response_cache),
'cache_ttl_seconds': self.cache_ttl
}
# Create global instance
sambanova_service = SambaNovaService()
# # SambaNova AI service
# import logging
# from sambanova import SambaNova
# from utils.config import Config
# class SambaNovaService:
# def __init__(self):
# self.client = None
# self.initialize_client()
# def initialize_client(self):
# try:
# self.client = SambaNova(
# api_key=Config.SAMBANOVA_API_KEY,
# base_url=Config.SAMBANOVA_BASE_URL,
# )
# logging.info("SambaNova client initialized")
# except Exception as e:
# logging.error(f"Failed to initialize SambaNova client: {e}")
# self.client = None
# def generate_response(self, prompt, model_name="Qwen3-235B", temperature=0.1, top_p=0.1):
# """Generate content using SambaNova API"""
# if not self.client:
# return "SambaNova client not initialized."
# try:
# response = self.client.chat.completions.create(
# model=model_name,
# messages=[
# {
# "role": "system",
# "content": "Bạn là một chuyên gia tư vấn pháp luật Việt Nam. Hãy cung cấp câu trả lời chính xác, chi tiết và có căn cứ pháp lý."
# },
# {"role": "user", "content": prompt}
# ],
# temperature=temperature,
# top_p=top_p
# )
# if response and response.choices and len(response.choices) > 0:
# return response.choices[0].message.content
# else:
# logging.error("SambaNova API returned empty response")
# return "Không có phản hồi từ mô hình."
# except Exception as e:
# logging.error(f"Error calling SambaNova API: {e}")
# return f"Lỗi khi gọi API: {str(e)}"
# # Tạo instance toàn cục
# sambanova_service = SambaNovaService()