Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from torchao.quantization import autoquant | |
| from diffusers import FluxPipeline | |
| pipe = FluxPipeline.from_pretrained( | |
| "sayakpaul/FLUX.1-merged", | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| pipe.transformer.to(memory_format=torch.channels_last) | |
| pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) | |
| pipe.transformer = autoquant( | |
| pipe.transformer, | |
| error_on_unseen=False | |
| ) | |
| def generate_images(prompt, guidance_scale, num_inference_steps): | |
| # # generate image with normal pipeline | |
| # image_normal = pipeline_normal( | |
| # prompt=prompt, | |
| # guidance_scale=guidance_scale, | |
| # num_inference_steps=int(num_inference_steps) | |
| # ).images[0] | |
| # generate image with optimized pipeline | |
| image_optimized = pipe( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=int(num_inference_steps) | |
| ).images[0] | |
| return image_optimized | |
| # set up Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_images, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
| gr.Slider(1.0, 10.0, step=0.5, value=3.5, label="Guidance Scale"), | |
| gr.Slider(10, 100, step=1, value=50, label="Number of Inference Steps") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Optimized FluxPipeline") | |
| ], | |
| title="FluxPipeline Comparison", | |
| description="Compare images generated by the normal FluxPipeline and the optimized one using torchao and torch.compile()." | |
| ) | |
| demo.launch() | |