dpo-recipe-api / app.py
Hunjun's picture
Upload folder using huggingface_hub
747d60a verified
"""
DPO Recipe Generation API - HuggingFace Spaces
Generates personalized recipes using DPO-trained persona models.
"""
import os
import json
import re
import torch
import gradio as gr
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Configuration
BASE_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Available personas
PERSONAS = {
"korean_spicy": {
"hf_adapter": "Hunjun/korean-spicy-dpo-adapter",
"name": "Korean Food Lover (Spicy)",
"cuisine": "korean",
"flavor": "spicy, umami, savory",
},
"mexican_vegan": {
"hf_adapter": "Hunjun/mexican-vegan-dpo-adapter",
"name": "Mexican Vegan",
"cuisine": "mexican",
"flavor": "spicy, bold, savory",
"dietary_restrictions": "vegan",
}
}
# Global model cache
_base_model = None
_tokenizer = None
_current_persona = None
_model_with_adapter = None
def get_device():
"""Determine the best available device."""
if torch.cuda.is_available():
return "cuda"
return "cpu"
def load_base_model():
"""Load the base model and tokenizer."""
global _base_model, _tokenizer
if _base_model is not None:
return
print("Loading base model and tokenizer...")
device = get_device()
_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN
)
_tokenizer.pad_token = _tokenizer.eos_token
_base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
token=HF_TOKEN
)
print(f"Base model loaded on {device}")
def load_adapter(persona_id: str):
"""Load a specific persona adapter."""
global _model_with_adapter, _current_persona
if _current_persona == persona_id:
return
load_base_model()
print(f"Loading adapter for {persona_id}...")
adapter_repo = PERSONAS[persona_id]["hf_adapter"]
_model_with_adapter = PeftModel.from_pretrained(
_base_model,
adapter_repo,
token=HF_TOKEN
)
_model_with_adapter.eval()
_current_persona = persona_id
print(f"Adapter loaded: {persona_id}")
def build_prompt(persona_id: str, ingredients: str, user_request: str = "") -> str:
"""Build ChatML format prompt."""
persona = PERSONAS[persona_id]
system_msg = "You are a recipe generation AI that creates recipes based on user inventory and preferences."
diet = persona.get("dietary_restrictions", "")
if user_request:
user_msg = f"I have {ingredients}. {user_request}"
else:
user_msg = f"I have {ingredients}."
if diet:
user_msg += f" I want a {diet} {persona['cuisine']} recipe."
else:
user_msg += f" I want a {persona['cuisine']} recipe."
prompt = f"""<|im_start|>system
{system_msg}<|im_end|>
<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
"""
return prompt
def parse_recipe_json(output: str) -> dict:
"""Parse recipe JSON from model output."""
try:
return json.loads(output)
except json.JSONDecodeError:
pass
json_match = re.search(r'\{[\s\S]*\}', output)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
return {
"status": "error",
"error": "Failed to parse recipe",
"raw_output": output[:500]
}
def generate_recipe(
persona: str,
ingredients: str,
user_request: str = "",
max_tokens: int = 512,
temperature: float = 0.7
) -> dict:
"""Generate a recipe using the specified persona."""
if persona not in PERSONAS:
return {"status": "error", "error": f"Unknown persona: {persona}"}
if not ingredients.strip():
return {"status": "error", "error": "Please provide at least one ingredient"}
try:
# Load adapter
load_adapter(persona)
# Build prompt
prompt = build_prompt(persona, ingredients, user_request)
# Tokenize
inputs = _tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Generate
with torch.no_grad():
outputs = _model_with_adapter.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=_tokenizer.eos_token_id
)
# Decode
generated_text = _tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Parse and return
result = parse_recipe_json(generated_text)
result["persona"] = persona
result["persona_name"] = PERSONAS[persona]["name"]
return result
except Exception as e:
return {
"status": "error",
"error": str(e),
"persona": persona
}
# Gradio Interface
with gr.Blocks(title="DPO Recipe Generator") as demo:
gr.Markdown("""
# DPO Recipe Generator
Generate personalized recipes using DPO-trained persona models.
**Available Personas:**
- **Korean Spicy**: Korean cuisine with emphasis on spicy flavors
- **Mexican Vegan**: Mexican cuisine, plant-based recipes
""")
with gr.Row():
with gr.Column():
persona_input = gr.Dropdown(
choices=list(PERSONAS.keys()),
value="korean_spicy",
label="Persona"
)
ingredients_input = gr.Textbox(
label="Ingredients",
placeholder="e.g., tofu, rice, gochujang, sesame oil",
lines=2
)
request_input = gr.Textbox(
label="Additional Request (optional)",
placeholder="e.g., Make something quick and spicy",
lines=2
)
with gr.Row():
max_tokens = gr.Slider(
minimum=128,
maximum=1024,
value=512,
step=64,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
generate_btn = gr.Button("Generate Recipe", variant="primary")
with gr.Column():
output = gr.JSON(label="Generated Recipe")
generate_btn.click(
fn=generate_recipe,
inputs=[persona_input, ingredients_input, request_input, max_tokens, temperature],
outputs=output
)
gr.Examples(
examples=[
["korean_spicy", "tofu, rice, gochujang, sesame oil, green onion", "Make something quick and spicy"],
["mexican_vegan", "black beans, avocado, lime, cilantro, tortillas", "Make fresh tacos"],
["korean_spicy", "chicken, kimchi, cheese, rice", "Make a fusion dish"],
["mexican_vegan", "quinoa, bell peppers, corn, black beans", "Make a healthy bowl"],
],
inputs=[persona_input, ingredients_input, request_input]
)
if __name__ == "__main__":
demo.launch()