PartEdit / model.py
Gorluxor
initial demo
16eb15e
import gc
import PIL.Image
import torch
from stable_diffusion_xl_partedit import PartEditPipeline, DotDictExtra, Binarization, PaddingStrategy, EmptyControl
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
available_pts = [
"pt/torso_custom.pt", # this is human torso only
"pt/chair_custom.pt", # this is seat of the chair only
"pt/carhood_custom.pt",
"pt/partimage_biped_head.pt", # this is essentially monkeys
"pt/partimage_carbody.pt", # this is everything except the wheels
"pt/partimage_human_hair.pt",
"pt/partimage_human_head.pt", # this is essentially faces
"pt/partimage_human_torso.pt", # use custom on in favour of this one
"pt/partimage_quadruped_head.pt", # this is essentially animals on 4 legs
]
def download_part(index):
return hf_hub_download(
repo_id="Aleksandar/PartEdit-extra",
repo_type="dataset",
filename=available_pts[index]
)
PART_TOKENS = {
"human_head": download_part(6),
"human_hair": download_part(5),
"human_torso_custom": download_part(0), # custom one
"chair_custom": download_part(1),
"carhood_custom": download_part(2),
"carbody": download_part(4),
"biped_head": download_part(8),
"quadruped_head": download_part(3),
"human_torso": download_part(7), # based on partimage
}
class PartEditSDXLModel:
MAX_NUM_INFERENCE_STEPS = 50
def __init__(self):
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
self.sd_pipe, self.partedit_pipe = PartEditPipeline.default_pipeline(self.device)
else:
self.pipe = None
def generate(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 0,
eta: float = 0,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise RuntimeError("This demo does not work on CPU!")
out = self.sd_pipe(
prompt=prompt,
# negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
generator=torch.Generator().manual_seed(seed),
).images[0]
gc.collect()
torch.cuda.empty_cache()
return out
def edit(
self,
prompt: str,
subject: str,
part: str,
edit: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 0,
eta: int = 0,
t_e: int = 50,
) -> PIL.Image.Image:
# Sanity Checks
if not torch.cuda.is_available():
raise RuntimeError("This demo does not work on CPU!")
if part in PART_TOKENS:
token_path = PART_TOKENS[part]
else:
raise ValueError(f"Part `{part}` is not supported!")
if subject not in prompt:
raise ValueError(f"The subject `{subject}` does not exist in the original prompt!")
prompts = [
prompt,
prompt.replace(subject, edit),
]
# PartEdit Parameters
cross_attention_kwargs = {
"edit_type": "replace",
"n_self_replace": 0.0,
"n_cross_replace": {"default_": 1.0, edit: 0.4},
}
extra_params = DotDictExtra()
extra_params.update({"omega": 1.5, "edit_steps": t_e})
out = self.partedit_pipe(
prompt=prompts,
# negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
generator=torch.Generator().manual_seed(seed),
cross_attention_kwargs=cross_attention_kwargs,
extra_kwargs=extra_params,
embedding_opt=token_path,
).images[:2][::-1]
mask = self.partedit_pipe.visualize_map_across_time()
gc.collect()
torch.cuda.empty_cache()
return out, mask