Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -74,26 +74,30 @@ class SpyAgent(BasicAgent):
74
  raise NotImplementedError
75
 
76
  def llm_caller(self, prompt):
77
- client = OpenAI(
78
- api_key=os.getenv('API_KEY'),
79
- base_url=os.getenv('BASE_URL')
 
 
 
 
 
80
  )
81
- completion = client.chat.completions.create(
82
- model=self.model_name,
83
- messages=[
84
- {'role': 'system', 'content': 'You are a helpful assistant.'},
85
- {'role': 'user', 'content': prompt}
86
- ],
87
- temperature=0
88
  )
89
- try:
90
- return completion.choices[0].message.content
91
- except Exception as e:
92
- print(e)
93
- return None
 
94
 
95
 
96
  if __name__ == '__main__':
97
  name = 'spy'
98
- agent_builder = AgentBuilder(name, agent=SpyAgent(name, model_name=os.getenv('MODEL_NAME')))
99
  agent_builder.start()
 
74
  raise NotImplementedError
75
 
76
  def llm_caller(self, prompt):
77
+ messages = [
78
+ {"role": "system", "content": "You are a helpful assistant."},
79
+ {"role": "user", "content": prompt}
80
+ ]
81
+ text = self.tokenizer.apply_chat_template(
82
+ messages,
83
+ tokenize=False,
84
+ add_generation_prompt=True
85
  )
86
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
87
+
88
+ generated_ids = self.model.generate(
89
+ model_inputs.input_ids,
90
+ max_new_tokens=512
 
 
91
  )
92
+ generated_ids = [
93
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
94
+ ]
95
+
96
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
97
+ return response
98
 
99
 
100
  if __name__ == '__main__':
101
  name = 'spy'
102
+ agent_builder = AgentBuilder(name, agent=SpyAgent(name, model_name='qwen3-coder-plus'))
103
  agent_builder.start()