ccm commited on
Commit
99ca1a9
·
1 Parent(s): 9573471

Separating agent

Browse files
Files changed (3) hide show
  1. agents/__init__.py +0 -0
  2. agents/code_writing_agent.py +19 -0
  3. proxy.py +5 -32
agents/__init__.py ADDED
File without changes
agents/code_writing_agent.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import smolagents
3
+ import smolagents.models
4
+
5
+
6
+ def create_code_writing_agent():
7
+ return smolagents.CodeAgent(
8
+ model=smolagents.models.OpenAIServerModel(
9
+ model_id=os.getenv("AGENT_MODEL", ""),
10
+ api_base=os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/"),
11
+ api_key=os.getenv("OPENAI_API_KEY"),
12
+ ),
13
+ tools=[], # no extra tools
14
+ add_base_tools=False,
15
+ max_steps=4,
16
+ verbosity_level=int(
17
+ os.getenv("AGENT_VERBOSITY", "1")
18
+ ), # quieter by default; override via env
19
+ )
proxy.py CHANGED
@@ -16,13 +16,11 @@ import fastapi.responses
16
  import io
17
  import contextlib
18
 
19
- # smolagents + OpenAI-compatible model wrapper
20
- import smolagents
21
- import smolagents.models
22
-
23
  # Upstream pass-through
24
  import httpx
25
 
 
 
26
  # Logging setup
27
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
28
  log = logging.getLogger(__name__)
@@ -44,22 +42,6 @@ if not UPSTREAM_BASE:
44
  if not HF_TOKEN:
45
  log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.")
46
 
47
- # ================== Agent ====================
48
- llm = smolagents.models.OpenAIServerModel(
49
- model_id=AGENT_MODEL,
50
- api_base=UPSTREAM_BASE,
51
- api_key=HF_TOKEN,
52
- )
53
- agent = smolagents.CodeAgent(
54
- model=llm,
55
- tools=[], # no extra tools
56
- add_base_tools=False,
57
- max_steps=4,
58
- verbosity_level=int(
59
- os.getenv("AGENT_VERBOSITY", "1")
60
- ), # quieter by default; override via env
61
- )
62
-
63
  # ================== FastAPI ==================
64
  app = fastapi.FastAPI()
65
 
@@ -405,7 +387,7 @@ async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = N
405
  """
406
  loop = asyncio.get_running_loop()
407
  q: asyncio.Queue = asyncio.Queue()
408
- agent_to_use = agent_obj or agent
409
 
410
  stop_evt = threading.Event()
411
 
@@ -712,16 +694,7 @@ async def chat_completions(req: fastapi.Request):
712
  AGENT_MODEL + "-nothink",
713
  ) and isinstance(model_name, str):
714
  try:
715
- req_llm = smolagents.models.OpenAIServerModel(
716
- model_id=model_name, api_base=UPSTREAM_BASE, api_key=HF_TOKEN
717
- )
718
- agent_for_request = smolagents.CodeAgent(
719
- model=req_llm,
720
- tools=[],
721
- add_base_tools=False,
722
- max_steps=4,
723
- verbosity_level=int(os.getenv("AGENT_VERBOSITY", "1")),
724
- )
725
  except Exception:
726
  log.exception(
727
  "Failed to construct agent for model '%s'; using default", model_name
@@ -934,4 +907,4 @@ if __name__ == "__main__":
934
 
935
  uvicorn.run(
936
  "app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False
937
- )
 
16
  import io
17
  import contextlib
18
 
 
 
 
 
19
  # Upstream pass-through
20
  import httpx
21
 
22
+ from agents.code_writing_agent import create_code_writing_agent
23
+
24
  # Logging setup
25
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
26
  log = logging.getLogger(__name__)
 
42
  if not HF_TOKEN:
43
  log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.")
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # ================== FastAPI ==================
46
  app = fastapi.FastAPI()
47
 
 
387
  """
388
  loop = asyncio.get_running_loop()
389
  q: asyncio.Queue = asyncio.Queue()
390
+ agent_to_use = agent_obj or create_code_writing_agent
391
 
392
  stop_evt = threading.Event()
393
 
 
694
  AGENT_MODEL + "-nothink",
695
  ) and isinstance(model_name, str):
696
  try:
697
+ agent_for_request = create_code_writing_agent()
 
 
 
 
 
 
 
 
 
698
  except Exception:
699
  log.exception(
700
  "Failed to construct agent for model '%s'; using default", model_name
 
907
 
908
  uvicorn.run(
909
  "app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False
910
+ )