ccm commited on
Commit
5f7408d
·
verified ·
1 Parent(s): 4f9a5bb

Create proxy.py

Browse files
Files changed (1) hide show
  1. proxy.py +937 -0
proxy.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI-compatible FastAPI proxy that wraps a smolagents CodeAgent
3
+ """
4
+
5
+ import os # For dealing with env vars
6
+ import re # For tag stripping
7
+ import json # For JSON handling
8
+ import time # For timestamps and sleeps
9
+ import asyncio # For async operations
10
+ import typing # For type annotations
11
+ import logging # For logging
12
+ import threading # For threading operations
13
+
14
+ import fastapi
15
+ import fastapi.responses
16
+ import io
17
+ import contextlib
18
+
19
+ # smolagents + OpenAI-compatible model wrapper
20
+ import smolagents
21
+ import smolagents.models
22
+
23
+ # Upstream pass-through
24
+ import httpx
25
+
26
+ # Logging setup
27
+ logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
28
+ log = logging.getLogger(__name__)
29
+
30
+ # Config from env vars
31
+ UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
32
+ HF_TOKEN = (
33
+ os.getenv("HF_TOKEN")
34
+ or os.getenv("HUGGINGFACEHUB_API_TOKEN")
35
+ or os.getenv("API_TOKEN")
36
+ or ""
37
+ )
38
+ AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen3-1.7B")
39
+
40
+ if not UPSTREAM_BASE:
41
+ log.warning(
42
+ "UPSTREAM_OPENAI_BASE is empty; OpenAI-compatible upstream calls will fail."
43
+ )
44
+ if not HF_TOKEN:
45
+ log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.")
46
+
47
+ # ================== Agent ====================
48
+ llm = smolagents.models.OpenAIServerModel(
49
+ model_id=AGENT_MODEL,
50
+ api_base=UPSTREAM_BASE,
51
+ api_key=HF_TOKEN,
52
+ )
53
+ agent = smolagents.CodeAgent(
54
+ model=llm,
55
+ tools=[], # no extra tools
56
+ add_base_tools=False,
57
+ max_steps=4,
58
+ verbosity_level=int(
59
+ os.getenv("AGENT_VERBOSITY", "1")
60
+ ), # quieter by default; override via env
61
+ )
62
+
63
+ # ================== FastAPI ==================
64
+ app = fastapi.FastAPI()
65
+
66
+
67
+ @app.get("/healthz")
68
+ async def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ # ---------- OpenAI-compatible minimal schemas ----------
73
+ class ChatMessage(typing.TypedDict, total=False):
74
+ role: str
75
+ content: typing.Any # str or multimodal list
76
+
77
+
78
+ class ChatCompletionRequest(typing.TypedDict, total=False):
79
+ model: typing.Optional[str]
80
+ messages: typing.List[ChatMessage]
81
+ temperature: typing.Optional[float]
82
+ stream: typing.Optional[bool]
83
+ max_tokens: typing.Optional[int]
84
+
85
+
86
+ # ---------- Helpers ----------
87
+ def normalize_content_to_text(content: typing.Any) -> str:
88
+ if isinstance(content, str):
89
+ return content
90
+ if isinstance(content, (bytes, bytearray)):
91
+ try:
92
+ return content.decode("utf-8", errors="ignore")
93
+ except Exception:
94
+ return str(content)
95
+ if isinstance(content, list):
96
+ parts = []
97
+ for item in content:
98
+ if (
99
+ isinstance(item, dict)
100
+ and item.get("type") == "text"
101
+ and isinstance(item.get("text"), str)
102
+ ):
103
+ parts.append(item["text"])
104
+ else:
105
+ try:
106
+ parts.append(json.dumps(item, ensure_ascii=False))
107
+ except Exception:
108
+ parts.append(str(item))
109
+ return "\n".join(parts)
110
+ if isinstance(content, dict):
111
+ try:
112
+ return json.dumps(content, ensure_ascii=False)
113
+ except Exception:
114
+ return str(content)
115
+ return str(content)
116
+
117
+
118
+ def _messages_to_task(messages: typing.List[ChatMessage]) -> str:
119
+ system_parts = [
120
+ normalize_content_to_text(m.get("content", ""))
121
+ for m in messages
122
+ if m.get("role") == "system"
123
+ ]
124
+ user_parts = [
125
+ normalize_content_to_text(m.get("content", ""))
126
+ for m in messages
127
+ if m.get("role") == "user"
128
+ ]
129
+ assistant_parts = [
130
+ normalize_content_to_text(m.get("content", ""))
131
+ for m in messages
132
+ if m.get("role") == "assistant"
133
+ ]
134
+
135
+ sys_txt = "\n".join([s for s in system_parts if s]).strip()
136
+ history = ""
137
+ if assistant_parts:
138
+ history = "\n\nPrevious assistant replies (for context):\n" + "\n---\n".join(
139
+ assistant_parts
140
+ )
141
+
142
+ last_user = user_parts[-1] if user_parts else ""
143
+ prefix = (
144
+ "You are a very small agent with only a Python REPL tool.\n"
145
+ "Prefer short, correct answers. If Python is unnecessary, just answer plainly.\n"
146
+ "If you do use Python, print only final results—no extra logs.\n"
147
+ )
148
+ if sys_txt:
149
+ prefix = f"{sys_txt}\n\n{prefix}"
150
+ return f"{prefix}\nTask:\n{last_user}\n{history}".strip()
151
+
152
+
153
+ def _openai_response(
154
+ message_text: str, model_name: str
155
+ ) -> typing.Dict[str, typing.Any]:
156
+ now = int(time.time())
157
+ return {
158
+ "id": f"chatcmpl-smol-{now}",
159
+ "object": "chat.completion",
160
+ "created": now,
161
+ "model": model_name,
162
+ "choices": [
163
+ {
164
+ "index": 0,
165
+ "message": {"role": "assistant", "content": message_text},
166
+ "finish_reason": "stop",
167
+ }
168
+ ],
169
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
170
+ }
171
+
172
+
173
+ def _sse_headers() -> dict:
174
+ return {
175
+ "Cache-Control": "no-cache, no-transform",
176
+ "Connection": "keep-alive",
177
+ "X-Accel-Buffering": "no",
178
+ }
179
+
180
+
181
+ # ---------- Sanitizer: remove think/thank tags from LLM-originated text ----------
182
+ _THINK_TAG_RE = re.compile(r"</?\s*think\b[^>]*>", flags=re.IGNORECASE)
183
+ _THANK_TAG_RE = re.compile(r"</?\s*thank\b[^>]*>", flags=re.IGNORECASE) # typo safety
184
+ _ESC_THINK_TAG_RE = re.compile(r"&lt;/?\s*think\b[^&]*&gt;", flags=re.IGNORECASE)
185
+ _ESC_THANK_TAG_RE = re.compile(r"&lt;/?\s*thank\b[^&]*&gt;", flags=re.IGNORECASE)
186
+
187
+
188
+ def scrub_think_tags(text: typing.Any) -> str:
189
+ """
190
+ Remove literal and HTML-escaped <think> / </think> (and <thank> variants) tags.
191
+ Content inside the tags is preserved; only the tags are stripped.
192
+ """
193
+ if not isinstance(text, str):
194
+ try:
195
+ text = str(text)
196
+ except Exception:
197
+ return ""
198
+ t = _THINK_TAG_RE.sub("", text)
199
+ t = _THANK_TAG_RE.sub("", t)
200
+ t = _ESC_THINK_TAG_RE.sub("", t)
201
+ t = _ESC_THANK_TAG_RE.sub("", t)
202
+ return t
203
+
204
+
205
+ # ---------- Reasoning formatting for Chat-UI ----------
206
+ def _format_reasoning_chunk(text: str, tag: str, idx: int) -> str:
207
+ """
208
+ Lightweight formatter for reasoning stream. Avoid huge code fences;
209
+ make it readable and incremental. Also filters out ASCII/box-drawing noise.
210
+ """
211
+ text = scrub_think_tags(text).rstrip("\n")
212
+ if not text:
213
+ return ""
214
+ noisy_prefixes = (
215
+ "OpenAIServerModel",
216
+ "Output message of the LLM",
217
+ "─ Executing parsed code",
218
+ "New run",
219
+ "╭",
220
+ "╰",
221
+ "│",
222
+ "━",
223
+ "─",
224
+ )
225
+ stripped = text.strip()
226
+ if not stripped:
227
+ return ""
228
+ # Lines made mostly of box drawing/separators
229
+ if all(ch in " ─━╭╮╰╯│═·—-_=+•" for ch in stripped):
230
+ return ""
231
+ if any(stripped.startswith(p) for p in noisy_prefixes):
232
+ return ""
233
+ # Excessively long lines with little signal (no alphanumerics)
234
+ if len(stripped) > 240 and not re.search(r"[A-Za-z0-9]{3,}", stripped):
235
+ return ""
236
+ # No tag/idx prefix; add a trailing blank line for readability in markdown
237
+ return f"{stripped}\n\n"
238
+
239
+
240
+ def _extract_final_text(item: typing.Any) -> typing.Optional[str]:
241
+ if isinstance(item, dict) and ("__stdout__" in item or "__step__" in item):
242
+ return None
243
+ if isinstance(item, (bytes, bytearray)):
244
+ try:
245
+ item = item.decode("utf-8", errors="ignore")
246
+ except Exception:
247
+ item = str(item)
248
+ if isinstance(item, str):
249
+ s = scrub_think_tags(item.strip())
250
+ return s or None
251
+ # If it's a step-like object with an 'output' attribute, use that
252
+ try:
253
+ if not isinstance(item, (dict, list, bytes, bytearray)):
254
+ out = getattr(item, "output", None)
255
+ if out is not None:
256
+ s = scrub_think_tags(str(out)).strip()
257
+ if s:
258
+ return s
259
+ except Exception:
260
+ pass
261
+ if isinstance(item, dict):
262
+ for key in ("content", "text", "message", "output", "final", "answer"):
263
+ if key in item:
264
+ val = item[key]
265
+ if isinstance(val, (dict, list)):
266
+ try:
267
+ return scrub_think_tags(json.dumps(val, ensure_ascii=False))
268
+ except Exception:
269
+ return scrub_think_tags(str(val))
270
+ if isinstance(val, (bytes, bytearray)):
271
+ try:
272
+ val = val.decode("utf-8", errors="ignore")
273
+ except Exception:
274
+ val = str(val)
275
+ s = scrub_think_tags(str(val).strip())
276
+ return s or None
277
+ try:
278
+ return scrub_think_tags(json.dumps(item, ensure_ascii=False))
279
+ except Exception:
280
+ return scrub_think_tags(str(item))
281
+ try:
282
+ return scrub_think_tags(str(item))
283
+ except Exception:
284
+ return None
285
+
286
+
287
+ # Helper to parse explicit "Final answer:" from stdout lines
288
+ _FINAL_RE = re.compile(r"(?:^|\\b)Final\\s+answer:\\s*(.+)$", flags=re.IGNORECASE)
289
+
290
+
291
+ def _maybe_parse_final_from_stdout(line: str) -> typing.Optional[str]:
292
+ if not isinstance(line, str):
293
+ return None
294
+ m = _FINAL_RE.search(line.strip())
295
+ if not m:
296
+ return None
297
+ return scrub_think_tags(m.group(1)).strip() or None
298
+
299
+
300
+ # ---------- Live stdout/stderr tee ----------
301
+ class QueueWriter(io.TextIOBase):
302
+ """
303
+ File-like object that pushes each write to an asyncio.Queue immediately.
304
+ """
305
+
306
+ def __init__(self, q: "asyncio.Queue"):
307
+ self.q = q
308
+ self._lock = threading.Lock()
309
+ self._buf = [] # accumulate until newline to reduce spam
310
+
311
+ def write(self, s: str):
312
+ if not s:
313
+ return 0
314
+ with self._lock:
315
+ self._buf.append(s)
316
+ # flush on newline to keep granularity reasonable
317
+ if "\n" in s:
318
+ chunk = "".join(self._buf)
319
+ self._buf.clear()
320
+ try:
321
+ self.q.put_nowait({"__stdout__": chunk})
322
+ except Exception:
323
+ pass
324
+ return len(s)
325
+
326
+ def flush(self):
327
+ with self._lock:
328
+ if self._buf:
329
+ chunk = "".join(self._buf)
330
+ self._buf.clear()
331
+ try:
332
+ self.q.put_nowait({"__stdout__": chunk})
333
+ except Exception:
334
+ pass
335
+
336
+
337
+ def _serialize_step(step) -> str:
338
+ """
339
+ Best-effort pretty string for a smolagents MemoryStep / ActionStep.
340
+ Works even if attributes are missing on some versions.
341
+ """
342
+ parts = []
343
+ sn = getattr(step, "step_number", None)
344
+ if sn is not None:
345
+ parts.append(f"Step {sn}")
346
+ thought_val = getattr(step, "thought", None)
347
+ if thought_val:
348
+ parts.append(f"Thought: {scrub_think_tags(str(thought_val))}")
349
+ tool_val = getattr(step, "tool", None)
350
+ if tool_val:
351
+ parts.append(f"Tool: {scrub_think_tags(str(tool_val))}")
352
+ code_val = getattr(step, "code", None)
353
+ if code_val:
354
+ code_str = scrub_think_tags(str(code_val)).strip()
355
+ parts.append("```python\n" + code_str + "\n```")
356
+ args = getattr(step, "args", None)
357
+ if args:
358
+ try:
359
+ parts.append(
360
+ "Args: " + scrub_think_tags(json.dumps(args, ensure_ascii=False))
361
+ )
362
+ except Exception:
363
+ parts.append("Args: " + scrub_think_tags(str(args)))
364
+ error = getattr(step, "error", None)
365
+ if error:
366
+ parts.append(f"Error: {scrub_think_tags(str(error))}")
367
+ obs = getattr(step, "observations", None)
368
+ if obs is not None:
369
+ if isinstance(obs, (list, tuple)):
370
+ obs_str = "\n".join(map(str, obs))
371
+ else:
372
+ obs_str = str(obs)
373
+ parts.append("Observation:\n" + scrub_think_tags(obs_str).strip())
374
+ # If this looks like a FinalAnswer step object, surface a clean final answer
375
+ try:
376
+ tname = type(step).__name__
377
+ except Exception:
378
+ tname = ""
379
+ if tname.lower().startswith("finalanswer"):
380
+ out = getattr(step, "output", None)
381
+ if out is not None:
382
+ return f"Final answer: {scrub_think_tags(str(out)).strip()}"
383
+ # Fallback: try to parse from string repr "FinalAnswerStep(output=...)"
384
+ s = scrub_think_tags(str(step))
385
+ m = re.search(r"FinalAnswer[^()]*\(\s*output\s*=\s*([^,)]+)", s)
386
+ if m:
387
+ return f"Final answer: {m.group(1).strip()}"
388
+ # If the only content would be an object repr like FinalAnswerStep(...), drop it;
389
+ # a cleaner "Final answer: ..." will come from the rule above or stdout.
390
+ joined = "\n".join(parts).strip()
391
+ if re.match(r"^FinalAnswer[^\n]+\)$", joined):
392
+ return ""
393
+ return joined or scrub_think_tags(str(step))
394
+
395
+
396
+ # ---------- Agent streaming bridge (truly live) ----------
397
+ async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
398
+ """
399
+ Start the agent in a worker thread.
400
+ Stream THREE sources of incremental data into the async generator:
401
+ (1) live stdout/stderr lines,
402
+ (2) newly appended memory steps (polled),
403
+ (3) any iterable the agent may yield (if supported).
404
+ Finally emit a __final__ item with the last answer.
405
+ """
406
+ loop = asyncio.get_running_loop()
407
+ q: asyncio.Queue = asyncio.Queue()
408
+ agent_to_use = agent_obj or agent
409
+
410
+ stop_evt = threading.Event()
411
+
412
+ # 1) stdout/stderr live tee
413
+ qwriter = QueueWriter(q)
414
+
415
+ # 2) memory poller
416
+ def poll_memory():
417
+ last_len = 0
418
+ while not stop_evt.is_set():
419
+ try:
420
+ steps = []
421
+ try:
422
+ # Common API: agent.memory.get_full_steps()
423
+ steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined]
424
+ except Exception:
425
+ # Fallbacks: different names across versions
426
+ steps = (
427
+ getattr(agent_to_use, "steps", [])
428
+ or getattr(agent_to_use, "memory", [])
429
+ or []
430
+ )
431
+ if steps is None:
432
+ steps = []
433
+ curr_len = len(steps)
434
+ if curr_len > last_len:
435
+ new = steps[last_len:curr_len]
436
+ last_len = curr_len
437
+ for s in new:
438
+ s_text = _serialize_step(s)
439
+ if s_text:
440
+ try:
441
+ q.put_nowait({"__step__": s_text})
442
+ except Exception:
443
+ pass
444
+ except Exception:
445
+ pass
446
+ time.sleep(0.10) # 100 ms cadence
447
+
448
+ # 3) agent runner (may or may not yield)
449
+ def run_agent():
450
+ final_result = None
451
+ try:
452
+ with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
453
+ qwriter
454
+ ):
455
+ used_iterable = False
456
+ if hasattr(agent_to_use, "run") and callable(
457
+ getattr(agent_to_use, "run")
458
+ ):
459
+ try:
460
+ res = agent_to_use.run(task, stream=True)
461
+ if hasattr(res, "__iter__") and not isinstance(
462
+ res, (str, bytes)
463
+ ):
464
+ used_iterable = True
465
+ for it in res:
466
+ try:
467
+ q.put_nowait(it)
468
+ except Exception:
469
+ pass
470
+ final_result = (
471
+ None # iterable may already contain the answer
472
+ )
473
+ else:
474
+ final_result = res
475
+ except TypeError:
476
+ # run(stream=True) not supported -> fall back
477
+ pass
478
+
479
+ if final_result is None and not used_iterable:
480
+ # Try other common streaming signatures
481
+ for name in (
482
+ "run_stream",
483
+ "stream",
484
+ "stream_run",
485
+ "run_with_callback",
486
+ ):
487
+ fn = getattr(agent_to_use, name, None)
488
+ if callable(fn):
489
+ try:
490
+ res = fn(task)
491
+ if hasattr(res, "__iter__") and not isinstance(
492
+ res, (str, bytes)
493
+ ):
494
+ for it in res:
495
+ q.put_nowait(it)
496
+ final_result = None
497
+ else:
498
+ final_result = res
499
+ break
500
+ except TypeError:
501
+ # maybe callback signature
502
+ def cb(item):
503
+ try:
504
+ q.put_nowait(item)
505
+ except Exception:
506
+ pass
507
+
508
+ try:
509
+ fn(task, cb)
510
+ final_result = None
511
+ break
512
+ except Exception:
513
+ continue
514
+
515
+ if final_result is None and not used_iterable:
516
+ pass # (typo guard removed below)
517
+
518
+ if final_result is None and not used_iterable:
519
+ # Last resort: synchronous run()/generate()/callable
520
+ if hasattr(agent_to_use, "run") and callable(
521
+ getattr(agent_to_use, "run")
522
+ ):
523
+ final_result = agent_to_use.run(task)
524
+ elif hasattr(agent_to_use, "generate") and callable(
525
+ getattr(agent_to_use, "generate")
526
+ ):
527
+ final_result = agent_to_use.generate(task)
528
+ elif callable(agent_to_use):
529
+ final_result = agent_to_use(task)
530
+
531
+ except Exception as e:
532
+ try:
533
+ qwriter.flush()
534
+ except Exception:
535
+ pass
536
+ try:
537
+ q.put_nowait({"__error__": str(e)})
538
+ except Exception:
539
+ pass
540
+ finally:
541
+ try:
542
+ qwriter.flush()
543
+ except Exception:
544
+ pass
545
+ try:
546
+ q.put_nowait({"__final__": final_result})
547
+ except Exception:
548
+ pass
549
+ stop_evt.set()
550
+
551
+ # Kick off threads
552
+ mem_thread = threading.Thread(target=poll_memory, daemon=True)
553
+ run_thread = threading.Thread(target=run_agent, daemon=True)
554
+ mem_thread.start()
555
+ run_thread.start()
556
+
557
+ # Async consumer
558
+ while True:
559
+ item = await q.get()
560
+ yield item
561
+ if isinstance(item, dict) and "__final__" in item:
562
+ break
563
+
564
+
565
+ def _recursively_scrub(obj):
566
+ if isinstance(obj, str):
567
+ return scrub_think_tags(obj)
568
+ if isinstance(obj, dict):
569
+ return {k: _recursively_scrub(v) for k, v in obj.items()}
570
+ if isinstance(obj, list):
571
+ return [_recursively_scrub(v) for v in obj]
572
+ return obj
573
+
574
+
575
+ async def _proxy_upstream_chat_completions(
576
+ body: dict, stream: bool, scrub_think: bool = False
577
+ ):
578
+ if not UPSTREAM_BASE:
579
+ return fastapi.responses.JSONResponse(
580
+ {"error": {"message": "UPSTREAM_OPENAI_BASE not configured"}},
581
+ status_code=500,
582
+ )
583
+ headers = {
584
+ "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
585
+ "Content-Type": "application/json",
586
+ }
587
+ url = f"{UPSTREAM_BASE}/chat/completions"
588
+
589
+ if stream:
590
+
591
+ async def proxy_stream():
592
+ async with httpx.AsyncClient(timeout=None) as client:
593
+ async with client.stream(
594
+ "POST", url, headers=headers, json=body
595
+ ) as resp:
596
+ resp.raise_for_status()
597
+ if scrub_think:
598
+ # Pull text segments, scrub tags, and yield bytes
599
+ async for txt in resp.aiter_text():
600
+ try:
601
+ cleaned = scrub_think_tags(txt)
602
+ yield cleaned.encode("utf-8")
603
+ except Exception:
604
+ yield txt.encode("utf-8")
605
+ else:
606
+ async for chunk in resp.aiter_bytes():
607
+ yield chunk
608
+
609
+ return fastapi.responses.StreamingResponse(
610
+ proxy_stream(), media_type="text/event-stream", headers=_sse_headers()
611
+ )
612
+ else:
613
+ async with httpx.AsyncClient(timeout=None) as client:
614
+ r = await client.post(url, headers=headers, json=body)
615
+ try:
616
+ payload = r.json()
617
+ except Exception:
618
+ payload = {"status_code": r.status_code, "text": r.text}
619
+
620
+ if scrub_think:
621
+ try:
622
+ payload = _recursively_scrub(payload)
623
+ except Exception:
624
+ pass
625
+
626
+ return fastapi.responses.JSONResponse(
627
+ status_code=r.status_code, content=payload
628
+ )
629
+
630
+
631
+ # ---------- Endpoints ----------
632
+ @app.get("/v1/models")
633
+ async def list_models():
634
+ now = int(time.time())
635
+ return {
636
+ "object": "list",
637
+ "data": [
638
+ {
639
+ "id": "code-writing-agent",
640
+ "object": "model",
641
+ "created": now,
642
+ "owned_by": "you",
643
+ },
644
+ {
645
+ "id": AGENT_MODEL,
646
+ "object": "model",
647
+ "created": now,
648
+ "owned_by": "upstream",
649
+ },
650
+ {
651
+ "id": AGENT_MODEL + "-nothink",
652
+ "object": "model",
653
+ "created": now,
654
+ "owned_by": "upstream",
655
+ },
656
+ ],
657
+ }
658
+
659
+
660
+ @app.post("/v1/chat/completions")
661
+ async def chat_completions(req: fastapi.Request):
662
+ try:
663
+ body: ChatCompletionRequest = typing.cast(
664
+ ChatCompletionRequest, await req.json()
665
+ )
666
+ except Exception as e:
667
+ return fastapi.responses.JSONResponse(
668
+ {"error": {"message": f"Invalid JSON: {e}"}}, status_code=400
669
+ )
670
+
671
+ messages = body.get("messages") or []
672
+ stream = bool(body.get("stream", False))
673
+ raw_model = body.get("model")
674
+ model_name = (
675
+ raw_model.get("id")
676
+ if isinstance(raw_model, dict)
677
+ else (raw_model or "code-writing-agent")
678
+ )
679
+ # Pure pass-through if the user selects the upstream model id
680
+ if model_name == AGENT_MODEL:
681
+ return await _proxy_upstream_chat_completions(dict(body), stream)
682
+ if model_name == AGENT_MODEL + "-nothink":
683
+ # Remove "-nothink" from the model name in body
684
+ body["model"] = AGENT_MODEL
685
+
686
+ # Add /nothink to the end of the message contents to disable think tags
687
+ new_messages = []
688
+ for msg in messages:
689
+ if msg.get("role") == "user":
690
+ content = normalize_content_to_text(msg.get("content", ""))
691
+ content += "\n/nothink"
692
+ new_msg: ChatMessage = {
693
+ "role": "user",
694
+ "content": content,
695
+ }
696
+ new_messages.append(new_msg)
697
+ else:
698
+ new_messages.append(msg)
699
+ body["messages"] = new_messages
700
+ return await _proxy_upstream_chat_completions(
701
+ dict(body), stream, scrub_think=True
702
+ )
703
+
704
+ # Otherwise, reasoning-aware wrapper
705
+ task = _messages_to_task(messages)
706
+
707
+ # Per-request agent override if a custom model id was provided (different from defaults)
708
+ agent_for_request = None
709
+ if model_name not in (
710
+ "code-writing-agent",
711
+ AGENT_MODEL,
712
+ AGENT_MODEL + "-nothink",
713
+ ) and isinstance(model_name, str):
714
+ try:
715
+ req_llm = smolagents.models.OpenAIServerModel(
716
+ model_id=model_name, api_base=UPSTREAM_BASE, api_key=HF_TOKEN
717
+ )
718
+ agent_for_request = smolagents.CodeAgent(
719
+ model=req_llm,
720
+ tools=[],
721
+ add_base_tools=False,
722
+ max_steps=4,
723
+ verbosity_level=int(os.getenv("AGENT_VERBOSITY", "1")),
724
+ )
725
+ except Exception:
726
+ log.exception(
727
+ "Failed to construct agent for model '%s'; using default", model_name
728
+ )
729
+ agent_for_request = None
730
+
731
+ try:
732
+ if stream:
733
+
734
+ async def sse_streamer():
735
+ base = {
736
+ "id": f"chatcmpl-smol-{int(time.time())}",
737
+ "object": "chat.completion.chunk",
738
+ "created": int(time.time()),
739
+ "model": model_name,
740
+ "choices": [
741
+ {
742
+ "index": 0,
743
+ "delta": {"role": "assistant"},
744
+ "finish_reason": None,
745
+ }
746
+ ],
747
+ }
748
+ yield f"data: {json.dumps(base)}\n\n"
749
+
750
+ reasoning_idx = 0
751
+ final_candidate: typing.Optional[str] = None
752
+
753
+ async for item in run_agent_stream(task, agent_for_request):
754
+ # Error short-circuit
755
+ if isinstance(item, dict) and "__error__" in item:
756
+ error_chunk = {
757
+ **base,
758
+ "choices": [
759
+ {"index": 0, "delta": {}, "finish_reason": "error"}
760
+ ],
761
+ }
762
+ yield f"data: {json.dumps(error_chunk)}\n\n"
763
+ yield f"data: {json.dumps({'error': item['__error__']})}\n\n"
764
+ break
765
+
766
+ # Explicit final result from the agent
767
+ if isinstance(item, dict) and "__final__" in item:
768
+ val = item["__final__"]
769
+ cand = _extract_final_text(val)
770
+ # Only update if the agent actually provided a non-empty answer
771
+ if cand and cand.strip().lower() != "none":
772
+ final_candidate = cand
773
+ # do not emit anything yet; we'll send a single final chunk below
774
+ continue
775
+
776
+ # Live stdout -> reasoning_content
777
+ if (
778
+ isinstance(item, dict)
779
+ and "__stdout__" in item
780
+ and isinstance(item["__stdout__"], str)
781
+ ):
782
+ for line in item["__stdout__"].splitlines():
783
+ parsed = _maybe_parse_final_from_stdout(line)
784
+ if parsed:
785
+ final_candidate = parsed
786
+ rt = _format_reasoning_chunk(
787
+ line, "stdout", reasoning_idx := reasoning_idx + 1
788
+ )
789
+ if rt:
790
+ r_chunk = {
791
+ **base,
792
+ "choices": [
793
+ {"index": 0, "delta": {"reasoning_content": rt}}
794
+ ],
795
+ }
796
+ yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n"
797
+ continue
798
+
799
+ # Newly observed step -> reasoning_content
800
+ if (
801
+ isinstance(item, dict)
802
+ and "__step__" in item
803
+ and isinstance(item["__step__"], str)
804
+ ):
805
+ for line in item["__step__"].splitlines():
806
+ parsed = _maybe_parse_final_from_stdout(line)
807
+ if parsed:
808
+ final_candidate = parsed
809
+ rt = _format_reasoning_chunk(
810
+ line, "step", reasoning_idx := reasoning_idx + 1
811
+ )
812
+ if rt:
813
+ r_chunk = {
814
+ **base,
815
+ "choices": [
816
+ {"index": 0, "delta": {"reasoning_content": rt}}
817
+ ],
818
+ }
819
+ yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n"
820
+ continue
821
+
822
+ # Any iterable output from the agent (rare) — treat as candidate answer
823
+ cand = _extract_final_text(item)
824
+ if cand:
825
+ final_candidate = cand
826
+
827
+ await asyncio.sleep(0) # keep the loop fair
828
+
829
+ # Emit the visible answer once at the end (scrub any stray tags)
830
+ visible = scrub_think_tags(final_candidate or "")
831
+ if not visible or visible.strip().lower() == "none":
832
+ visible = "Done."
833
+ final_chunk = {
834
+ **base,
835
+ "choices": [{"index": 0, "delta": {"content": visible}}],
836
+ }
837
+ yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
838
+
839
+ stop_chunk = {
840
+ **base,
841
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
842
+ }
843
+ yield f"data: {json.dumps(stop_chunk)}\n\n"
844
+ yield "data: [DONE]\n\n"
845
+
846
+ return fastapi.responses.StreamingResponse(
847
+ sse_streamer(), media_type="text/event-stream", headers=_sse_headers()
848
+ )
849
+
850
+ else:
851
+ # Non-streaming: collect into <think>…</think> + final
852
+ reasoning_lines: typing.List[str] = []
853
+ final_candidate: typing.Optional[str] = None
854
+
855
+ async for item in run_agent_stream(task, agent_for_request):
856
+ if isinstance(item, dict) and "__error__" in item:
857
+ raise Exception(item["__error__"])
858
+
859
+ if isinstance(item, dict) and "__final__" in item:
860
+ val = item["__final__"]
861
+ cand = _extract_final_text(val)
862
+ if cand and cand.strip().lower() != "none":
863
+ final_candidate = cand
864
+ continue
865
+
866
+ if isinstance(item, dict) and "__stdout__" in item:
867
+ lines = (
868
+ scrub_think_tags(item["__stdout__"]).rstrip("\n").splitlines()
869
+ )
870
+ for line in lines:
871
+ parsed = _maybe_parse_final_from_stdout(line)
872
+ if parsed:
873
+ final_candidate = parsed
874
+ rt = _format_reasoning_chunk(
875
+ line, "stdout", len(reasoning_lines) + 1
876
+ )
877
+ if rt:
878
+ reasoning_lines.append(rt)
879
+ continue
880
+
881
+ if isinstance(item, dict) and "__step__" in item:
882
+ lines = scrub_think_tags(item["__step__"]).rstrip("\n").splitlines()
883
+ for line in lines:
884
+ parsed = _maybe_parse_final_from_stdout(line)
885
+ if parsed:
886
+ final_candidate = parsed
887
+ rt = _format_reasoning_chunk(
888
+ line, "step", len(reasoning_lines) + 1
889
+ )
890
+ if rt:
891
+ reasoning_lines.append(rt)
892
+ continue
893
+
894
+ cand = _extract_final_text(item)
895
+ if cand:
896
+ final_candidate = cand
897
+
898
+ reasoning_blob = "\n".join(reasoning_lines).strip()
899
+ if len(reasoning_blob) > 24000:
900
+ reasoning_blob = reasoning_blob[:24000] + "\n… [truncated]"
901
+ think_block = (
902
+ f"<think>\n{reasoning_blob}\n</think>\n" if reasoning_blob else ""
903
+ )
904
+ final_text = scrub_think_tags(final_candidate or "")
905
+ if not final_text or final_text.strip().lower() == "none":
906
+ final_text = "Done."
907
+ result_text = f"{think_block}{final_text}"
908
+
909
+ except Exception as e:
910
+ msg = str(e)
911
+ status = 503 if "503" in msg or "Service Unavailable" in msg else 500
912
+ log.error("Agent error (%s): %s", status, msg)
913
+ return fastapi.responses.JSONResponse(
914
+ status_code=status,
915
+ content={
916
+ "error": {"message": f"Agent error: {msg}", "type": "agent_error"}
917
+ },
918
+ )
919
+
920
+ # Non-streaming response
921
+ if result_text is None:
922
+ result_text = ""
923
+ if not isinstance(result_text, str):
924
+ try:
925
+ result_text = json.dumps(result_text, ensure_ascii=False)
926
+ except Exception:
927
+ result_text = str(result_text)
928
+ return fastapi.responses.JSONResponse(_openai_response(result_text, model_name))
929
+
930
+
931
+ # Optional: local run
932
+ if __name__ == "__main__":
933
+ import uvicorn
934
+
935
+ uvicorn.run(
936
+ "app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False
937
+ )