Hunjun commited on
Commit
747d60a
·
verified ·
1 Parent(s): 5ebce49

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +28 -6
  2. app.py +279 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,12 +1,34 @@
1
  ---
2
- title: Dpo Recipe Api
3
- emoji: 💻
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DPO Recipe Generator
3
+ emoji: 🍳
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # DPO Recipe Generator
14
+
15
+ Generate personalized recipes using DPO-trained persona models.
16
+
17
+ ## Available Personas
18
+
19
+ - **Korean Spicy**: Korean cuisine with emphasis on spicy flavors (gochujang, kimchi)
20
+ - **Mexican Vegan**: Mexican cuisine, plant-based recipes (beans, avocado, salsa)
21
+
22
+ ## API Usage
23
+
24
+ ```bash
25
+ curl -X POST "https://hunjun-dpo-recipe-api.hf.space/api/predict" \
26
+ -H "Content-Type: application/json" \
27
+ -d '{"data": ["korean_spicy", "tofu, rice, gochujang", "Make something spicy", 512, 0.7]}'
28
+ ```
29
+
30
+ ## Models
31
+
32
+ - Base Model: `meta-llama/Llama-3.2-3B-Instruct`
33
+ - Korean Spicy Adapter: `Hunjun/korean-spicy-dpo-adapter`
34
+ - Mexican Vegan Adapter: `Hunjun/mexican-vegan-dpo-adapter`
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO Recipe Generation API - HuggingFace Spaces
3
+
4
+ Generates personalized recipes using DPO-trained persona models.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import re
10
+ import torch
11
+ import gradio as gr
12
+ from typing import Optional
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from peft import PeftModel
15
+
16
+ # Configuration
17
+ BASE_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
+
20
+ # Available personas
21
+ PERSONAS = {
22
+ "korean_spicy": {
23
+ "hf_adapter": "Hunjun/korean-spicy-dpo-adapter",
24
+ "name": "Korean Food Lover (Spicy)",
25
+ "cuisine": "korean",
26
+ "flavor": "spicy, umami, savory",
27
+ },
28
+ "mexican_vegan": {
29
+ "hf_adapter": "Hunjun/mexican-vegan-dpo-adapter",
30
+ "name": "Mexican Vegan",
31
+ "cuisine": "mexican",
32
+ "flavor": "spicy, bold, savory",
33
+ "dietary_restrictions": "vegan",
34
+ }
35
+ }
36
+
37
+ # Global model cache
38
+ _base_model = None
39
+ _tokenizer = None
40
+ _current_persona = None
41
+ _model_with_adapter = None
42
+
43
+
44
+ def get_device():
45
+ """Determine the best available device."""
46
+ if torch.cuda.is_available():
47
+ return "cuda"
48
+ return "cpu"
49
+
50
+
51
+ def load_base_model():
52
+ """Load the base model and tokenizer."""
53
+ global _base_model, _tokenizer
54
+
55
+ if _base_model is not None:
56
+ return
57
+
58
+ print("Loading base model and tokenizer...")
59
+ device = get_device()
60
+
61
+ _tokenizer = AutoTokenizer.from_pretrained(
62
+ BASE_MODEL_ID,
63
+ token=HF_TOKEN
64
+ )
65
+ _tokenizer.pad_token = _tokenizer.eos_token
66
+
67
+ _base_model = AutoModelForCausalLM.from_pretrained(
68
+ BASE_MODEL_ID,
69
+ torch_dtype=torch.float32,
70
+ low_cpu_mem_usage=True,
71
+ token=HF_TOKEN
72
+ )
73
+
74
+ print(f"Base model loaded on {device}")
75
+
76
+
77
+ def load_adapter(persona_id: str):
78
+ """Load a specific persona adapter."""
79
+ global _model_with_adapter, _current_persona
80
+
81
+ if _current_persona == persona_id:
82
+ return
83
+
84
+ load_base_model()
85
+
86
+ print(f"Loading adapter for {persona_id}...")
87
+ adapter_repo = PERSONAS[persona_id]["hf_adapter"]
88
+
89
+ _model_with_adapter = PeftModel.from_pretrained(
90
+ _base_model,
91
+ adapter_repo,
92
+ token=HF_TOKEN
93
+ )
94
+ _model_with_adapter.eval()
95
+ _current_persona = persona_id
96
+ print(f"Adapter loaded: {persona_id}")
97
+
98
+
99
+ def build_prompt(persona_id: str, ingredients: str, user_request: str = "") -> str:
100
+ """Build ChatML format prompt."""
101
+ persona = PERSONAS[persona_id]
102
+
103
+ system_msg = "You are a recipe generation AI that creates recipes based on user inventory and preferences."
104
+
105
+ diet = persona.get("dietary_restrictions", "")
106
+
107
+ if user_request:
108
+ user_msg = f"I have {ingredients}. {user_request}"
109
+ else:
110
+ user_msg = f"I have {ingredients}."
111
+
112
+ if diet:
113
+ user_msg += f" I want a {diet} {persona['cuisine']} recipe."
114
+ else:
115
+ user_msg += f" I want a {persona['cuisine']} recipe."
116
+
117
+ prompt = f"""<|im_start|>system
118
+ {system_msg}<|im_end|>
119
+ <|im_start|>user
120
+ {user_msg}<|im_end|>
121
+ <|im_start|>assistant
122
+ """
123
+ return prompt
124
+
125
+
126
+ def parse_recipe_json(output: str) -> dict:
127
+ """Parse recipe JSON from model output."""
128
+ try:
129
+ return json.loads(output)
130
+ except json.JSONDecodeError:
131
+ pass
132
+
133
+ json_match = re.search(r'\{[\s\S]*\}', output)
134
+ if json_match:
135
+ try:
136
+ return json.loads(json_match.group())
137
+ except json.JSONDecodeError:
138
+ pass
139
+
140
+ return {
141
+ "status": "error",
142
+ "error": "Failed to parse recipe",
143
+ "raw_output": output[:500]
144
+ }
145
+
146
+
147
+ def generate_recipe(
148
+ persona: str,
149
+ ingredients: str,
150
+ user_request: str = "",
151
+ max_tokens: int = 512,
152
+ temperature: float = 0.7
153
+ ) -> dict:
154
+ """Generate a recipe using the specified persona."""
155
+
156
+ if persona not in PERSONAS:
157
+ return {"status": "error", "error": f"Unknown persona: {persona}"}
158
+
159
+ if not ingredients.strip():
160
+ return {"status": "error", "error": "Please provide at least one ingredient"}
161
+
162
+ try:
163
+ # Load adapter
164
+ load_adapter(persona)
165
+
166
+ # Build prompt
167
+ prompt = build_prompt(persona, ingredients, user_request)
168
+
169
+ # Tokenize
170
+ inputs = _tokenizer(
171
+ prompt,
172
+ return_tensors="pt",
173
+ padding=True,
174
+ truncation=True,
175
+ max_length=2048
176
+ )
177
+
178
+ # Generate
179
+ with torch.no_grad():
180
+ outputs = _model_with_adapter.generate(
181
+ **inputs,
182
+ max_new_tokens=max_tokens,
183
+ temperature=temperature,
184
+ top_p=0.9,
185
+ do_sample=True,
186
+ pad_token_id=_tokenizer.eos_token_id
187
+ )
188
+
189
+ # Decode
190
+ generated_text = _tokenizer.decode(
191
+ outputs[0][inputs["input_ids"].shape[1]:],
192
+ skip_special_tokens=True
193
+ )
194
+
195
+ # Parse and return
196
+ result = parse_recipe_json(generated_text)
197
+ result["persona"] = persona
198
+ result["persona_name"] = PERSONAS[persona]["name"]
199
+
200
+ return result
201
+
202
+ except Exception as e:
203
+ return {
204
+ "status": "error",
205
+ "error": str(e),
206
+ "persona": persona
207
+ }
208
+
209
+
210
+ # Gradio Interface
211
+ with gr.Blocks(title="DPO Recipe Generator") as demo:
212
+ gr.Markdown("""
213
+ # DPO Recipe Generator
214
+
215
+ Generate personalized recipes using DPO-trained persona models.
216
+
217
+ **Available Personas:**
218
+ - **Korean Spicy**: Korean cuisine with emphasis on spicy flavors
219
+ - **Mexican Vegan**: Mexican cuisine, plant-based recipes
220
+ """)
221
+
222
+ with gr.Row():
223
+ with gr.Column():
224
+ persona_input = gr.Dropdown(
225
+ choices=list(PERSONAS.keys()),
226
+ value="korean_spicy",
227
+ label="Persona"
228
+ )
229
+ ingredients_input = gr.Textbox(
230
+ label="Ingredients",
231
+ placeholder="e.g., tofu, rice, gochujang, sesame oil",
232
+ lines=2
233
+ )
234
+ request_input = gr.Textbox(
235
+ label="Additional Request (optional)",
236
+ placeholder="e.g., Make something quick and spicy",
237
+ lines=2
238
+ )
239
+
240
+ with gr.Row():
241
+ max_tokens = gr.Slider(
242
+ minimum=128,
243
+ maximum=1024,
244
+ value=512,
245
+ step=64,
246
+ label="Max Tokens"
247
+ )
248
+ temperature = gr.Slider(
249
+ minimum=0.1,
250
+ maximum=1.5,
251
+ value=0.7,
252
+ step=0.1,
253
+ label="Temperature"
254
+ )
255
+
256
+ generate_btn = gr.Button("Generate Recipe", variant="primary")
257
+
258
+ with gr.Column():
259
+ output = gr.JSON(label="Generated Recipe")
260
+
261
+ generate_btn.click(
262
+ fn=generate_recipe,
263
+ inputs=[persona_input, ingredients_input, request_input, max_tokens, temperature],
264
+ outputs=output
265
+ )
266
+
267
+ gr.Examples(
268
+ examples=[
269
+ ["korean_spicy", "tofu, rice, gochujang, sesame oil, green onion", "Make something quick and spicy"],
270
+ ["mexican_vegan", "black beans, avocado, lime, cilantro, tortillas", "Make fresh tacos"],
271
+ ["korean_spicy", "chicken, kimchi, cheese, rice", "Make a fusion dish"],
272
+ ["mexican_vegan", "quinoa, bell peppers, corn, black beans", "Make a healthy bowl"],
273
+ ],
274
+ inputs=[persona_input, ingredients_input, request_input]
275
+ )
276
+
277
+
278
+ if __name__ == "__main__":
279
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.44.0
4
+ peft>=0.10.0
5
+ accelerate>=0.27.0
6
+ safetensors>=0.4.1
7
+ huggingface_hub>=0.20.0