Spaces:
Paused
Paused
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import os | |
| import json | |
| import zipfile | |
| from datetime import datetime | |
| from diffusers import StableDiffusionXLImg2ImgPipeline | |
| from utils.planner import ( | |
| extract_scene_plan, | |
| generate_prompt_variations_from_scene, | |
| generate_negative_prompt_from_scene | |
| ) | |
| # ---------------------------- | |
| # π» Device Configuration | |
| # ---------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # ---------------------------- | |
| # π§ Load Stable Diffusion XL Img2Img Pipeline | |
| # ---------------------------- | |
| pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| variant="fp16" if device == "cuda" else None, | |
| ) | |
| pipe.to(device) | |
| if device == "cuda": | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_attention_slicing() | |
| # ---------------------------- | |
| # π¨ Core Generation Function | |
| # ---------------------------- | |
| def process_image(prompt, image, num_variations): | |
| try: | |
| if image is None: | |
| raise ValueError("π« Please upload an image.") | |
| print("π§ Prompt received:", prompt) | |
| scene_plan = extract_scene_plan(prompt, image) | |
| enriched_prompts = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations) | |
| negative_prompt = generate_negative_prompt_from_scene(scene_plan) | |
| image = image.resize((1024, 1024)).convert("RGB") | |
| outputs = [] | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_dir = f"outputs/session_{timestamp}" | |
| os.makedirs(out_dir, exist_ok=True) | |
| for i, enriched_prompt in enumerate(enriched_prompts): | |
| print(f"β¨ Generating Image {i + 1}...") | |
| result = pipe( | |
| prompt=enriched_prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| strength=0.7, | |
| guidance_scale=7.5, | |
| num_inference_steps=30, | |
| ) | |
| output_img = result.images[0] | |
| output_img.save(f"{out_dir}/generated_{i+1}.png") | |
| outputs.append(output_img) | |
| # Save log | |
| log_data = { | |
| "timestamp": timestamp, | |
| "prompt": prompt, | |
| "scene_plan": scene_plan, | |
| "enriched_prompts": enriched_prompts, | |
| "negative_prompt": negative_prompt, | |
| "device": device, | |
| "num_variations": num_variations | |
| } | |
| os.makedirs("logs", exist_ok=True) | |
| with open("logs/generation_logs.jsonl", "a") as log_file: | |
| log_file.write(json.dumps(log_data) + "\n") | |
| # Create ZIP of outputs | |
| # Handle single or multiple image download | |
| if num_variations == 1: | |
| single_img_path = f"{out_dir}/generated_1.png" | |
| return outputs, "β Generated one image. Ready for download.", single_img_path | |
| else: | |
| zip_path = f"{out_dir}/all_images.zip" | |
| with zipfile.ZipFile(zip_path, "w") as zipf: | |
| for i in range(len(outputs)): | |
| img_path = f"{out_dir}/generated_{i+1}.png" | |
| zipf.write(img_path, os.path.basename(img_path)) | |
| return outputs, f"β Generated {num_variations} images. Download below.", zip_path | |
| except Exception as e: | |
| print("β Generation failed:", e) | |
| return [Image.new("RGB", (512, 512), color="red")], f"β Error: {str(e)}", None | |
| # ---------------------------- | |
| # π§ͺ Gradio Interface | |
| # ---------------------------- | |
| with gr.Blocks(title="NewCrux Image-to-Image Generator") as demo: | |
| gr.Markdown("### πΌοΈ NewCrux: Product Lifestyle Visual Generator (SDXL + Prompt AI)\nUpload a product image and describe the visual you want. The system will generate realistic marketing images using AI.") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="e.g., A person running on the beach wearing the product") | |
| input_image = gr.Image(type="pil", label="Product Image") | |
| num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Variations") | |
| generate_btn = gr.Button("π Generate Image(s)") | |
| output_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=[2], height="auto") | |
| output_msg = gr.Textbox(label="Generation Status", interactive=False) | |
| download_zip = gr.File(label="β¬οΈ Download All Images (.zip)", interactive=False) | |
| generate_btn.click( | |
| fn=process_image, | |
| inputs=[prompt, input_image, num_outputs], | |
| outputs=[output_gallery, output_msg, download_zip] | |
| ) | |
| # ---------------------------- | |
| # π Launch App | |
| # ---------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |