|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
from diffusers.utils import load_image, check_min_version |
|
|
from controlnet_flux import FluxControlNetModel |
|
|
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline |
|
|
from diffusers.models.attention_processor import Attention |
|
|
from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, List, Dict, Optional, Union, Tuple |
|
|
from huggingface_hub import hf_hub_download |
|
|
import random |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BoundingBox: |
|
|
xmin: int |
|
|
ymin: int |
|
|
xmax: int |
|
|
ymax: int |
|
|
|
|
|
@property |
|
|
def xyxy(self) -> List[float]: |
|
|
return [self.xmin, self.ymin, self.xmax, self.ymax] |
|
|
|
|
|
@dataclass |
|
|
class DetectionResult: |
|
|
score: float |
|
|
label: str |
|
|
box: BoundingBox |
|
|
mask: Optional[np.array] = None |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': |
|
|
return cls(score=detection_dict['score'], |
|
|
label=detection_dict['label'], |
|
|
box=BoundingBox(xmin=detection_dict['box']['xmin'], |
|
|
ymin=detection_dict['box']['ymin'], |
|
|
xmax=detection_dict['box']['xmax'], |
|
|
ymax=detection_dict['box']['ymax'])) |
|
|
|
|
|
|
|
|
|
|
|
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: |
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if not contours: |
|
|
return [] |
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
return largest_contour.reshape(-1, 2).tolist() |
|
|
|
|
|
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: |
|
|
mask = np.zeros(image_shape, dtype=np.uint8) |
|
|
if not polygon: |
|
|
return mask |
|
|
pts = np.array(polygon, dtype=np.int32) |
|
|
cv2.fillPoly(mask, [pts], color=(255,)) |
|
|
return mask |
|
|
|
|
|
def get_boxes(results: List[DetectionResult]) -> List[List[List[float]]]: |
|
|
boxes = [result.box.xyxy for result in results] |
|
|
return [boxes] |
|
|
|
|
|
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: |
|
|
masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1) |
|
|
masks = (masks > 0).int().numpy().astype(np.uint8) |
|
|
masks = list(masks) |
|
|
|
|
|
if polygon_refinement: |
|
|
for idx, mask in enumerate(masks): |
|
|
shape = mask.shape |
|
|
polygon = mask_to_polygon(mask) |
|
|
refined_mask = polygon_to_mask(polygon, shape) |
|
|
masks[idx] = refined_mask |
|
|
return masks |
|
|
|
|
|
def detect( |
|
|
object_detector, image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None |
|
|
) -> List[DetectionResult]: |
|
|
labels = [label if label.endswith(".") else label + "." for label in labels] |
|
|
results = object_detector(image, candidate_labels=labels, threshold=threshold) |
|
|
return [DetectionResult.from_dict(result) for result in results] |
|
|
|
|
|
def segment( |
|
|
segmentator, processor, image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False |
|
|
) -> List[DetectionResult]: |
|
|
if not detection_results: |
|
|
return [] |
|
|
boxes = get_boxes(detection_results) |
|
|
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = segmentator(**inputs) |
|
|
masks = processor.post_process_masks( |
|
|
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes |
|
|
)[0] |
|
|
masks = refine_masks(masks, polygon_refinement) |
|
|
for detection_result, mask in zip(detection_results, masks): |
|
|
detection_result.mask = mask |
|
|
return detection_results |
|
|
|
|
|
def grounded_segmentation( |
|
|
detect_pipeline, segmentator, segment_processor, image: Image.Image, labels: List[str], |
|
|
) -> Tuple[np.ndarray, List[DetectionResult]]: |
|
|
detections = detect(detect_pipeline, image, labels, threshold=0.3) |
|
|
detections = segment(segmentator, segment_processor, image, detections, polygon_refinement=True) |
|
|
return np.array(image), detections |
|
|
|
|
|
def segment_image(image, object_name, detector, segmentator, seg_processor): |
|
|
""" |
|
|
Segments a specific object from an image and returns the segmented object on a white background. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image): The input image. |
|
|
object_name (str): The name of the object to segment. |
|
|
detector: The object detection pipeline. |
|
|
segmentator: The mask generation model. |
|
|
seg_processor: The processor for the mask generation model. |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: The image with the segmented object on a white background. |
|
|
|
|
|
Raises: |
|
|
gr.Error: If the object cannot be segmented. |
|
|
""" |
|
|
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name]) |
|
|
if not detections or detections[0].mask is None: |
|
|
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.") |
|
|
|
|
|
mask_expanded = np.expand_dims(detections[0].mask / 255, axis=-1) |
|
|
segment_result = image_array * mask_expanded + np.ones_like(image_array) * (1 - mask_expanded) * 255 |
|
|
return Image.fromarray(segment_result.astype(np.uint8)) |
|
|
|
|
|
def make_diptych(image): |
|
|
""" |
|
|
Creates a diptych image by concatenating the input image with a black image of the same size. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image): The input image. |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: The diptych image. |
|
|
""" |
|
|
ref_image_np = np.array(image) |
|
|
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1) |
|
|
return Image.fromarray(diptych_np) |
|
|
|
|
|
|
|
|
|
|
|
class CustomFluxAttnProcessor2_0: |
|
|
def __init__(self, height=44, width=88, attn_enforce=1.0): |
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
|
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
self.height = height |
|
|
self.width = width |
|
|
self.num_pixels = height * width |
|
|
self.step = 0 |
|
|
self.attn_enforce = attn_enforce |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
attn: Attention, |
|
|
hidden_states: torch.FloatTensor, |
|
|
encoder_hidden_states: torch.FloatTensor = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
|
) -> torch.FloatTensor: |
|
|
self.step += 1 |
|
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
|
key = attn.to_k(hidden_states) |
|
|
value = attn.to_v(hidden_states) |
|
|
inner_dim, head_dim = key.shape[-1], key.shape[-1] // attn.heads |
|
|
query, key, value = [x.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for x in [query, key, value]] |
|
|
|
|
|
if attn.norm_q is not None: query = attn.norm_q(query) |
|
|
if attn.norm_k is not None: key = attn.norm_k(key) |
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
encoder_q = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
encoder_k = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
encoder_v = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
if attn.norm_added_q is not None: encoder_q = attn.norm_added_q(encoder_q) |
|
|
if attn.norm_added_k is not None: encoder_k = attn.norm_added_k(encoder_k) |
|
|
query, key, value = [torch.cat([e, x], dim=2) for e, x in zip([encoder_q, encoder_k, encoder_v], [query, key, value])] |
|
|
|
|
|
if image_rotary_emb is not None: |
|
|
from diffusers.models.embeddings import apply_rotary_emb |
|
|
query = apply_rotary_emb(query, image_rotary_emb) |
|
|
key = apply_rotary_emb(key, image_rotary_emb) |
|
|
|
|
|
if self.attn_enforce != 1.0: |
|
|
attn_probs = (torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale).softmax(dim=-1) |
|
|
img_attn_probs = attn_probs[:, :, -self.num_pixels:, -self.num_pixels:].reshape((batch_size, attn.heads, self.height, self.width, self.height, self.width)) |
|
|
img_attn_probs[:, :, :, self.width//2:, :, :self.width//2] *= self.attn_enforce |
|
|
img_attn_probs = img_attn_probs.reshape((batch_size, attn.heads, self.num_pixels, self.num_pixels)) |
|
|
attn_probs[:, :, -self.num_pixels:, -self.num_pixels:] = img_attn_probs |
|
|
hidden_states = torch.einsum('bhqk,bhkd->bhqd', attn_probs, value) |
|
|
else: |
|
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) |
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
encoder_hs, hs = hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :] |
|
|
hs = attn.to_out[0](hs) |
|
|
hs = attn.to_out[1](hs) |
|
|
encoder_hs = attn.to_add_out(encoder_hs) |
|
|
return hs, encoder_hs |
|
|
else: |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---") |
|
|
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16) |
|
|
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device) |
|
|
|
|
|
pipe.transformer.to(torch.bfloat16) |
|
|
pipe.controlnet.to(torch.bfloat16) |
|
|
|
|
|
base_attn_procs = pipe.transformer.attn_processors.copy() |
|
|
|
|
|
print("Loading segmentation models...") |
|
|
detector_id, segmenter_id = "IDEA-Research/grounding-dino-tiny", "facebook/sam-vit-base" |
|
|
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) |
|
|
segment_processor = AutoProcessor.from_pretrained(segmenter_id) |
|
|
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) |
|
|
print("--- All models loaded successfully! ---") |
|
|
|
|
|
def get_duration( |
|
|
input_image: Image.Image, |
|
|
subject_name: str, |
|
|
do_segmentation: bool, |
|
|
full_prompt: str, |
|
|
attn_enforce: float, |
|
|
ctrl_scale: float, |
|
|
width: int, |
|
|
height: int, |
|
|
pixel_offset: int, |
|
|
num_steps: int, |
|
|
guidance: float, |
|
|
real_guidance: float, |
|
|
seed: int, |
|
|
randomize_seed: bool, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Calculates the estimated duration for the Spaces GPU based on image dimensions. |
|
|
|
|
|
Args: |
|
|
input_image (PIL.Image.Image): The input reference image. |
|
|
subject_name (str): Name of the subject for segmentation. |
|
|
do_segmentation (bool): Whether to perform segmentation. |
|
|
full_prompt (str): The full text prompt. |
|
|
attn_enforce (float): Attention enforcement value. |
|
|
ctrl_scale (float): ControlNet conditioning scale. |
|
|
width (int): Target width of the generated image. |
|
|
height (int): Target height of the generated image. |
|
|
pixel_offset (int): Padding offset in pixels. |
|
|
num_steps (int): Number of inference steps. |
|
|
guidance (float): Distilled guidance scale. |
|
|
real_guidance (float): Real guidance scale. |
|
|
seed (int): Random seed. |
|
|
randomize_seed (bool): Whether to randomize the seed. |
|
|
progress (gr.Progress): Gradio progress tracker. |
|
|
|
|
|
Returns: |
|
|
int: Estimated duration in seconds. |
|
|
""" |
|
|
if width > 768 or height > 768: |
|
|
return 210 |
|
|
else: |
|
|
return 120 |
|
|
|
|
|
@spaces.GPU(duration=get_duration) |
|
|
def run_diptych_prompting( |
|
|
input_image: Image.Image, |
|
|
subject_name: str, |
|
|
do_segmentation: bool, |
|
|
full_prompt: str, |
|
|
attn_enforce: float, |
|
|
ctrl_scale: float, |
|
|
width: int, |
|
|
height: int, |
|
|
pixel_offset: int, |
|
|
num_steps: int, |
|
|
guidance: float, |
|
|
real_guidance: float, |
|
|
seed: int, |
|
|
randomize_seed: bool, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Runs the diptych prompting image generation process. |
|
|
|
|
|
Args: |
|
|
input_image (PIL.Image.Image): The input reference image. |
|
|
subject_name (str): The name of the subject for segmentation (if `do_segmentation` is True). |
|
|
do_segmentation (bool): If True, the subject will be segmented from the reference image. |
|
|
full_prompt (str): The complete text prompt used for image generation. |
|
|
attn_enforce (float): Controls the attention enforcement in the custom attention processor. |
|
|
ctrl_scale (float): The conditioning scale for ControlNet. |
|
|
width (int): The desired width of the final generated image. |
|
|
height (int): The desired height of the final generated image. |
|
|
pixel_offset (int): Padding added around the image during diptych creation. |
|
|
num_steps (int): The number of inference steps for the diffusion process. |
|
|
guidance (float): The distilled guidance scale for the diffusion process. |
|
|
real_guidance (float): The real guidance scale for the diffusion process. |
|
|
seed (int): The random seed for reproducibility. |
|
|
randomize_seed (bool): If True, a random seed will be used instead of the provided `seed`. |
|
|
progress (gr.Progress): Gradio progress tracker to update UI during execution. |
|
|
|
|
|
Returns: |
|
|
tuple: A tuple containing: |
|
|
- PIL.Image.Image: The final generated image. |
|
|
- PIL.Image.Image: The processed reference image (left panel of the diptych). |
|
|
- PIL.Image.Image: The full diptych image generated by the pipeline. |
|
|
- str: The final prompt used. |
|
|
- int: The actual seed used for generation. |
|
|
|
|
|
Raises: |
|
|
gr.Error: If a reference image is not uploaded, prompts are empty, or segmentation fails. |
|
|
""" |
|
|
if randomize_seed: |
|
|
actual_seed = random.randint(0, 9223372036854775807) |
|
|
else: |
|
|
actual_seed = seed |
|
|
|
|
|
if input_image is None: raise gr.Error("Please upload a reference image.") |
|
|
if not full_prompt: raise gr.Error("Full Prompt is empty. Please fill out the prompt fields.") |
|
|
|
|
|
|
|
|
padded_width = width + pixel_offset * 2 |
|
|
padded_height = height + pixel_offset * 2 |
|
|
diptych_size = (padded_width * 2, padded_height) |
|
|
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB") |
|
|
|
|
|
|
|
|
progress(0, desc="Preparing reference image...") |
|
|
if do_segmentation: |
|
|
if not subject_name: |
|
|
raise gr.Error("Subject Name is required when 'Do Segmentation' is checked.") |
|
|
progress(0.05, desc="Segmenting reference image...") |
|
|
processed_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor) |
|
|
else: |
|
|
processed_image = reference_image |
|
|
|
|
|
|
|
|
progress(0.2, desc="Creating diptych and mask...") |
|
|
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1) |
|
|
mask_image = Image.fromarray(mask_image.astype(np.uint8)) |
|
|
diptych_image_prompt = make_diptych(processed_image) |
|
|
|
|
|
|
|
|
progress(0.3, desc="Setting up attention processors...") |
|
|
new_attn_procs = base_attn_procs.copy() |
|
|
for k in new_attn_procs: |
|
|
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce) |
|
|
pipe.transformer.set_attn_processor(new_attn_procs) |
|
|
|
|
|
|
|
|
progress(0.4, desc="Running diffusion process...") |
|
|
generator = torch.Generator(device="cuda").manual_seed(actual_seed) |
|
|
full_diptych_result = pipe( |
|
|
prompt=full_prompt, |
|
|
height=diptych_size[1], |
|
|
width=diptych_size[0], |
|
|
control_image=diptych_image_prompt, |
|
|
control_mask=mask_image, |
|
|
num_inference_steps=num_steps, |
|
|
generator=generator, |
|
|
controlnet_conditioning_scale=ctrl_scale, |
|
|
guidance_scale=guidance, |
|
|
negative_prompt="", |
|
|
true_guidance_scale=real_guidance |
|
|
).images[0] |
|
|
|
|
|
|
|
|
progress(0.95, desc="Finalizing image...") |
|
|
final_image = full_diptych_result.crop((padded_width, 0, padded_width * 2, padded_height)) |
|
|
final_image = final_image.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset)) |
|
|
|
|
|
|
|
|
return final_image, processed_image, full_diptych_result, full_prompt, actual_seed |
|
|
|
|
|
|
|
|
|
|
|
css = ''' |
|
|
.gradio-container{max-width: 960px;margin: 0 auto} |
|
|
''' |
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Diptych Prompting: Zero-Shot Subject-Driven & Style-Driven Image Generation |
|
|
### Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)" |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image(type="pil", label="Reference Image") |
|
|
|
|
|
with gr.Group() as subject_driven_group: |
|
|
subject_name = gr.Textbox(label="Subject Name", placeholder="e.g., a plush bear") |
|
|
|
|
|
target_prompt = gr.Textbox(label="Target Prompt", placeholder="e.g., riding a skateboard on the moon") |
|
|
|
|
|
run_button = gr.Button("Generate Image", variant="primary") |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
mode = gr.Radio(["Subject-Driven", "Style-Driven (unstable)"], label="Generation Mode", value="Subject-Driven") |
|
|
with gr.Group(visible=False) as style_driven_group: |
|
|
original_style_description = gr.Textbox(label="Original Image Description", placeholder="e.g., in watercolor painting style") |
|
|
do_segmentation = gr.Checkbox(label="Do Segmentation", value=True) |
|
|
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement") |
|
|
full_prompt = gr.Textbox(label="Full Prompt (Auto-generated, editable)", lines=3) |
|
|
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale") |
|
|
num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps") |
|
|
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Distilled Guidance Scale") |
|
|
real_guidance = gr.Slider(minimum=1.0, maximum=10.0, value=4.5, step=0.1, label="Real Guidance Scale") |
|
|
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width") |
|
|
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height") |
|
|
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)") |
|
|
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed") |
|
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image(type="pil", label="Generated Image") |
|
|
with gr.Accordion("Other Outputs", open=False) as other_outputs_accordion: |
|
|
processed_ref_image = gr.Image(label="Processed Reference (Left Panel)") |
|
|
full_diptych_image = gr.Image(label="Full Diptych Output") |
|
|
final_prompt_used = gr.Textbox(label="Final Prompt Used") |
|
|
|
|
|
|
|
|
|
|
|
def toggle_mode_visibility(mode_choice): |
|
|
""" |
|
|
Hides/shows the relevant input textboxes based on the selected mode. |
|
|
|
|
|
Args: |
|
|
mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven"). |
|
|
|
|
|
Returns: |
|
|
tuple: Gradio update objects for `subject_driven_group` and `style_driven_group` visibility. |
|
|
""" |
|
|
if mode_choice == "Subject-Driven": |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
else: |
|
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
def update_derived_fields(mode_choice, subject, style_desc, target): |
|
|
""" |
|
|
Updates the full prompt and segmentation checkbox based on other inputs. |
|
|
|
|
|
Args: |
|
|
mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven"). |
|
|
subject (str): The subject name (relevant for "Subject-Driven" mode). |
|
|
style_desc (str): The original style description (relevant for "Style-Driven" mode). |
|
|
target (str): The target prompt. |
|
|
|
|
|
Returns: |
|
|
tuple: Gradio update objects for `full_prompt` value and `do_segmentation` checkbox value. |
|
|
""" |
|
|
if mode_choice == "Subject-Driven": |
|
|
prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}" |
|
|
return gr.update(value=prompt), gr.update(value=True) |
|
|
else: |
|
|
prompt = f"A diptych with two side-by-side images of same style. On the left, {style_desc}. On the right, replicate this style exactly but as {target}" |
|
|
return gr.update(value=prompt), gr.update(value=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mode.change( |
|
|
fn=toggle_mode_visibility, |
|
|
inputs=mode, |
|
|
outputs=[subject_driven_group, style_driven_group], |
|
|
queue=False |
|
|
) |
|
|
|
|
|
|
|
|
prompt_component_inputs = [mode, subject_name, original_style_description, target_prompt] |
|
|
|
|
|
derived_outputs = [full_prompt, do_segmentation] |
|
|
|
|
|
|
|
|
for component in prompt_component_inputs: |
|
|
component.change(update_derived_fields, inputs=prompt_component_inputs, outputs=derived_outputs, queue=False, show_progress="hidden") |
|
|
|
|
|
run_button.click( |
|
|
fn=run_diptych_prompting, |
|
|
inputs=[ |
|
|
input_image, subject_name, do_segmentation, full_prompt, attn_enforce, |
|
|
ctrl_scale, width, height, pixel_offset, num_steps, guidance, |
|
|
real_guidance, seed, randomize_seed |
|
|
], |
|
|
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed] |
|
|
) |
|
|
def run_subject_driven_example(input_image, subject_name, target_prompt): |
|
|
""" |
|
|
Helper function to run an example for the subject-driven mode. |
|
|
|
|
|
Args: |
|
|
input_image (PIL.Image.Image): The input reference image for the example. |
|
|
subject_name (str): The subject name for the example. |
|
|
target_prompt (str): The target prompt for the example. |
|
|
|
|
|
Returns: |
|
|
tuple: The outputs from `run_diptych_prompting`. |
|
|
""" |
|
|
|
|
|
full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}" |
|
|
|
|
|
|
|
|
return run_diptych_prompting( |
|
|
input_image=input_image, |
|
|
subject_name=subject_name, |
|
|
do_segmentation=True, |
|
|
full_prompt=full_prompt, |
|
|
attn_enforce=1.3, |
|
|
ctrl_scale=0.95, |
|
|
width=768, |
|
|
height=768, |
|
|
pixel_offset=8, |
|
|
num_steps=28, |
|
|
guidance=3.5, |
|
|
real_guidance=4.5, |
|
|
seed=42, |
|
|
randomize_seed=False, |
|
|
) |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["./assets/cat_squished.png", "a cat toy", "a cat toy riding a skate"], |
|
|
["./assets/hf.png", "hugging face logo", "a hugging face logo on a hat"], |
|
|
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie drinking bubble tea"] |
|
|
], |
|
|
inputs=[input_image, subject_name, target_prompt], |
|
|
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed], |
|
|
fn=run_subject_driven_example, |
|
|
cache_examples="lazy" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True, debug=True, mcp_server=True) |