ProximileAdmin commited on
Commit
82773d5
·
verified ·
1 Parent(s): 55e3552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -27
app.py CHANGED
@@ -9,12 +9,18 @@ import time
9
  from typing import Dict, List, Optional
10
 
11
  ENDPOINT_URL = "https://api.hyperbolic.xyz/v1"
 
12
  OAI_API_KEY = os.getenv('HYPERBOLIC_XYZ_KEY')
 
13
  VERBOSE_SHELL = True
 
14
  todays_date_string = datetime.date.today().strftime("%d %B %Y")
15
 
 
16
  NAME_OF_SERVICE = "arXiv Paper Search"
17
- DESCRIPTION_OF_SERVICE = "a service that searches and retrieves academic papers from arXiv based on various criteria"
 
 
18
  PAPER_SEARCH_FUNCTION_NAME = "search_arxiv_papers"
19
 
20
  functions_list = [
@@ -27,8 +33,8 @@ functions_list = [
27
  "type": "object",
28
  "properties": {
29
  "query": {
30
- "type": "string",
31
- "description": "Search query (e.g., 'deep learning', 'quantum computing')"
32
  },
33
  "max_results": {
34
  "type": "integer",
@@ -63,9 +69,27 @@ After receiving the results back from a function (formatted as {{"name": functio
63
 
64
  If the user request does not necessitate a function call, simply respond to the user's query directly."""
65
 
66
- def search_arxiv_papers(query: str, max_results: int = 5, sort_by: str = 'relevance') -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
 
68
  search_query = f'all:{query}'
 
 
69
  base_url = 'http://export.arxiv.org/api/query?'
70
  params = {
71
  'search_query': search_query,
@@ -76,8 +100,12 @@ def search_arxiv_papers(query: str, max_results: int = 5, sort_by: str = 'releva
76
  }
77
  query_string = '&'.join([f'{k}={urllib.parse.quote(str(v))}' for k, v in params.items()])
78
  url = base_url + query_string
 
 
79
  response = urllib.request.urlopen(url)
80
  feed = feedparser.parse(response.read().decode('utf-8'))
 
 
81
  papers = []
82
  for entry in feed.entries:
83
  paper = {
@@ -90,12 +118,16 @@ def search_arxiv_papers(query: str, max_results: int = 5, sort_by: str = 'releva
90
  'primary_category': entry.tags[0]['term']
91
  }
92
  papers.append(paper)
 
 
93
  time.sleep(3)
 
94
  return {
95
  'status': 'success',
96
  'total_results': len(papers),
97
  'papers': papers
98
  }
 
99
  except Exception as e:
100
  return {
101
  'status': 'error',
@@ -104,6 +136,7 @@ def search_arxiv_papers(query: str, max_results: int = 5, sort_by: str = 'releva
104
 
105
  functions_dict = {f["function"]["name"]: f for f in functions_list}
106
  FUNCTION_BACKENDS = {
 
107
  PAPER_SEARCH_FUNCTION_NAME: search_arxiv_papers,
108
  }
109
 
@@ -116,6 +149,8 @@ class LLM:
116
  self.api_key = OAI_API_KEY
117
  self.max_model_len = max_model_len
118
  self.client = OpenAI(base_url=ENDPOINT_URL, api_key=self.api_key)
 
 
119
  self.model_name = "meta-llama/Llama-3.3-70B-Instruct"
120
 
121
  def generate(self, prompt: str, sampling_params: dict) -> dict:
@@ -128,15 +163,18 @@ class LLM:
128
  "n": sampling_params.get("n", 1),
129
  "stream": False,
130
  }
 
131
  if "stop" in sampling_params:
132
  completion_params["stop"] = sampling_params["stop"]
133
  if "presence_penalty" in sampling_params:
134
  completion_params["presence_penalty"] = sampling_params["presence_penalty"]
135
  if "frequency_penalty" in sampling_params:
136
  completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
 
137
  return self.client.completions.create(**completion_params)
138
 
139
  def form_chat_prompt(message_history, functions=functions_dict.keys()):
 
140
  functions_string = "\n\n".join([json.dumps(functions_dict[f], indent=4) for f in functions])
141
  full_prompt = (
142
  ROLE_HEADER.format(role="system")
@@ -155,6 +193,7 @@ def form_chat_prompt(message_history, functions=functions_dict.keys()):
155
  return full_prompt
156
 
157
  def check_assistant_response_for_tool_calls(response):
 
158
  response = response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
159
  for tool_name in functions_dict.keys():
160
  if f"\"{tool_name}\"" in response and "{" in response:
@@ -168,17 +207,21 @@ def check_assistant_response_for_tool_calls(response):
168
  return None
169
 
170
  def process_tool_request(tool_request_data):
 
171
  tool_name = tool_request_data["name"]
172
  tool_parameters = tool_request_data["parameters"]
 
173
  if tool_name == PAPER_SEARCH_FUNCTION_NAME:
174
  query = tool_parameters["query"]
175
  max_results = tool_parameters.get("max_results", 5)
176
  sort_by = tool_parameters.get("sort_by", "relevance")
177
  search_results = FUNCTION_BACKENDS[tool_name](query, max_results, sort_by)
178
  return {"name": PAPER_SEARCH_FUNCTION_NAME, "results": search_results}
 
179
  return None
180
 
181
  def restore_message_history(full_history):
 
182
  restored = []
183
  for message in full_history:
184
  if message["role"] == "assistant" and "metadata" in message:
@@ -196,10 +239,13 @@ def restore_message_history(full_history):
196
  return restored
197
 
198
  def iterate_chat(llm, sampling_params, full_history):
 
199
  tool_interactions = []
 
200
  for _ in range(10):
201
  prompt = form_chat_prompt(restore_message_history(full_history) + tool_interactions)
202
  output = llm.generate(prompt, sampling_params)
 
203
  if VERBOSE_SHELL:
204
  print(f"Input prompt: {prompt}")
205
  print("-" * 50)
@@ -207,8 +253,10 @@ def iterate_chat(llm, sampling_params, full_history):
207
  print("=" * 50)
208
  if not output or not output.choices:
209
  raise ValueError("Invalid completion response")
 
210
  assistant_response = output.choices[0].text.strip()
211
  assistant_response = assistant_response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
 
212
  tool_request_data = check_assistant_response_for_tool_calls(assistant_response)
213
  if not tool_request_data:
214
  final_message = {
@@ -227,41 +275,58 @@ def iterate_chat(llm, sampling_params, full_history):
227
  }
228
  tool_interactions.append(assistant_message)
229
  tool_return_data = process_tool_request(tool_request_data)
 
230
  tool_message = {
231
  "role": "function",
232
  "content": json.dumps(tool_return_data)
233
  }
234
  tool_interactions.append(tool_message)
 
235
  return full_history
236
 
237
- def respond(message, chat_history, system_message, max_tokens, temperature, top_p):
238
- if chat_history is None:
239
- chat_history = []
240
- full_history = chat_history.copy()
241
- full_history.append({"role": "user", "content": message})
242
- sampling_params = {
243
- "temperature": temperature,
244
- "top_p": top_p,
245
- "max_tokens": max_tokens,
246
- "stop_token_ids": [128001, 128008, 128009, 128006],
247
- }
248
  updated_history = iterate_chat(llm, sampling_params, full_history)
249
  assistant_answer = updated_history[-1]["content"]
250
- chat_history.append((message, assistant_answer))
251
- return chat_history
 
 
 
 
 
 
 
 
252
 
253
  # Initialize LLM
254
  llm = LLM(max_model_len=8096)
255
 
256
- demo = gr.ChatInterface(
257
- respond,
258
- additional_inputs=[
259
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
260
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
261
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
262
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
263
- ],
264
- )
265
 
266
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  demo.launch()
 
9
  from typing import Dict, List, Optional
10
 
11
  ENDPOINT_URL = "https://api.hyperbolic.xyz/v1"
12
+
13
  OAI_API_KEY = os.getenv('HYPERBOLIC_XYZ_KEY')
14
+
15
  VERBOSE_SHELL = True
16
+
17
  todays_date_string = datetime.date.today().strftime("%d %B %Y")
18
 
19
+
20
  NAME_OF_SERVICE = "arXiv Paper Search"
21
+ DESCRIPTION_OF_SERVICE = (
22
+ "a service that searches and retrieves academic papers from arXiv based on various criteria"
23
+ )
24
  PAPER_SEARCH_FUNCTION_NAME = "search_arxiv_papers"
25
 
26
  functions_list = [
 
33
  "type": "object",
34
  "properties": {
35
  "query": {
36
+ "type": "string", # function names for AI agents should be chosen carefully to avoid confusion
37
+ "description": "Search query (e.g., 'deep learning', 'quantum computing')" # descriptions help the AI agent's LLM backend understand the function
38
  },
39
  "max_results": {
40
  "type": "integer",
 
69
 
70
  If the user request does not necessitate a function call, simply respond to the user's query directly."""
71
 
72
+ def search_arxiv_papers(
73
+ query: str,
74
+ max_results: int = 5,
75
+ sort_by: str = 'relevance'
76
+ ) -> Dict:
77
+ """
78
+ Search for papers on arXiv using their API.
79
+
80
+ Args:
81
+ query: Search query string
82
+ max_results: Maximum number of results to return (default: 5)
83
+ sort_by: Sorting criteria (default: 'relevance')
84
+
85
+ Returns:
86
+ Dictionary containing search results and metadata
87
+ """
88
  try:
89
+ # Construct the search query
90
  search_query = f'all:{query}'
91
+
92
+ # Construct the API URL
93
  base_url = 'http://export.arxiv.org/api/query?'
94
  params = {
95
  'search_query': search_query,
 
100
  }
101
  query_string = '&'.join([f'{k}={urllib.parse.quote(str(v))}' for k, v in params.items()])
102
  url = base_url + query_string
103
+
104
+ # Make the API request
105
  response = urllib.request.urlopen(url)
106
  feed = feedparser.parse(response.read().decode('utf-8'))
107
+
108
+ # Process the results
109
  papers = []
110
  for entry in feed.entries:
111
  paper = {
 
118
  'primary_category': entry.tags[0]['term']
119
  }
120
  papers.append(paper)
121
+
122
+ # Add a delay to respect API rate limits
123
  time.sleep(3)
124
+
125
  return {
126
  'status': 'success',
127
  'total_results': len(papers),
128
  'papers': papers
129
  }
130
+
131
  except Exception as e:
132
  return {
133
  'status': 'error',
 
136
 
137
  functions_dict = {f["function"]["name"]: f for f in functions_list}
138
  FUNCTION_BACKENDS = {
139
+ #WALLET_CHECK_FUNCTION_NAME: check_wallet_balance,
140
  PAPER_SEARCH_FUNCTION_NAME: search_arxiv_papers,
141
  }
142
 
 
149
  self.api_key = OAI_API_KEY
150
  self.max_model_len = max_model_len
151
  self.client = OpenAI(base_url=ENDPOINT_URL, api_key=self.api_key)
152
+ #models_list = self.client.models.list()
153
+ #self.model_name = models_list.data[0].id
154
  self.model_name = "meta-llama/Llama-3.3-70B-Instruct"
155
 
156
  def generate(self, prompt: str, sampling_params: dict) -> dict:
 
163
  "n": sampling_params.get("n", 1),
164
  "stream": False,
165
  }
166
+
167
  if "stop" in sampling_params:
168
  completion_params["stop"] = sampling_params["stop"]
169
  if "presence_penalty" in sampling_params:
170
  completion_params["presence_penalty"] = sampling_params["presence_penalty"]
171
  if "frequency_penalty" in sampling_params:
172
  completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
173
+
174
  return self.client.completions.create(**completion_params)
175
 
176
  def form_chat_prompt(message_history, functions=functions_dict.keys()):
177
+ """Builds the chat prompt for the LLM."""
178
  functions_string = "\n\n".join([json.dumps(functions_dict[f], indent=4) for f in functions])
179
  full_prompt = (
180
  ROLE_HEADER.format(role="system")
 
193
  return full_prompt
194
 
195
  def check_assistant_response_for_tool_calls(response):
196
+ """Check if the LLM response contains a function call."""
197
  response = response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
198
  for tool_name in functions_dict.keys():
199
  if f"\"{tool_name}\"" in response and "{" in response:
 
207
  return None
208
 
209
  def process_tool_request(tool_request_data):
210
+ """Process tool requests from the LLM."""
211
  tool_name = tool_request_data["name"]
212
  tool_parameters = tool_request_data["parameters"]
213
+
214
  if tool_name == PAPER_SEARCH_FUNCTION_NAME:
215
  query = tool_parameters["query"]
216
  max_results = tool_parameters.get("max_results", 5)
217
  sort_by = tool_parameters.get("sort_by", "relevance")
218
  search_results = FUNCTION_BACKENDS[tool_name](query, max_results, sort_by)
219
  return {"name": PAPER_SEARCH_FUNCTION_NAME, "results": search_results}
220
+
221
  return None
222
 
223
  def restore_message_history(full_history):
224
+ """Restore the complete message history including tool interactions."""
225
  restored = []
226
  for message in full_history:
227
  if message["role"] == "assistant" and "metadata" in message:
 
239
  return restored
240
 
241
  def iterate_chat(llm, sampling_params, full_history):
242
+ """Handle conversation turns with tool calling."""
243
  tool_interactions = []
244
+
245
  for _ in range(10):
246
  prompt = form_chat_prompt(restore_message_history(full_history) + tool_interactions)
247
  output = llm.generate(prompt, sampling_params)
248
+
249
  if VERBOSE_SHELL:
250
  print(f"Input prompt: {prompt}")
251
  print("-" * 50)
 
253
  print("=" * 50)
254
  if not output or not output.choices:
255
  raise ValueError("Invalid completion response")
256
+
257
  assistant_response = output.choices[0].text.strip()
258
  assistant_response = assistant_response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
259
+
260
  tool_request_data = check_assistant_response_for_tool_calls(assistant_response)
261
  if not tool_request_data:
262
  final_message = {
 
275
  }
276
  tool_interactions.append(assistant_message)
277
  tool_return_data = process_tool_request(tool_request_data)
278
+
279
  tool_message = {
280
  "role": "function",
281
  "content": json.dumps(tool_return_data)
282
  }
283
  tool_interactions.append(tool_message)
284
+
285
  return full_history
286
 
287
+ def user_conversation(user_message, chat_history, full_history):
288
+ """Handle user input and maintain conversation state."""
289
+ if full_history is None:
290
+ full_history = []
291
+
292
+ full_history.append({"role": "user", "content": user_message})
 
 
 
 
 
293
  updated_history = iterate_chat(llm, sampling_params, full_history)
294
  assistant_answer = updated_history[-1]["content"]
295
+ chat_history.append((user_message, assistant_answer))
296
+
297
+ return "", chat_history, updated_history
298
+
299
+ sampling_params = {
300
+ "temperature": 0.8,
301
+ "top_p": 0.95,
302
+ "max_tokens": 512,
303
+ "stop_token_ids": [128001,128008,128009,128006],
304
+ }
305
 
306
  # Initialize LLM
307
  llm = LLM(max_model_len=8096)
308
 
309
+ with gr.Blocks() as demo:
310
+ gr.Markdown(f"<h2>{NAME_OF_SERVICE}</h2>")
311
+ chat_state = gr.State([])
312
+ chatbot = gr.Chatbot(label="Chat with the arXiv Paper Search Assistant")
313
+ user_input = gr.Textbox(
314
+ lines=1,
315
+ placeholder="Type your message here...",
316
+ )
 
317
 
318
+ user_input.submit(
319
+ fn=user_conversation,
320
+ inputs=[user_input, chatbot, chat_state],
321
+ outputs=[user_input, chatbot, chat_state],
322
+ queue=False
323
+ )
324
+
325
+ send_button = gr.Button("Send")
326
+ send_button.click(
327
+ fn=user_conversation,
328
+ inputs=[user_input, chatbot, chat_state],
329
+ outputs=[user_input, chatbot, chat_state],
330
+ queue=False
331
+ )
332
  demo.launch()