Spaces:
Paused
Paused
| import os | |
| import json | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from PIL import Image | |
| import torch | |
| from transformers import ( | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| CLIPTokenizer | |
| ) | |
| # ---------------------------- | |
| # π Load API Keys & Setup | |
| # ---------------------------- | |
| load_dotenv() | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------------------- | |
| # πΈ Load BLIP Captioning Model | |
| # ---------------------------- | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
| # ---------------------------- | |
| # π§ Load CLIP Tokenizer (for token check) | |
| # ---------------------------- | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| # ---------------------------- | |
| # πΈ Generate Caption from Product Image | |
| # ---------------------------- | |
| def generate_blip_caption(image: Image.Image) -> str: | |
| try: | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| out = blip_model.generate(**inputs, max_length=50) | |
| caption = processor.decode(out[0], skip_special_tokens=True) | |
| # Clean duplicate tokens | |
| caption = " ".join(dict.fromkeys(caption.split())) | |
| print(f"πΌοΈ BLIP Caption: {caption}") | |
| return caption | |
| except Exception as e: | |
| print("β BLIP Captioning Error:", e) | |
| return "a product image" | |
| # ---------------------------- | |
| # π§ GPT Scene Planning with Caption + Visual Style | |
| # ---------------------------- | |
| SCENE_SYSTEM_INSTRUCTIONS = """ | |
| You are a scene planning assistant for an AI image generation system. | |
| Your job is to take a caption from a product image, a visual style hint, and a user prompt, then return a structured JSON with: | |
| - scene (environment, setting) | |
| - subject (main_actor) | |
| - objects (main_product or items) | |
| - layout (foreground/background elements and their placement) | |
| - rules (validation rules to ensure visual correctness) | |
| Respond ONLY in raw JSON format. Do NOT include explanations. | |
| """ | |
| def extract_scene_plan(prompt: str, image: Image.Image) -> dict: | |
| try: | |
| caption = generate_blip_caption(image) | |
| visual_hint = caption if "shoe" in caption or "product" in caption else "low-top product photo on white background" | |
| merged_prompt = ( | |
| f"Image Caption: {caption}\n" | |
| f"Image Visual Style: {visual_hint}\n" | |
| f"User Prompt: {prompt}" | |
| ) | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini-2024-07-18", | |
| messages=[ | |
| {"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS}, | |
| {"role": "user", "content": merged_prompt} | |
| ], | |
| temperature=0.3, | |
| max_tokens=500 | |
| ) | |
| content = response.choices[0].message.content | |
| print("π§ Scene Plan (Raw):", content) | |
| # Logging | |
| os.makedirs("logs", exist_ok=True) | |
| with open("logs/scene_plans.jsonl", "a") as f: | |
| f.write(json.dumps({ | |
| "caption": caption, | |
| "visual_hint": visual_hint, | |
| "prompt": prompt, | |
| "scene_plan": content | |
| }) + "\n") | |
| return json.loads(content) | |
| except Exception as e: | |
| print("β extract_scene_plan() Error:", e) | |
| return { | |
| "scene": {"environment": "studio", "setting": "plain white background"}, | |
| "subject": {"main_actor": "a product"}, | |
| "objects": {"main_product": "product"}, | |
| "layout": {}, | |
| "rules": {} | |
| } | |
| # ---------------------------- | |
| # β¨ Enriched Prompt Generation (GPT, 77-token safe) | |
| # ---------------------------- | |
| ENRICHED_PROMPT_INSTRUCTIONS = """ | |
| You are a prompt engineer for an AI image generation model. | |
| Given a structured scene plan and a user prompt, generate a single natural-language enriched prompt that: | |
| 1. Describes the subject, product, setting, and layout clearly | |
| 2. Uses natural, photo-realistic language | |
| 3. Stays strictly under 77 tokens (CLIP token limit) | |
| Return ONLY the enriched prompt string. No explanations. | |
| """ | |
| def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list: | |
| prompts = [] | |
| for _ in range(n): | |
| try: | |
| user_input = f"Scene Plan:\n{json.dumps(scene_plan)}\n\nUser Prompt:\n{base_prompt}" | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini-2024-07-18", | |
| messages=[ | |
| {"role": "system", "content": ENRICHED_PROMPT_INSTRUCTIONS}, | |
| {"role": "user", "content": user_input} | |
| ], | |
| temperature=0.4, | |
| max_tokens=100 | |
| ) | |
| enriched = response.choices[0].message.content.strip() | |
| token_count = len(tokenizer(enriched)["input_ids"]) | |
| print(f"π Enriched Prompt ({token_count} tokens): {enriched}") | |
| prompts.append(enriched) | |
| except Exception as e: | |
| print("β οΈ Prompt fallback:", e) | |
| prompts.append(base_prompt) | |
| return prompts | |
| # ---------------------------- | |
| # β Negative Prompt Generator | |
| # ---------------------------- | |
| NEGATIVE_SYSTEM_PROMPT = """ | |
| You are a prompt engineer. Given a structured scene plan, generate a short negative prompt | |
| to suppress unwanted visual elements such as: distortion, blurriness, poor anatomy, | |
| logo errors, background noise, or low realism. | |
| Return a single comma-separated list. No intro text. | |
| """ | |
| def generate_negative_prompt_from_scene(scene_plan: dict) -> str: | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini-2024-07-18", | |
| messages=[ | |
| {"role": "system", "content": NEGATIVE_SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps(scene_plan)} | |
| ], | |
| temperature=0.2, | |
| max_tokens=100 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| print("β Negative Prompt Error:", e) | |
| return "blurry, distorted, low quality, deformed, watermark" | |