Spaces:
Runtime error
Runtime error
File size: 8,986 Bytes
a739682 845923e 2cbc2ed 845923e 57c0158 f6c1154 845923e ae7e181 d81b69d 34c61bc 6b4cf4f 34c61bc d81b69d c4d55e9 d81b69d 2cbc2ed ae7e181 845923e 3016f41 1b34a92 09ea0cc 1b34a92 09ea0cc 34c61bc e36a1cb 34c61bc e349f8e b7c4bd5 e349f8e 3016f41 ae7e181 2950482 755b781 7036958 ae7e181 8c8bd87 aa3092a 7deff65 aa3092a 2d41b48 e349f8e 7036958 5ad1e83 d042099 ae7e181 7036958 67dc795 e36a1cb 2950482 8a73f77 67dc795 b7c4bd5 5ad1e83 fca4fe6 2917d87 fca4fe6 8c8bd87 0ba61b6 fca4fe6 b8b73a9 09ea0cc 7d2b602 0ba61b6 aa3092a ae7e181 d81b69d ae7e181 845923e f27601c 3add37a d81b69d b8431c3 d81b69d b8431c3 ae7e181 d81b69d ae7e181 5ad1e83 d81b69d ae7e181 b8431c3 ae7e181 d81b69d ae7e181 d81b69d ae7e181 d81b69d ae7e181 d81b69d f27601c d81b69d f27601c d81b69d ae7e181 3add37a 0ba61b6 d81b69d b5a73f8 355c45b 8a73f77 355c45b 8a73f77 a17e99a 355c45b b5a73f8 845923e 8a3e08c 0ba61b6 8a3e08c 0ba61b6 8a3e08c d336207 b5a73f8 8a3e08c d336207 8a3e08c b5a73f8 845923e 4b4e5a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import spaces
import os
import gradio as gr
import torch
import safetensors
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
from PIL import Image, ImageDraw
import numpy as np
import subprocess
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
HF_TOKEN = os.getenv("HF_TOKEN")
# Ensure that the minimal version of diffusers is installed
check_min_version("0.30.2")
quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# quant_config = DiffusersBitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# )
transformerx = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# text_encoder_8bit = T5EncoderModel.from_pretrained(
# "black-forest-labs/FLUX.1-dev",
# subfolder="text_encoder_2",
# quantization_config=quant_config,
# torch_dtype=torch.bfloat16,
# use_safetensors=True,
# token=HF_TOKEN
# )
# Build pipeline
controlnet = FluxControlNetModel.from_pretrained(
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
# subfolder="controlnet",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
# text_encoder_2=text_encoder_8bit,
transformer=transformerx,
torch_dtype=torch.bfloat16,
# device_map="balanced",
token=HF_TOKEN
)
# pipe.text_encoder_2 = text_encoder_2_4bit
# pipe.transformer = transformer_4bit
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")
pipe.set_adapters(["turbo"], adapter_weights=[0.95])
pipe.fuse_lora(lora_scale=1)
pipe.unload_lora_weights()
# We can utilize the enable_group_offload method for Diffusers model implementations
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
# pipe.push_to_hub("FLUX.1-Inpainting-8step_uncensored", private=True, token=HF_TOKEN)
# pipe.enable_vae_tiling()
# pipe.enable_model_cpu_offload()
print(pipe.hf_device_map)
def create_mask_from_editor(editor_value):
"""
Create a mask from the ImageEditor value.
Args:
editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite'
Returns:
PIL Image with white mask
"""
# The 'composite' key contains the final image with all layers applied
composite_image = editor_value['composite']
# Convert to numpy array
composite_array = np.array(composite_image)
# Create mask where the composite image is white
mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255
mask_image = Image.fromarray(mask_array)
return mask_image
def create_mask_on_image(image, xyxy):
"""
Create a white mask on the image given xyxy coordinates.
Args:
image: PIL Image
xyxy: List of [x1, y1, x2, y2] coordinates
Returns:
PIL Image with white mask
"""
# Convert to numpy array
img_array = np.array(image)
# Create mask
mask = Image.new('RGB', image.size, (0, 0, 0))
draw = ImageDraw.Draw(mask)
# Draw white rectangle
draw.rectangle(xyxy, fill=(255, 255, 255))
# Convert mask to array
mask_array = np.array(mask)
# Apply mask to image
masked_array = np.where(mask_array == 255, 255, img_array)
return Image.fromarray(mask_array), Image.fromarray(masked_array)
def create_diptych_image(image):
# Create a diptych image with original on left and black on right
width, height = image.size
diptych = Image.new('RGB', (width * 2, height), 'black')
diptych.paste(image, (0, 0))
return diptych
@spaces.GPU(duration=120)
def inpaint_image(image, prompt, subject, editor_value):
# Load image and mask
size = (1536, 768)
image = load_image(image).convert("RGB").resize((768, 768))
diptych_image = create_diptych_image(image)
# mask = load_image(mask_path).convert("RGB").resize(size)
# mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400])
mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768])
generator = torch.Generator(device="cuda").manual_seed(24)
# Load and preprocess image
# Calculate attention scale mask
attn_scale_factor = 1.5
# Create a tensor of ones with same size as diptych image
H, W = size[1]//16, size[0]//16
attn_scale_mask = torch.zeros(size[1], size[0])
attn_scale_mask[:, 768:] = 1.0 # height, width
attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
# Get inverted attention mask by subtracting from 1.0
transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
cross_attn_region = cross_attn_region * attn_scale_factor
cross_attn_region[cross_attn_region < 1.0] = 1.0
full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
# Convert to bfloat16 to match model dtype
full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
subject_name=subject
target_text_prompt=prompt
prompt_final=f'A two side-by-side image of {subject_name}. LEFT: a photo of {subject_name}; RIGHT: a photo of {subject_name} {target_text_prompt}.'
# Convert attention mask to PIL image format
# Take first head's mask after prompt tokens (shape is now H*W x H*W)
attn_vis = full_attn_scale_mask[0, 0]
attn_vis[attn_vis <= 1.0] = 0
attn_vis[attn_vis > 1.0] = 255
attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8)
# # Convert to PIL Image
attn_vis_img = Image.fromarray(attn_vis)
attn_vis_img.save('attention_mask_vis.png')
with torch.inference_mode():
result = pipe(
prompt=prompt_final,
height=size[1],
width=size[0],
control_image=diptych_image,
control_mask=mask,
num_inference_steps=12,
generator=generator,
controlnet_conditioning_scale=0.7,
guidance_scale=1,
negative_prompt="",
true_guidance_scale=1.0,
attn_scale_mask=full_attn_scale_mask,
).images[0]
return result, attn_vis_img
# Create Gradio interface with structured layout
with gr.Blocks() as iface:
gr.Markdown("## FLUX Inpainting with Diptych Prompting")
gr.Markdown("Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image.")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Accordion():
input_image = gr.Image(type="filepath", label="Upload Image")
with gr.Row():
prompt_preview = gr.Textbox(value="A two side-by-side image of 'subject_name'. LEFT: a photo of 'subject_name'; RIGHT: a photo of 'subject_name' 'target_text_prompt'", interactive=False)
subject = gr.Textbox(lines=1, placeholder="Enter your subject", label="Subject")
prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt")
with gr.Column():
editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False)
inpainted_image = gr.Image(type="pil", label="Inpainted Image")
attn_vis_img = gr.Image(type="pil", label="Attn Vis Image")
with gr.Row():
inpaint_button = gr.Button("Inpaint")
inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img])
# Launch the app
iface.launch() |