multimodal-persian-QA / end_to_end_class.py
shndap's picture
Remove unused Top-p slider from app.py and adjust top_k parameter in EndToEndRAG class from 5 to 1 for improved performance and simplicity.
1690356
import os
import json
import math
import urllib.request
from io import BytesIO
from typing import Any, Dict, List, Optional
import numpy as np
from PIL import Image
try:
import torch
from transformers import CLIPModel, CLIPProcessor
import faiss # type: ignore
from huggingface_hub import hf_hub_download, InferenceClient
except Exception as import_error: # pragma: no cover
raise RuntimeError(
"Required packages not found. Please install: torch, transformers, pillow, faiss-cpu, huggingface_hub"
) from import_error
class EndToEndRAG:
"""
End-to-end multimodal RAG system using local CLIP + FAISS retrieval and remote generation via Inference API.
"""
def __init__(
self,
clip_model_name: str = "aaalaaa/multimodal-face-clip",
generator_model_name: Optional[str] = "google/gemma-2b",
index_path: Optional[str] = None,
doc_embeddings_path: Optional[str] = None,
doc_metadata_path: Optional[str] = None,
device: Optional[str] = None,
text_weight: float = 0.7,
image_weight: float = 0.3,
top_k: int = 1,
max_new_tokens: int = 10,
temperature: float = 0.1,
) -> None:
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.text_weight = float(text_weight)
self.image_weight = float(image_weight)
self.top_k = int(top_k)
self.max_new_tokens = int(max_new_tokens)
self.temperature = float(temperature)
if not math.isclose(self.text_weight + self.image_weight, 1.0, rel_tol=1e-6):
raise ValueError("text_weight + image_weight must equal 1.0")
# Models: CLIP
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
self.clip_model.eval()
# Inference client for generation (remote)
self.inference_client: Optional[InferenceClient] = None
if generator_model_name:
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
model_name = os.environ.get("HF_INFERENCE_MODEL", generator_model_name)
self.inference_client = InferenceClient(model=model_name, token=hf_token)
# Two-index stores
self.text_index: Optional[faiss.Index] = None
self.image_index: Optional[faiss.Index] = None
self.metadata: List[Dict[str, Any]] = []
self.id_to_original: Dict[str, Dict[str, Any]] = {}
# Single-index store
self.index: Optional[faiss.Index] = None
self.doc_embeddings: Optional[np.ndarray] = None
self.doc_metadata: List[Dict[str, Any]] = []
# Load local single-index mode if provided
self._load_index(index_path, doc_embeddings_path, doc_metadata_path)
@classmethod
def default(
cls,
hf_token: Optional[str] = None,
text_weight: float = 0.7,
image_weight: float = 0.3,
top_k: int = 1,
max_new_tokens: int = 10,
temperature: float = 0.1,
device: Optional[str] = None,
) -> "EndToEndRAG":
instance = cls(
clip_model_name="aaalaaa/multimodal-face-clip",
generator_model_name="google/gemma-2b",
index_path=None,
doc_embeddings_path=None,
doc_metadata_path=None,
device=device,
text_weight=text_weight,
image_weight=image_weight,
top_k=top_k,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
# Download indices and metadata via HF Hub
token = hf_token or os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
text_index_path = hf_hub_download(
repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/text_index.faiss", token=token
)
image_index_path = hf_hub_download(
repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/image_index.faiss", token=token
)
metadata_path = hf_hub_download(
repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/metadata.json", token=token
)
original_path = hf_hub_download(
repo_id="aaalaaa/multimodal-face-clip", filename="saved_data.json", token=token
)
instance.text_index = faiss.read_index(text_index_path)
instance.image_index = faiss.read_index(image_index_path)
with open(metadata_path, "r", encoding="utf-8") as f:
instance.metadata = json.load(f)
with open(original_path, "r", encoding="utf-8") as f:
original_data = json.load(f)
instance.id_to_original = {str(item.get("id")): item for item in original_data}
return instance
def query(self, text: Optional[str], image_url: Optional[str], options: Optional[List[str]] = None) -> str:
if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""):
return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید."
retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k)
prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved, options=options)
answer = self._generate(prompt, is_mcq=bool(options), options=options)
return answer
def _load_index(
self,
index_path: Optional[str],
doc_embeddings_path: Optional[str],
doc_metadata_path: Optional[str],
) -> None:
if index_path and os.path.exists(index_path):
self.index = faiss.read_index(index_path)
if doc_embeddings_path and os.path.exists(doc_embeddings_path):
self.doc_embeddings = np.load(doc_embeddings_path)
if doc_metadata_path and os.path.exists(doc_metadata_path):
with open(doc_metadata_path, "r", encoding="utf-8") as f:
self.doc_metadata = json.load(f)
if self.index is None and self.doc_embeddings is not None:
self._normalize_inplace(self.doc_embeddings)
dim = int(self.doc_embeddings.shape[1])
self.index = faiss.IndexFlatIP(dim)
self.index.add(self.doc_embeddings.astype(np.float32))
if self.index is None:
self.index = None
self.doc_embeddings = None
self.doc_metadata = []
@torch.no_grad()
def _encode_text(self, text: str) -> np.ndarray:
inputs = self.clip_processor(text=[text], images=None, return_tensors="pt", padding=True).to(self.device)
text_features = self.clip_model.get_text_features(**{k: v for k, v in inputs.items() if k.startswith("input_")})
text_features = torch.nn.functional.normalize(text_features, p=2, dim=-1)
return text_features.detach().cpu().numpy()[0]
@torch.no_grad()
def _encode_image(self, image: Image.Image) -> np.ndarray:
inputs = self.clip_processor(text=None, images=image, return_tensors="pt").to(self.device)
image_features = self.clip_model.get_image_features(**{k: v for k, v in inputs.items() if k.startswith("pixel_")})
image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1)
return image_features.detach().cpu().numpy()[0]
def _retrieve(
self,
text: Optional[str],
image_url: Optional[str],
top_k: int,
) -> List[Dict[str, Any]]:
has_two_indices = self.text_index is not None and self.image_index is not None and len(self.metadata) > 0
query_vectors: List[np.ndarray] = []
weights: List[float] = []
if text and text.strip():
query_vectors.append(self._encode_text(text.strip()))
weights.append(self.text_weight)
if image_url and image_url.strip():
image = self._load_image(image_url.strip())
if image is not None:
query_vectors.append(self._encode_image(image))
weights.append(self.image_weight)
if not query_vectors:
return []
if has_two_indices:
stacked = np.stack(query_vectors).astype(np.float32)
weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1)
combined = (stacked * weights_arr).sum(axis=0)
combined = self._normalize(combined).reshape(1, -1).astype(np.float32)
text_scores, text_indices = self.text_index.search(combined, max(top_k * 3, top_k))
image_scores, image_indices = self.image_index.search(combined, max(top_k * 3, top_k))
results: Dict[str, Dict[str, Any]] = {}
for score, idx in zip(text_scores[0], text_indices[0]):
if idx < 0 or idx >= len(self.metadata):
continue
meta = self.metadata[idx]
if meta.get("type") != "text":
continue
pid = str(meta.get("id"))
entry = results.setdefault(
pid,
{"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0},
)
entry["text_similarity"] = float(score)
entry["combined_similarity"] += float(score) * self.text_weight
for score, idx in zip(image_scores[0], image_indices[0]):
if idx < 0 or idx >= len(self.metadata):
continue
meta = self.metadata[idx]
if meta.get("type") != "image":
continue
pid = str(meta.get("id"))
entry = results.setdefault(
pid,
{"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0},
)
entry["image_similarity"] = float(score)
entry["combined_similarity"] += float(score) * self.image_weight
ranked = sorted(results.values(), key=lambda x: x["combined_similarity"], reverse=True)
final: List[Dict[str, Any]] = []
for rank, res in enumerate(ranked[:top_k], start=1):
original = self.id_to_original.get(res["id"], {})
final.append(
{
"id": res["id"],
"rank": rank,
"text_similarity": res["text_similarity"],
"image_similarity": res["image_similarity"],
"combined_similarity": res["combined_similarity"],
"biography": original.get("cleaned_bio", ""),
"image_urls": original.get("images", []),
}
)
return final
if self.index is None or self.doc_embeddings is None or len(self.doc_metadata) == 0:
return []
stacked = np.stack(query_vectors).astype(np.float32)
weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1)
weighted = (stacked * weights_arr).sum(axis=0)
weighted = self._normalize(weighted)
query = weighted.reshape(1, -1).astype(np.float32)
scores, indices = self.index.search(query, top_k)
scores = scores[0]
indices = indices[0]
results: List[Dict[str, Any]] = []
for rank, (idx, score) in enumerate(zip(indices, scores)):
if idx < 0 or idx >= len(self.doc_metadata):
continue
meta = self.doc_metadata[idx]
results.append(
{
"id": meta.get("id", str(idx)),
"rank": int(rank + 1),
"score": float(score),
"title": meta.get("title", ""),
"text": meta.get("text", ""),
"image_path": meta.get("image_path"),
"metadata": meta,
}
)
return results
def _build_prompt(
self,
text: Optional[str],
image_url: Optional[str],
retrieved: List[Dict[str, Any]],
options: Optional[List[str]] = None,
) -> str:
# Notebook-style context formatting
parts: List[str] = []
for i, item in enumerate(retrieved, start=1):
parts.append(f"Person {i}:")
bio = item.get("biography") or item.get("text") or ""
parts.append(f"Biography: {bio}")
imgs = item.get("image_urls") or []
if imgs:
parts.append(f"Image URLs: {', '.join(imgs)}")
score = item.get("combined_similarity")
if score is not None:
parts.append(f"Relevance Score: {float(score):.3f}")
parts.append("---")
context = "\n".join(parts) if parts else "(no retrieved content)"
user_q = text.strip() if text else ""
if options:
options_text = "\n".join([f"{i}: {opt}" for i, opt in enumerate(options)])
prompt = (
f"Retrieved Information:\n{context}\n\n"
f"Question: {user_q}\n\n"
f"Options:\n{options_text}\n\n"
"Output ONLY the chosen option number in the format \"Choice: [number]\". Do not include any other text.\n"
"Choice:"
)
return prompt
# Free-form answer
prompt = (
f"Retrieved Information:\n{context}\n\n"
f"Question: {user_q}\n\n"
"Answer in concise Persian:"
)
return prompt
def _generate(self, prompt: str, is_mcq: bool, options: Optional[List[str]]) -> str:
if self.inference_client is None:
return (
"سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید."
)
max_new = 10 if is_mcq else self.max_new_tokens
temp = 0.1 if is_mcq else self.temperature
# Prefer chat
try:
chat = self.inference_client.chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
max_tokens=max_new,
temperature=temp,
stream=False,
)
if chat and getattr(chat, "choices", None):
content = getattr(chat.choices[0].message, "content", "")
if isinstance(content, str) and content.strip():
return content.strip()
except Exception:
pass
# Fallback to text generation
try:
out = self.inference_client.text_generation(
prompt,
max_new_tokens=max_new,
temperature=temp,
do_sample=temp > 0,
return_full_text=False,
details=False,
stream=False,
)
if isinstance(out, str) and out.strip():
return out.strip()
gen = getattr(out, "generated_text", None)
if isinstance(gen, str) and gen.strip():
return gen.strip()
return ""
except Exception as e:
return f"خطا در تولید پاسخ: {type(e).__name__}: {e}"
@staticmethod
def _normalize(v: np.ndarray) -> np.ndarray:
denom = np.linalg.norm(v) + 1e-12
return (v / denom).astype(np.float32)
@staticmethod
def _normalize_inplace(mat: np.ndarray) -> None:
norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12
mat /= norms
@staticmethod
def _load_image(image_url: str) -> Optional[Image.Image]:
try:
if image_url.startswith("http://") or image_url.startswith("https://"):
with urllib.request.urlopen(image_url, timeout=10) as resp:
data = resp.read()
return Image.open(BytesIO(data)).convert("RGB")
if os.path.exists(image_url):
return Image.open(image_url).convert("RGB")
except Exception:
return None
return None