""" DPO Recipe Generation API - HuggingFace Spaces Generates personalized recipes using DPO-trained persona models. """ import os import json import re import torch import gradio as gr from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # Configuration BASE_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" HF_TOKEN = os.environ.get("HF_TOKEN", None) # Available personas PERSONAS = { "korean_spicy": { "hf_adapter": "Hunjun/korean-spicy-dpo-adapter", "name": "Korean Food Lover (Spicy)", "cuisine": "korean", "flavor": "spicy, umami, savory", }, "mexican_vegan": { "hf_adapter": "Hunjun/mexican-vegan-dpo-adapter", "name": "Mexican Vegan", "cuisine": "mexican", "flavor": "spicy, bold, savory", "dietary_restrictions": "vegan", } } # Global model cache _base_model = None _tokenizer = None _current_persona = None _model_with_adapter = None def get_device(): """Determine the best available device.""" if torch.cuda.is_available(): return "cuda" return "cpu" def load_base_model(): """Load the base model and tokenizer.""" global _base_model, _tokenizer if _base_model is not None: return print("Loading base model and tokenizer...") device = get_device() _tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL_ID, token=HF_TOKEN ) _tokenizer.pad_token = _tokenizer.eos_token _base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=HF_TOKEN ) print(f"Base model loaded on {device}") def load_adapter(persona_id: str): """Load a specific persona adapter.""" global _model_with_adapter, _current_persona if _current_persona == persona_id: return load_base_model() print(f"Loading adapter for {persona_id}...") adapter_repo = PERSONAS[persona_id]["hf_adapter"] _model_with_adapter = PeftModel.from_pretrained( _base_model, adapter_repo, token=HF_TOKEN ) _model_with_adapter.eval() _current_persona = persona_id print(f"Adapter loaded: {persona_id}") def build_prompt(persona_id: str, ingredients: str, user_request: str = "") -> str: """Build ChatML format prompt.""" persona = PERSONAS[persona_id] system_msg = "You are a recipe generation AI that creates recipes based on user inventory and preferences." diet = persona.get("dietary_restrictions", "") if user_request: user_msg = f"I have {ingredients}. {user_request}" else: user_msg = f"I have {ingredients}." if diet: user_msg += f" I want a {diet} {persona['cuisine']} recipe." else: user_msg += f" I want a {persona['cuisine']} recipe." prompt = f"""<|im_start|>system {system_msg}<|im_end|> <|im_start|>user {user_msg}<|im_end|> <|im_start|>assistant """ return prompt def parse_recipe_json(output: str) -> dict: """Parse recipe JSON from model output.""" try: return json.loads(output) except json.JSONDecodeError: pass json_match = re.search(r'\{[\s\S]*\}', output) if json_match: try: return json.loads(json_match.group()) except json.JSONDecodeError: pass return { "status": "error", "error": "Failed to parse recipe", "raw_output": output[:500] } def generate_recipe( persona: str, ingredients: str, user_request: str = "", max_tokens: int = 512, temperature: float = 0.7 ) -> dict: """Generate a recipe using the specified persona.""" if persona not in PERSONAS: return {"status": "error", "error": f"Unknown persona: {persona}"} if not ingredients.strip(): return {"status": "error", "error": "Please provide at least one ingredient"} try: # Load adapter load_adapter(persona) # Build prompt prompt = build_prompt(persona, ingredients, user_request) # Tokenize inputs = _tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048 ) # Generate with torch.no_grad(): outputs = _model_with_adapter.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=0.9, do_sample=True, pad_token_id=_tokenizer.eos_token_id ) # Decode generated_text = _tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) # Parse and return result = parse_recipe_json(generated_text) result["persona"] = persona result["persona_name"] = PERSONAS[persona]["name"] return result except Exception as e: return { "status": "error", "error": str(e), "persona": persona } # Gradio Interface with gr.Blocks(title="DPO Recipe Generator") as demo: gr.Markdown(""" # DPO Recipe Generator Generate personalized recipes using DPO-trained persona models. **Available Personas:** - **Korean Spicy**: Korean cuisine with emphasis on spicy flavors - **Mexican Vegan**: Mexican cuisine, plant-based recipes """) with gr.Row(): with gr.Column(): persona_input = gr.Dropdown( choices=list(PERSONAS.keys()), value="korean_spicy", label="Persona" ) ingredients_input = gr.Textbox( label="Ingredients", placeholder="e.g., tofu, rice, gochujang, sesame oil", lines=2 ) request_input = gr.Textbox( label="Additional Request (optional)", placeholder="e.g., Make something quick and spicy", lines=2 ) with gr.Row(): max_tokens = gr.Slider( minimum=128, maximum=1024, value=512, step=64, label="Max Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature" ) generate_btn = gr.Button("Generate Recipe", variant="primary") with gr.Column(): output = gr.JSON(label="Generated Recipe") generate_btn.click( fn=generate_recipe, inputs=[persona_input, ingredients_input, request_input, max_tokens, temperature], outputs=output ) gr.Examples( examples=[ ["korean_spicy", "tofu, rice, gochujang, sesame oil, green onion", "Make something quick and spicy"], ["mexican_vegan", "black beans, avocado, lime, cilantro, tortillas", "Make fresh tacos"], ["korean_spicy", "chicken, kimchi, cheese, rice", "Make a fusion dish"], ["mexican_vegan", "quinoa, bell peppers, corn, black beans", "Make a healthy bowl"], ], inputs=[persona_input, ingredients_input, request_input] ) if __name__ == "__main__": demo.launch()