shndap commited on
Commit
ab1ae1b
·
1 Parent(s): f77aa91

Refactor query and prompt generation in EndToEndRAG class to support multiple-choice questions and improve context formatting. Simplify user input handling in app.py by passing raw user text to the retrieval function.

Browse files
Files changed (2) hide show
  1. app.py +2 -18
  2. end_to_end_class.py +44 -40
app.py CHANGED
@@ -35,25 +35,9 @@ def respond(
35
  user_text = message if isinstance(message, str) else str(message)
36
  img_input = image_url if isinstance(image_url, str) and image_url.strip() else None
37
 
38
- # Build a Persian prompt aligned with the notebook style
39
- sys_prefix = system_message if isinstance(system_message, str) and system_message.strip() else "تو یک دستیار پاسخ‌گوی دقیق به زبان فارسی هستی."
40
-
41
- user_desc_parts = []
42
- if user_text and user_text.strip():
43
- user_desc_parts.append(f"پرسش متنی: {user_text.strip()}")
44
- if img_input:
45
- user_desc_parts.append(f"لینک تصویر: {img_input}")
46
-
47
- prompt = (
48
- f"{sys_prefix} "
49
- "از زمینهٔ زیر برای پاسخ استفاده کن و اگر کافی نبود، صراحتاً اعلام کن. "
50
- "از حدس‌زدن بپرهیز و در صورت امکان به منبع اشاره کن.\n\n"
51
- f"جزئیات ورودی کاربر:\n- {' | '.join(user_desc_parts) if user_desc_parts else 'نامشخص'}\n\n"
52
- "پاسخ نهایی فارسی، موجز و مستدل:"
53
- )
54
-
55
  try:
56
- answer = rag_instance.query(text=prompt, image_url=img_input)
 
57
  yield answer
58
  except Exception as e:
59
  yield f"Error while generating answer: {e}"
 
35
  user_text = message if isinstance(message, str) else str(message)
36
  img_input = image_url if isinstance(image_url, str) and image_url.strip() else None
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
+ # Pass only the raw user text to retrieval so CLIP stays within its 77-token limit
40
+ answer = rag_instance.query(text=user_text, image_url=img_input)
41
  yield answer
42
  except Exception as e:
43
  yield f"Error while generating answer: {e}"
end_to_end_class.py CHANGED
@@ -125,13 +125,13 @@ class EndToEndRAG:
125
 
126
  return instance
127
 
128
- def query(self, text: Optional[str], image_url: Optional[str]) -> str:
129
  if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""):
130
  return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید."
131
 
132
  retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k)
133
- prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved)
134
- answer = self._generate(prompt)
135
  return answer
136
 
137
  def _load_index(
@@ -288,56 +288,60 @@ class EndToEndRAG:
288
  text: Optional[str],
289
  image_url: Optional[str],
290
  retrieved: List[Dict[str, Any]],
 
291
  ) -> str:
292
- context_blocks: List[str] = []
293
- for item in retrieved:
294
- parts = []
295
- if item.get("title"):
296
- parts.append(f"عنوان: {item['title']}")
297
- if item.get("biography"):
298
- parts.append(f"متن: {item['biography']}")
299
- elif item.get("text"):
300
- parts.append(f"متن: {item['text']}")
301
- if item.get("image_urls"):
302
- parts.append(f"تصاویر: {', '.join(item['image_urls'])}")
303
- if item.get("image_path"):
304
- parts.append(f"تصویر: {item['image_path']}")
305
- if item.get("combined_similarity") is not None:
306
- parts.append(f"امتیاز شباهت: {item['combined_similarity']:.3f}")
307
- context_blocks.append("\n".join(parts))
308
-
309
- context_str = "\n\n".join(context_blocks) if context_blocks else "(بدون محتوای بازیابی‌شده)"
310
-
311
- user_query_desc = []
312
- if text and text.strip():
313
- user_query_desc.append(f"پرسش متنی: {text.strip()}")
314
- if image_url and image_url.strip():
315
- user_query_desc.append(f"لینک تصویر: {image_url.strip()}")
 
 
 
316
 
 
317
  prompt = (
318
- "تو یک دستیار پاسخ‌گوی دقیق به زبان فارسی هستی. "
319
- "از زمینهٔ زیر برای پاسخ استفاده کن و اگر کافی نبود، صراحتاً اعلام کن. "
320
- "از حدس‌زدن بپرهیز و به منابع اشاره کن.\n\n"
321
- f"اطلاعات بازیابی‌شده:\n{context_str}\n\n"
322
- f"جزئیات ورودی کاربر:\n- {' | '.join(user_query_desc) if user_query_desc else 'نامشخص'}\n\n"
323
- "پاسخ نهایی فارسی، موجز و مستدل:"
324
  )
325
  return prompt
326
 
327
- def _generate(self, prompt: str) -> str:
328
  if self.inference_client is None:
329
  return (
330
  "سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید."
331
  )
 
 
 
332
  try:
333
- # Prefer chat completion when available
334
  chat = self.inference_client.chat_completion(
335
  messages=[
336
  {"role": "system", "content": "You are a helpful assistant."},
337
  {"role": "user", "content": prompt},
338
  ],
339
- max_tokens=self.max_new_tokens,
340
- temperature=self.temperature,
341
  stream=False,
342
  )
343
  if chat and getattr(chat, "choices", None):
@@ -350,9 +354,9 @@ class EndToEndRAG:
350
  try:
351
  out = self.inference_client.text_generation(
352
  prompt,
353
- max_new_tokens=self.max_new_tokens,
354
- temperature=self.temperature,
355
- do_sample=self.temperature > 0,
356
  return_full_text=False,
357
  details=False,
358
  stream=False,
 
125
 
126
  return instance
127
 
128
+ def query(self, text: Optional[str], image_url: Optional[str], options: Optional[List[str]] = None) -> str:
129
  if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""):
130
  return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید."
131
 
132
  retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k)
133
+ prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved, options=options)
134
+ answer = self._generate(prompt, is_mcq=bool(options), options=options)
135
  return answer
136
 
137
  def _load_index(
 
288
  text: Optional[str],
289
  image_url: Optional[str],
290
  retrieved: List[Dict[str, Any]],
291
+ options: Optional[List[str]] = None,
292
  ) -> str:
293
+ # Notebook-style context formatting
294
+ parts: List[str] = []
295
+ for i, item in enumerate(retrieved, start=1):
296
+ parts.append(f"Person {i}:")
297
+ bio = item.get("biography") or item.get("text") or ""
298
+ parts.append(f"Biography: {bio}")
299
+ imgs = item.get("image_urls") or []
300
+ if imgs:
301
+ parts.append(f"Image URLs: {', '.join(imgs)}")
302
+ score = item.get("combined_similarity")
303
+ if score is not None:
304
+ parts.append(f"Relevance Score: {float(score):.3f}")
305
+ parts.append("---")
306
+ context = "\n".join(parts) if parts else "(no retrieved content)"
307
+
308
+ user_q = text.strip() if text else ""
309
+
310
+ if options:
311
+ options_text = "\n".join([f"{i}: {opt}" for i, opt in enumerate(options)])
312
+ prompt = (
313
+ f"Retrieved Information:\n{context}\n\n"
314
+ f"Question: {user_q}\n\n"
315
+ f"Options:\n{options_text}\n\n"
316
+ "Output ONLY the chosen option number in the format \"Choice: [number]\". Do not include any other text.\n"
317
+ "Choice:"
318
+ )
319
+ return prompt
320
 
321
+ # Free-form answer
322
  prompt = (
323
+ f"Retrieved Information:\n{context}\n\n"
324
+ f"Question: {user_q}\n\n"
325
+ "Answer in concise Persian:"
 
 
 
326
  )
327
  return prompt
328
 
329
+ def _generate(self, prompt: str, is_mcq: bool, options: Optional[List[str]]) -> str:
330
  if self.inference_client is None:
331
  return (
332
  "سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید."
333
  )
334
+ max_new = 10 if is_mcq else self.max_new_tokens
335
+ temp = 0.1 if is_mcq else self.temperature
336
+ # Prefer chat
337
  try:
 
338
  chat = self.inference_client.chat_completion(
339
  messages=[
340
  {"role": "system", "content": "You are a helpful assistant."},
341
  {"role": "user", "content": prompt},
342
  ],
343
+ max_tokens=max_new,
344
+ temperature=temp,
345
  stream=False,
346
  )
347
  if chat and getattr(chat, "choices", None):
 
354
  try:
355
  out = self.inference_client.text_generation(
356
  prompt,
357
+ max_new_tokens=max_new,
358
+ temperature=temp,
359
+ do_sample=temp > 0,
360
  return_full_text=False,
361
  details=False,
362
  stream=False,