import os import json import math import urllib.request from io import BytesIO from typing import Any, Dict, List, Optional import numpy as np from PIL import Image try: import torch from transformers import CLIPModel, CLIPProcessor import faiss # type: ignore from huggingface_hub import hf_hub_download, InferenceClient except Exception as import_error: # pragma: no cover raise RuntimeError( "Required packages not found. Please install: torch, transformers, pillow, faiss-cpu, huggingface_hub" ) from import_error class EndToEndRAG: """ End-to-end multimodal RAG system using local CLIP + FAISS retrieval and remote generation via Inference API. """ def __init__( self, clip_model_name: str = "aaalaaa/multimodal-face-clip", generator_model_name: Optional[str] = "google/gemma-2b", index_path: Optional[str] = None, doc_embeddings_path: Optional[str] = None, doc_metadata_path: Optional[str] = None, device: Optional[str] = None, text_weight: float = 0.7, image_weight: float = 0.3, top_k: int = 1, max_new_tokens: int = 10, temperature: float = 0.1, ) -> None: self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.text_weight = float(text_weight) self.image_weight = float(image_weight) self.top_k = int(top_k) self.max_new_tokens = int(max_new_tokens) self.temperature = float(temperature) if not math.isclose(self.text_weight + self.image_weight, 1.0, rel_tol=1e-6): raise ValueError("text_weight + image_weight must equal 1.0") # Models: CLIP self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name) self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device) self.clip_model.eval() # Inference client for generation (remote) self.inference_client: Optional[InferenceClient] = None if generator_model_name: hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN") model_name = os.environ.get("HF_INFERENCE_MODEL", generator_model_name) self.inference_client = InferenceClient(model=model_name, token=hf_token) # Two-index stores self.text_index: Optional[faiss.Index] = None self.image_index: Optional[faiss.Index] = None self.metadata: List[Dict[str, Any]] = [] self.id_to_original: Dict[str, Dict[str, Any]] = {} # Single-index store self.index: Optional[faiss.Index] = None self.doc_embeddings: Optional[np.ndarray] = None self.doc_metadata: List[Dict[str, Any]] = [] # Load local single-index mode if provided self._load_index(index_path, doc_embeddings_path, doc_metadata_path) @classmethod def default( cls, hf_token: Optional[str] = None, text_weight: float = 0.7, image_weight: float = 0.3, top_k: int = 1, max_new_tokens: int = 10, temperature: float = 0.1, device: Optional[str] = None, ) -> "EndToEndRAG": instance = cls( clip_model_name="aaalaaa/multimodal-face-clip", generator_model_name="google/gemma-2b", index_path=None, doc_embeddings_path=None, doc_metadata_path=None, device=device, text_weight=text_weight, image_weight=image_weight, top_k=top_k, max_new_tokens=max_new_tokens, temperature=temperature, ) # Download indices and metadata via HF Hub token = hf_token or os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN") text_index_path = hf_hub_download( repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/text_index.faiss", token=token ) image_index_path = hf_hub_download( repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/image_index.faiss", token=token ) metadata_path = hf_hub_download( repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/metadata.json", token=token ) original_path = hf_hub_download( repo_id="aaalaaa/multimodal-face-clip", filename="saved_data.json", token=token ) instance.text_index = faiss.read_index(text_index_path) instance.image_index = faiss.read_index(image_index_path) with open(metadata_path, "r", encoding="utf-8") as f: instance.metadata = json.load(f) with open(original_path, "r", encoding="utf-8") as f: original_data = json.load(f) instance.id_to_original = {str(item.get("id")): item for item in original_data} return instance def query(self, text: Optional[str], image_url: Optional[str], options: Optional[List[str]] = None) -> str: if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""): return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید." retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k) prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved, options=options) answer = self._generate(prompt, is_mcq=bool(options), options=options) return answer def _load_index( self, index_path: Optional[str], doc_embeddings_path: Optional[str], doc_metadata_path: Optional[str], ) -> None: if index_path and os.path.exists(index_path): self.index = faiss.read_index(index_path) if doc_embeddings_path and os.path.exists(doc_embeddings_path): self.doc_embeddings = np.load(doc_embeddings_path) if doc_metadata_path and os.path.exists(doc_metadata_path): with open(doc_metadata_path, "r", encoding="utf-8") as f: self.doc_metadata = json.load(f) if self.index is None and self.doc_embeddings is not None: self._normalize_inplace(self.doc_embeddings) dim = int(self.doc_embeddings.shape[1]) self.index = faiss.IndexFlatIP(dim) self.index.add(self.doc_embeddings.astype(np.float32)) if self.index is None: self.index = None self.doc_embeddings = None self.doc_metadata = [] @torch.no_grad() def _encode_text(self, text: str) -> np.ndarray: inputs = self.clip_processor(text=[text], images=None, return_tensors="pt", padding=True).to(self.device) text_features = self.clip_model.get_text_features(**{k: v for k, v in inputs.items() if k.startswith("input_")}) text_features = torch.nn.functional.normalize(text_features, p=2, dim=-1) return text_features.detach().cpu().numpy()[0] @torch.no_grad() def _encode_image(self, image: Image.Image) -> np.ndarray: inputs = self.clip_processor(text=None, images=image, return_tensors="pt").to(self.device) image_features = self.clip_model.get_image_features(**{k: v for k, v in inputs.items() if k.startswith("pixel_")}) image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1) return image_features.detach().cpu().numpy()[0] def _retrieve( self, text: Optional[str], image_url: Optional[str], top_k: int, ) -> List[Dict[str, Any]]: has_two_indices = self.text_index is not None and self.image_index is not None and len(self.metadata) > 0 query_vectors: List[np.ndarray] = [] weights: List[float] = [] if text and text.strip(): query_vectors.append(self._encode_text(text.strip())) weights.append(self.text_weight) if image_url and image_url.strip(): image = self._load_image(image_url.strip()) if image is not None: query_vectors.append(self._encode_image(image)) weights.append(self.image_weight) if not query_vectors: return [] if has_two_indices: stacked = np.stack(query_vectors).astype(np.float32) weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1) combined = (stacked * weights_arr).sum(axis=0) combined = self._normalize(combined).reshape(1, -1).astype(np.float32) text_scores, text_indices = self.text_index.search(combined, max(top_k * 3, top_k)) image_scores, image_indices = self.image_index.search(combined, max(top_k * 3, top_k)) results: Dict[str, Dict[str, Any]] = {} for score, idx in zip(text_scores[0], text_indices[0]): if idx < 0 or idx >= len(self.metadata): continue meta = self.metadata[idx] if meta.get("type") != "text": continue pid = str(meta.get("id")) entry = results.setdefault( pid, {"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0}, ) entry["text_similarity"] = float(score) entry["combined_similarity"] += float(score) * self.text_weight for score, idx in zip(image_scores[0], image_indices[0]): if idx < 0 or idx >= len(self.metadata): continue meta = self.metadata[idx] if meta.get("type") != "image": continue pid = str(meta.get("id")) entry = results.setdefault( pid, {"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0}, ) entry["image_similarity"] = float(score) entry["combined_similarity"] += float(score) * self.image_weight ranked = sorted(results.values(), key=lambda x: x["combined_similarity"], reverse=True) final: List[Dict[str, Any]] = [] for rank, res in enumerate(ranked[:top_k], start=1): original = self.id_to_original.get(res["id"], {}) final.append( { "id": res["id"], "rank": rank, "text_similarity": res["text_similarity"], "image_similarity": res["image_similarity"], "combined_similarity": res["combined_similarity"], "biography": original.get("cleaned_bio", ""), "image_urls": original.get("images", []), } ) return final if self.index is None or self.doc_embeddings is None or len(self.doc_metadata) == 0: return [] stacked = np.stack(query_vectors).astype(np.float32) weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1) weighted = (stacked * weights_arr).sum(axis=0) weighted = self._normalize(weighted) query = weighted.reshape(1, -1).astype(np.float32) scores, indices = self.index.search(query, top_k) scores = scores[0] indices = indices[0] results: List[Dict[str, Any]] = [] for rank, (idx, score) in enumerate(zip(indices, scores)): if idx < 0 or idx >= len(self.doc_metadata): continue meta = self.doc_metadata[idx] results.append( { "id": meta.get("id", str(idx)), "rank": int(rank + 1), "score": float(score), "title": meta.get("title", ""), "text": meta.get("text", ""), "image_path": meta.get("image_path"), "metadata": meta, } ) return results def _build_prompt( self, text: Optional[str], image_url: Optional[str], retrieved: List[Dict[str, Any]], options: Optional[List[str]] = None, ) -> str: # Notebook-style context formatting parts: List[str] = [] for i, item in enumerate(retrieved, start=1): parts.append(f"Person {i}:") bio = item.get("biography") or item.get("text") or "" parts.append(f"Biography: {bio}") imgs = item.get("image_urls") or [] if imgs: parts.append(f"Image URLs: {', '.join(imgs)}") score = item.get("combined_similarity") if score is not None: parts.append(f"Relevance Score: {float(score):.3f}") parts.append("---") context = "\n".join(parts) if parts else "(no retrieved content)" user_q = text.strip() if text else "" if options: options_text = "\n".join([f"{i}: {opt}" for i, opt in enumerate(options)]) prompt = ( f"Retrieved Information:\n{context}\n\n" f"Question: {user_q}\n\n" f"Options:\n{options_text}\n\n" "Output ONLY the chosen option number in the format \"Choice: [number]\". Do not include any other text.\n" "Choice:" ) return prompt # Free-form answer prompt = ( f"Retrieved Information:\n{context}\n\n" f"Question: {user_q}\n\n" "Answer in concise Persian:" ) return prompt def _generate(self, prompt: str, is_mcq: bool, options: Optional[List[str]]) -> str: if self.inference_client is None: return ( "سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید." ) max_new = 10 if is_mcq else self.max_new_tokens temp = 0.1 if is_mcq else self.temperature # Prefer chat try: chat = self.inference_client.chat_completion( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], max_tokens=max_new, temperature=temp, stream=False, ) if chat and getattr(chat, "choices", None): content = getattr(chat.choices[0].message, "content", "") if isinstance(content, str) and content.strip(): return content.strip() except Exception: pass # Fallback to text generation try: out = self.inference_client.text_generation( prompt, max_new_tokens=max_new, temperature=temp, do_sample=temp > 0, return_full_text=False, details=False, stream=False, ) if isinstance(out, str) and out.strip(): return out.strip() gen = getattr(out, "generated_text", None) if isinstance(gen, str) and gen.strip(): return gen.strip() return "" except Exception as e: return f"خطا در تولید پاسخ: {type(e).__name__}: {e}" @staticmethod def _normalize(v: np.ndarray) -> np.ndarray: denom = np.linalg.norm(v) + 1e-12 return (v / denom).astype(np.float32) @staticmethod def _normalize_inplace(mat: np.ndarray) -> None: norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12 mat /= norms @staticmethod def _load_image(image_url: str) -> Optional[Image.Image]: try: if image_url.startswith("http://") or image_url.startswith("https://"): with urllib.request.urlopen(image_url, timeout=10) as resp: data = resp.read() return Image.open(BytesIO(data)).convert("RGB") if os.path.exists(image_url): return Image.open(image_url).convert("RGB") except Exception: return None return None