Spaces:
Sleeping
Sleeping
| """ | |
| DPO Recipe Generation API - HuggingFace Spaces | |
| Generates personalized recipes using DPO-trained persona models. | |
| """ | |
| import os | |
| import json | |
| import re | |
| import torch | |
| import gradio as gr | |
| from typing import Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # Configuration | |
| BASE_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Available personas | |
| PERSONAS = { | |
| "korean_spicy": { | |
| "hf_adapter": "Hunjun/korean-spicy-dpo-adapter", | |
| "name": "Korean Food Lover (Spicy)", | |
| "cuisine": "korean", | |
| "flavor": "spicy, umami, savory", | |
| }, | |
| "mexican_vegan": { | |
| "hf_adapter": "Hunjun/mexican-vegan-dpo-adapter", | |
| "name": "Mexican Vegan", | |
| "cuisine": "mexican", | |
| "flavor": "spicy, bold, savory", | |
| "dietary_restrictions": "vegan", | |
| } | |
| } | |
| # Global model cache | |
| _base_model = None | |
| _tokenizer = None | |
| _current_persona = None | |
| _model_with_adapter = None | |
| def get_device(): | |
| """Determine the best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def load_base_model(): | |
| """Load the base model and tokenizer.""" | |
| global _base_model, _tokenizer | |
| if _base_model is not None: | |
| return | |
| print("Loading base model and tokenizer...") | |
| device = get_device() | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL_ID, | |
| token=HF_TOKEN | |
| ) | |
| _tokenizer.pad_token = _tokenizer.eos_token | |
| _base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN | |
| ) | |
| print(f"Base model loaded on {device}") | |
| def load_adapter(persona_id: str): | |
| """Load a specific persona adapter.""" | |
| global _model_with_adapter, _current_persona | |
| if _current_persona == persona_id: | |
| return | |
| load_base_model() | |
| print(f"Loading adapter for {persona_id}...") | |
| adapter_repo = PERSONAS[persona_id]["hf_adapter"] | |
| _model_with_adapter = PeftModel.from_pretrained( | |
| _base_model, | |
| adapter_repo, | |
| token=HF_TOKEN | |
| ) | |
| _model_with_adapter.eval() | |
| _current_persona = persona_id | |
| print(f"Adapter loaded: {persona_id}") | |
| def build_prompt(persona_id: str, ingredients: str, user_request: str = "") -> str: | |
| """Build ChatML format prompt.""" | |
| persona = PERSONAS[persona_id] | |
| system_msg = "You are a recipe generation AI that creates recipes based on user inventory and preferences." | |
| diet = persona.get("dietary_restrictions", "") | |
| if user_request: | |
| user_msg = f"I have {ingredients}. {user_request}" | |
| else: | |
| user_msg = f"I have {ingredients}." | |
| if diet: | |
| user_msg += f" I want a {diet} {persona['cuisine']} recipe." | |
| else: | |
| user_msg += f" I want a {persona['cuisine']} recipe." | |
| prompt = f"""<|im_start|>system | |
| {system_msg}<|im_end|> | |
| <|im_start|>user | |
| {user_msg}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| return prompt | |
| def parse_recipe_json(output: str) -> dict: | |
| """Parse recipe JSON from model output.""" | |
| try: | |
| return json.loads(output) | |
| except json.JSONDecodeError: | |
| pass | |
| json_match = re.search(r'\{[\s\S]*\}', output) | |
| if json_match: | |
| try: | |
| return json.loads(json_match.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| return { | |
| "status": "error", | |
| "error": "Failed to parse recipe", | |
| "raw_output": output[:500] | |
| } | |
| def generate_recipe( | |
| persona: str, | |
| ingredients: str, | |
| user_request: str = "", | |
| max_tokens: int = 512, | |
| temperature: float = 0.7 | |
| ) -> dict: | |
| """Generate a recipe using the specified persona.""" | |
| if persona not in PERSONAS: | |
| return {"status": "error", "error": f"Unknown persona: {persona}"} | |
| if not ingredients.strip(): | |
| return {"status": "error", "error": "Please provide at least one ingredient"} | |
| try: | |
| # Load adapter | |
| load_adapter(persona) | |
| # Build prompt | |
| prompt = build_prompt(persona, ingredients, user_request) | |
| # Tokenize | |
| inputs = _tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 | |
| ) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = _model_with_adapter.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=_tokenizer.eos_token_id | |
| ) | |
| # Decode | |
| generated_text = _tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Parse and return | |
| result = parse_recipe_json(generated_text) | |
| result["persona"] = persona | |
| result["persona_name"] = PERSONAS[persona]["name"] | |
| return result | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "persona": persona | |
| } | |
| # Gradio Interface | |
| with gr.Blocks(title="DPO Recipe Generator") as demo: | |
| gr.Markdown(""" | |
| # DPO Recipe Generator | |
| Generate personalized recipes using DPO-trained persona models. | |
| **Available Personas:** | |
| - **Korean Spicy**: Korean cuisine with emphasis on spicy flavors | |
| - **Mexican Vegan**: Mexican cuisine, plant-based recipes | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| persona_input = gr.Dropdown( | |
| choices=list(PERSONAS.keys()), | |
| value="korean_spicy", | |
| label="Persona" | |
| ) | |
| ingredients_input = gr.Textbox( | |
| label="Ingredients", | |
| placeholder="e.g., tofu, rice, gochujang, sesame oil", | |
| lines=2 | |
| ) | |
| request_input = gr.Textbox( | |
| label="Additional Request (optional)", | |
| placeholder="e.g., Make something quick and spicy", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| generate_btn = gr.Button("Generate Recipe", variant="primary") | |
| with gr.Column(): | |
| output = gr.JSON(label="Generated Recipe") | |
| generate_btn.click( | |
| fn=generate_recipe, | |
| inputs=[persona_input, ingredients_input, request_input, max_tokens, temperature], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["korean_spicy", "tofu, rice, gochujang, sesame oil, green onion", "Make something quick and spicy"], | |
| ["mexican_vegan", "black beans, avocado, lime, cilantro, tortillas", "Make fresh tacos"], | |
| ["korean_spicy", "chicken, kimchi, cheese, rice", "Make a fusion dish"], | |
| ["mexican_vegan", "quinoa, bell peppers, corn, black beans", "Make a healthy bowl"], | |
| ], | |
| inputs=[persona_input, ingredients_input, request_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |