curiouscurrent commited on
Commit
f08b6f0
·
verified ·
1 Parent(s): 38db332

Update AI_Agent/llm_adapters/hf_adapter.py

Browse files
Files changed (1) hide show
  1. AI_Agent/llm_adapters/hf_adapter.py +23 -6
AI_Agent/llm_adapters/hf_adapter.py CHANGED
@@ -4,19 +4,36 @@ import torch
4
  import asyncio
5
 
6
  class HuggingFaceAdapter:
7
- def __init__(self, model_name="openai/gpt-oss-20b"):
8
  self.model_name = model_name
9
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
- dtype=torch.float32, # CPU-friendly
13
- device_map=None # CPU only
14
  )
15
 
16
- async def generate(self, prompt: str, max_tokens=300):
 
 
 
 
 
 
 
 
 
 
17
  def _sync_generate():
18
- inputs = self.tokenizer(prompt, return_tensors="pt") # no .to(self.model.device) needed
19
- outputs = self.model.generate(**inputs, max_new_tokens=max_tokens)
 
 
 
 
 
 
 
20
  text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
21
  return text
22
 
 
4
  import asyncio
5
 
6
  class HuggingFaceAdapter:
7
+ def __init__(self, model_name="EleutherAI/gpt-neo-125M"):
8
  self.model_name = model_name
9
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
+ torch_dtype=torch.float32, # CPU-friendly
13
+ device_map=None # CPU only
14
  )
15
 
16
+ async def generate(self, prompt: str, max_tokens=300, temperature=0.7, top_p=0.9, repetition_penalty=1.2):
17
+ """
18
+ Generate text from prompt asynchronously.
19
+
20
+ Parameters:
21
+ prompt (str): Input text prompt.
22
+ max_tokens (int): Maximum number of new tokens.
23
+ temperature (float): Randomness, higher = more diverse.
24
+ top_p (float): Nucleus sampling.
25
+ repetition_penalty (float): >1 penalizes repeating tokens.
26
+ """
27
  def _sync_generate():
28
+ inputs = self.tokenizer(prompt, return_tensors="pt")
29
+ outputs = self.model.generate(
30
+ **inputs,
31
+ max_new_tokens=max_tokens,
32
+ temperature=temperature,
33
+ top_p=top_p,
34
+ repetition_penalty=repetition_penalty,
35
+ do_sample=True # enables sampling for more varied output
36
+ )
37
  text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
  return text
39