Spaces:
Sleeping
Sleeping
Refactor query and prompt generation in EndToEndRAG class to support multiple-choice questions and improve context formatting. Simplify user input handling in app.py by passing raw user text to the retrieval function.
Browse files- app.py +2 -18
- end_to_end_class.py +44 -40
app.py
CHANGED
|
@@ -35,25 +35,9 @@ def respond(
|
|
| 35 |
user_text = message if isinstance(message, str) else str(message)
|
| 36 |
img_input = image_url if isinstance(image_url, str) and image_url.strip() else None
|
| 37 |
|
| 38 |
-
# Build a Persian prompt aligned with the notebook style
|
| 39 |
-
sys_prefix = system_message if isinstance(system_message, str) and system_message.strip() else "تو یک دستیار پاسخگوی دقیق به زبان فارسی هستی."
|
| 40 |
-
|
| 41 |
-
user_desc_parts = []
|
| 42 |
-
if user_text and user_text.strip():
|
| 43 |
-
user_desc_parts.append(f"پرسش متنی: {user_text.strip()}")
|
| 44 |
-
if img_input:
|
| 45 |
-
user_desc_parts.append(f"لینک تصویر: {img_input}")
|
| 46 |
-
|
| 47 |
-
prompt = (
|
| 48 |
-
f"{sys_prefix} "
|
| 49 |
-
"از زمینهٔ زیر برای پاسخ استفاده کن و اگر کافی نبود، صراحتاً اعلام کن. "
|
| 50 |
-
"از حدسزدن بپرهیز و در صورت امکان به منبع اشاره کن.\n\n"
|
| 51 |
-
f"جزئیات ورودی کاربر:\n- {' | '.join(user_desc_parts) if user_desc_parts else 'نامشخص'}\n\n"
|
| 52 |
-
"پاسخ نهایی فارسی، موجز و مستدل:"
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
try:
|
| 56 |
-
|
|
|
|
| 57 |
yield answer
|
| 58 |
except Exception as e:
|
| 59 |
yield f"Error while generating answer: {e}"
|
|
|
|
| 35 |
user_text = message if isinstance(message, str) else str(message)
|
| 36 |
img_input = image_url if isinstance(image_url, str) and image_url.strip() else None
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
try:
|
| 39 |
+
# Pass only the raw user text to retrieval so CLIP stays within its 77-token limit
|
| 40 |
+
answer = rag_instance.query(text=user_text, image_url=img_input)
|
| 41 |
yield answer
|
| 42 |
except Exception as e:
|
| 43 |
yield f"Error while generating answer: {e}"
|
end_to_end_class.py
CHANGED
|
@@ -125,13 +125,13 @@ class EndToEndRAG:
|
|
| 125 |
|
| 126 |
return instance
|
| 127 |
|
| 128 |
-
def query(self, text: Optional[str], image_url: Optional[str]) -> str:
|
| 129 |
if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""):
|
| 130 |
return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید."
|
| 131 |
|
| 132 |
retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k)
|
| 133 |
-
prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved)
|
| 134 |
-
answer = self._generate(prompt)
|
| 135 |
return answer
|
| 136 |
|
| 137 |
def _load_index(
|
|
@@ -288,56 +288,60 @@ class EndToEndRAG:
|
|
| 288 |
text: Optional[str],
|
| 289 |
image_url: Optional[str],
|
| 290 |
retrieved: List[Dict[str, Any]],
|
|
|
|
| 291 |
) -> str:
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
parts.append(f"
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
| 316 |
|
|
|
|
| 317 |
prompt = (
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
| 321 |
-
f"اطلاعات بازیابیشده:\n{context_str}\n\n"
|
| 322 |
-
f"جزئیات ورودی کاربر:\n- {' | '.join(user_query_desc) if user_query_desc else 'نامشخص'}\n\n"
|
| 323 |
-
"پاسخ نهایی فارسی، موجز و مستدل:"
|
| 324 |
)
|
| 325 |
return prompt
|
| 326 |
|
| 327 |
-
def _generate(self, prompt: str) -> str:
|
| 328 |
if self.inference_client is None:
|
| 329 |
return (
|
| 330 |
"سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید."
|
| 331 |
)
|
|
|
|
|
|
|
|
|
|
| 332 |
try:
|
| 333 |
-
# Prefer chat completion when available
|
| 334 |
chat = self.inference_client.chat_completion(
|
| 335 |
messages=[
|
| 336 |
{"role": "system", "content": "You are a helpful assistant."},
|
| 337 |
{"role": "user", "content": prompt},
|
| 338 |
],
|
| 339 |
-
max_tokens=
|
| 340 |
-
temperature=
|
| 341 |
stream=False,
|
| 342 |
)
|
| 343 |
if chat and getattr(chat, "choices", None):
|
|
@@ -350,9 +354,9 @@ class EndToEndRAG:
|
|
| 350 |
try:
|
| 351 |
out = self.inference_client.text_generation(
|
| 352 |
prompt,
|
| 353 |
-
max_new_tokens=
|
| 354 |
-
temperature=
|
| 355 |
-
do_sample=
|
| 356 |
return_full_text=False,
|
| 357 |
details=False,
|
| 358 |
stream=False,
|
|
|
|
| 125 |
|
| 126 |
return instance
|
| 127 |
|
| 128 |
+
def query(self, text: Optional[str], image_url: Optional[str], options: Optional[List[str]] = None) -> str:
|
| 129 |
if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""):
|
| 130 |
return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید."
|
| 131 |
|
| 132 |
retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k)
|
| 133 |
+
prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved, options=options)
|
| 134 |
+
answer = self._generate(prompt, is_mcq=bool(options), options=options)
|
| 135 |
return answer
|
| 136 |
|
| 137 |
def _load_index(
|
|
|
|
| 288 |
text: Optional[str],
|
| 289 |
image_url: Optional[str],
|
| 290 |
retrieved: List[Dict[str, Any]],
|
| 291 |
+
options: Optional[List[str]] = None,
|
| 292 |
) -> str:
|
| 293 |
+
# Notebook-style context formatting
|
| 294 |
+
parts: List[str] = []
|
| 295 |
+
for i, item in enumerate(retrieved, start=1):
|
| 296 |
+
parts.append(f"Person {i}:")
|
| 297 |
+
bio = item.get("biography") or item.get("text") or ""
|
| 298 |
+
parts.append(f"Biography: {bio}")
|
| 299 |
+
imgs = item.get("image_urls") or []
|
| 300 |
+
if imgs:
|
| 301 |
+
parts.append(f"Image URLs: {', '.join(imgs)}")
|
| 302 |
+
score = item.get("combined_similarity")
|
| 303 |
+
if score is not None:
|
| 304 |
+
parts.append(f"Relevance Score: {float(score):.3f}")
|
| 305 |
+
parts.append("---")
|
| 306 |
+
context = "\n".join(parts) if parts else "(no retrieved content)"
|
| 307 |
+
|
| 308 |
+
user_q = text.strip() if text else ""
|
| 309 |
+
|
| 310 |
+
if options:
|
| 311 |
+
options_text = "\n".join([f"{i}: {opt}" for i, opt in enumerate(options)])
|
| 312 |
+
prompt = (
|
| 313 |
+
f"Retrieved Information:\n{context}\n\n"
|
| 314 |
+
f"Question: {user_q}\n\n"
|
| 315 |
+
f"Options:\n{options_text}\n\n"
|
| 316 |
+
"Output ONLY the chosen option number in the format \"Choice: [number]\". Do not include any other text.\n"
|
| 317 |
+
"Choice:"
|
| 318 |
+
)
|
| 319 |
+
return prompt
|
| 320 |
|
| 321 |
+
# Free-form answer
|
| 322 |
prompt = (
|
| 323 |
+
f"Retrieved Information:\n{context}\n\n"
|
| 324 |
+
f"Question: {user_q}\n\n"
|
| 325 |
+
"Answer in concise Persian:"
|
|
|
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
return prompt
|
| 328 |
|
| 329 |
+
def _generate(self, prompt: str, is_mcq: bool, options: Optional[List[str]]) -> str:
|
| 330 |
if self.inference_client is None:
|
| 331 |
return (
|
| 332 |
"سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید."
|
| 333 |
)
|
| 334 |
+
max_new = 10 if is_mcq else self.max_new_tokens
|
| 335 |
+
temp = 0.1 if is_mcq else self.temperature
|
| 336 |
+
# Prefer chat
|
| 337 |
try:
|
|
|
|
| 338 |
chat = self.inference_client.chat_completion(
|
| 339 |
messages=[
|
| 340 |
{"role": "system", "content": "You are a helpful assistant."},
|
| 341 |
{"role": "user", "content": prompt},
|
| 342 |
],
|
| 343 |
+
max_tokens=max_new,
|
| 344 |
+
temperature=temp,
|
| 345 |
stream=False,
|
| 346 |
)
|
| 347 |
if chat and getattr(chat, "choices", None):
|
|
|
|
| 354 |
try:
|
| 355 |
out = self.inference_client.text_generation(
|
| 356 |
prompt,
|
| 357 |
+
max_new_tokens=max_new,
|
| 358 |
+
temperature=temp,
|
| 359 |
+
do_sample=temp > 0,
|
| 360 |
return_full_text=False,
|
| 361 |
details=False,
|
| 362 |
stream=False,
|