hexware's picture
Update app.py
37173a8 verified
import os
import io
import json
import time
import uuid
import random
import tempfile
import zipfile
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import gradio as gr
import spaces
from PIL import Image
from pptx import Presentation
from diffusers import QwenImageLayeredPipeline
from huggingface_hub import HfApi, login
from huggingface_hub.utils import HfHubHTTPError
LOG_DIR = "/tmp/local"
MAX_SEED = np.iinfo(np.int32).max
# -------------------------
# HF auth (Spaces secrets)
# -------------------------
def _get_hf_token() -> Optional[str]:
# priority: HF_TOKEN -> hf -> HUGGINGFACEHUB_API_TOKEN
return (
os.environ.get("HF_TOKEN")
or os.environ.get("hf")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _get_dataset_repo() -> Optional[str]:
# priority: DATASET_REPO -> HF_DATASET_REPO
return os.environ.get("DATASET_REPO") or os.environ.get("HF_DATASET_REPO")
HF_TOKEN = _get_hf_token()
DATASET_REPO = _get_dataset_repo()
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception as e:
print("HF login failed:", repr(e))
# -------------------------
# Helpers
# -------------------------
def ensure_dirname(path: str):
if path and not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def px_to_emu(px, dpi=96):
inch = px / dpi
emu = inch * 914400
return int(emu)
def imagelist_to_pptx_from_pils(images: List[Image.Image]) -> str:
if not images:
raise ValueError("No images to export")
w, h = images[0].size
prs = Presentation()
prs.slide_width = px_to_emu(w)
prs.slide_height = px_to_emu(h)
slide = prs.slides.add_slide(prs.slide_layouts[6])
left = top = 0
for img in images:
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) # /tmp
img.save(tmp.name)
slide.shapes.add_picture(
tmp.name,
left,
top,
width=px_to_emu(w),
height=px_to_emu(h),
)
out = tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) # /tmp
prs.save(out.name)
return out.name
def imagelist_to_zip_from_pils(images: List[Image.Image], prefix: str = "layer") -> str:
outzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) # /tmp
with zipfile.ZipFile(outzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
for i, img in enumerate(images):
buf = io.BytesIO()
img.save(buf, format="PNG")
zipf.writestr(f"{prefix}_{i+1}.png", buf.getvalue())
return outzip.name
def _clamp_int(x, default: int, lo: int, hi: int) -> int:
try:
v = int(x)
except Exception:
v = default
return max(lo, min(hi, v))
def _normalize_resolution(resolution: Any) -> int:
resolution = _clamp_int(resolution, default=640, lo=640, hi=1024)
if resolution not in (640, 1024):
resolution = 640
return resolution
def _normalize_input_image(input_image: Any) -> Image.Image:
if isinstance(input_image, list):
input_image = input_image[0]
if isinstance(input_image, str):
pil_image = Image.open(input_image).convert("RGB").convert("RGBA")
elif isinstance(input_image, Image.Image):
pil_image = input_image.convert("RGB").convert("RGBA")
elif isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
else:
raise ValueError(f"Unsupported input_image type: {type(input_image)}")
return pil_image
# -------------------------
# Dataset persistence helpers
# -------------------------
def ds_enabled() -> bool:
return bool(_get_hf_token()) and bool(_get_dataset_repo())
def ds_api() -> HfApi:
token = _get_hf_token()
if not token:
raise RuntimeError("HF token missing")
return HfApi(token=token)
def ds_repo_id() -> str:
repo = _get_dataset_repo()
if not repo:
raise RuntimeError("DATASET_REPO/HF_DATASET_REPO missing")
return repo
def ds_ensure_repo() -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled: missing HF token and/or dataset repo env."
api = ds_api()
repo_id = ds_repo_id()
try:
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
return True, f"Dataset repo ready: {repo_id}"
except HfHubHTTPError as e:
return False, f"Failed to create/ensure dataset repo: {e}"
except Exception as e:
return False, f"Failed to create/ensure dataset repo: {repr(e)}"
def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled: missing HF token and/or dataset repo env."
api = ds_api()
repo_id = ds_repo_id()
try:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(data)
tmp.flush()
api.upload_file(
path_or_fileobj=tmp.name,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
return True, f"Uploaded: {path_in_repo}"
except HfHubHTTPError as e:
# 403 "must use a write token" — это сюда
return False, f"Upload failed (HTTP): {e}"
except Exception as e:
return False, f"Upload failed: {repr(e)}"
def ds_download_bytes(path_in_repo: str) -> Tuple[Optional[bytes], str]:
if not ds_enabled():
return None, "Dataset persistence disabled"
api = ds_api()
repo_id = ds_repo_id()
try:
tmpdir = tempfile.mkdtemp()
local_path = api.hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=path_in_repo,
local_dir=tmpdir,
)
with open(local_path, "rb") as f:
return f.read(), "OK"
except HfHubHTTPError as e:
return None, f"Download failed (HTTP): {e}"
except Exception as e:
return None, f"Download failed: {repr(e)}"
def _root_index_path() -> str:
return "index.json"
def ds_read_root_index() -> Dict[str, Any]:
"""
Root index.json (backward compatible):
{
"id": "<last_session_id>",
"last_session_id": "<last_session_id>",
"sessions": ["sess_...", ...],
"updated_at": 123.0
}
"""
b, _ = ds_download_bytes(_root_index_path())
if b is None:
return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()}
try:
obj = json.loads(b.decode("utf-8"))
if "last_session_id" not in obj and "id" in obj:
obj["last_session_id"] = obj.get("id")
if "id" not in obj:
obj["id"] = obj.get("last_session_id")
if "sessions" not in obj or not isinstance(obj["sessions"], list):
obj["sessions"] = []
return obj
except Exception:
return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()}
def ds_write_root_index(last_session_id: Optional[str]) -> Tuple[bool, str]:
idx = ds_read_root_index()
idx["last_session_id"] = last_session_id
idx["id"] = last_session_id # FIX: KeyError('id')
idx["updated_at"] = time.time()
if last_session_id:
idx["sessions"] = [last_session_id] + [s for s in idx.get("sessions", []) if s != last_session_id]
b = json.dumps(idx, ensure_ascii=False, indent=2).encode("utf-8")
return ds_upload_bytes(_root_index_path(), b, f"update root index last_session_id={last_session_id}")
def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]:
if not ds_enabled():
return [], "Dataset persistence disabled"
api = ds_api()
repo_id = ds_repo_id()
try:
# Prefer root index (fast)
sess = []
try:
root = ds_read_root_index()
sess = [s for s in root.get("sessions", []) if isinstance(s, str)]
except Exception:
sess = []
# Fallback scan
if not sess:
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
found = set()
for p in files:
if p.startswith("sessions/") and (p.endswith("/index.json") or p.endswith("/session.json")):
parts = p.split("/")
if len(parts) >= 3:
found.add(parts[1])
sess = sorted(found, reverse=True)
sess = sess[:max_sessions]
return sess, f"Found {len(sess)} session(s)"
except Exception as e:
return [], f"List sessions failed: {repr(e)}"
# -------------------------
# Node / History model
# -------------------------
@dataclass
class NodeMeta:
node_id: str
name: str
parent_id: Optional[str]
children: List[str]
op: str # "decompose" | "refine" | "duplicate"
created_at: float
source_node_id: Optional[str] = None
source_layer_idx: Optional[int] = None
sub_layers: Optional[int] = None
settings: Optional[Dict[str, Any]] = None
def _new_id(prefix: str) -> str:
return f"{prefix}_{uuid.uuid4().hex[:10]}"
def _make_chips(state: Dict[str, Any]) -> str:
node_id = state.get("selected_node_id")
nodes: Dict[str, Any] = state.get("nodes", {})
if not node_id or node_id not in nodes:
return "[root] [parent:-] [children:0]"
meta = nodes[node_id]["meta"]
parent = meta.get("parent_id") or "-"
children = meta.get("children") or []
root = state.get("root_node_id") or "-"
return f"[root:{root}] [parent:{parent}] [children:{len(children)}]"
def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]:
nodes: Dict[str, Any] = state.get("nodes", {})
items = []
for nid, obj in nodes.items():
meta = obj["meta"]
items.append((meta.get("created_at", 0.0), nid, meta.get("name", nid)))
items.sort(key=lambda x: x[0])
return [(f"{name}{nid}", nid) for _, nid, name in items]
def _get_node_images(state: Dict[str, Any], node_id: str) -> List[Image.Image]:
nodes: Dict[str, Any] = state.get("nodes", {})
if node_id not in nodes:
return []
return nodes[node_id].get("images", []) or []
def _add_node(
state: Dict[str, Any],
*,
name: str,
parent_id: Optional[str],
op: str,
images: List[Image.Image],
settings: Optional[Dict[str, Any]] = None,
source_node_id: Optional[str] = None,
source_layer_idx: Optional[int] = None,
sub_layers: Optional[int] = None,
) -> str:
node_id = _new_id("node")
meta = NodeMeta(
node_id=node_id,
name=name,
parent_id=parent_id,
children=[],
op=op,
created_at=time.time(),
source_node_id=source_node_id,
source_layer_idx=source_layer_idx,
sub_layers=sub_layers,
settings=settings or {},
)
state.setdefault("nodes", {})
state["nodes"][node_id] = {"meta": asdict(meta), "images": images}
if parent_id and parent_id in state["nodes"]:
state["nodes"][parent_id]["meta"].setdefault("children", [])
state["nodes"][parent_id]["meta"]["children"].append(node_id)
return node_id
def _rename_node(state: Dict[str, Any], node_id: str, new_name: str):
if not new_name:
return
if node_id in state.get("nodes", {}):
state["nodes"][node_id]["meta"]["name"] = new_name
def _duplicate_node(state: Dict[str, Any], node_id: str) -> Optional[str]:
if node_id not in state.get("nodes", {}):
return None
src = state["nodes"][node_id]
meta = src["meta"]
parent_id = meta.get("parent_id")
images = src.get("images", [])
name = f"{meta.get('name','node')} (copy)"
return _add_node(
state,
name=name,
parent_id=parent_id,
op="duplicate",
images=images,
settings=meta.get("settings") or {},
)
# -------------------------
# GPU duration + GPU-only pipeline runner
# IMPORTANT: pipeline init is INSIDE GPU worker (ZeroGPU friendly)
# -------------------------
def get_duration(*args, **kwargs):
# wrapper may pass random kwargs like pil_image_rgba etc; ignore
gpu_duration = kwargs.get("gpu_duration", 1000)
return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
_GPU_PIPE: Optional[QwenImageLayeredPipeline] = None
def _gpu_get_pipe() -> QwenImageLayeredPipeline:
global _GPU_PIPE
if _GPU_PIPE is not None:
return _GPU_PIPE
# This function runs inside GPU worker (due to @spaces.GPU on caller)
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is not available inside GPU worker. "
"Check Space hardware: it must be ZeroGPU/GPU, not CPU."
)
dtype = torch.bfloat16
_GPU_PIPE = QwenImageLayeredPipeline.from_pretrained(
"Qwen/Qwen-Image-Layered",
torch_dtype=dtype,
).to("cuda")
return _GPU_PIPE
@spaces.GPU(duration=get_duration)
def gpu_run_pipeline(
pil_image_rgba: Image.Image,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
layer=4,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
):
# Everything heavy here happens on GPU worker
if randomize_seed:
seed = random.randint(0, MAX_SEED)
resolution = _normalize_resolution(resolution)
if not torch.cuda.is_available():
raise RuntimeError("No CUDA GPUs are available (GPU worker not running).")
generator = torch.Generator(device="cuda").manual_seed(int(seed))
inputs = {
"image": pil_image_rgba,
"generator": generator,
"true_cfg_scale": float(true_guidance_scale),
"prompt": prompt,
"negative_prompt": neg_prompt,
"num_inference_steps": int(num_inference_steps),
"num_images_per_prompt": 1,
"layers": int(layer),
"resolution": int(resolution),
"cfg_normalize": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
}
# reduce allocator hiccups
try:
torch.cuda.empty_cache()
except Exception:
pass
pipe = _gpu_get_pipe()
with torch.inference_mode():
out = pipe(**inputs)
output_images = out.images[0] # list of PIL layers
return output_images, int(seed), inputs
# -------------------------
# Dataset persistence: save/load nodes + session
# -------------------------
def _pil_to_png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _png_bytes_to_pil(b: bytes) -> Image.Image:
return Image.open(io.BytesIO(b)).convert("RGBA")
def _session_base(session_id: str) -> str:
return f"sessions/{session_id}"
def _node_base(session_id: str, node_id: str) -> str:
return f"{_session_base(session_id)}/nodes/{node_id}"
def _persist_node_to_dataset(state: Dict[str, Any], node_id: str) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled. Set DATASET_REPO and HF_TOKEN/hf."
ok, msg = ds_ensure_repo()
if not ok:
return False, msg
session_id = state.get("session_id")
if not session_id:
return False, "No session_id in state (run Decompose first)"
nodes = state.get("nodes", {})
if node_id not in nodes:
return False, "Unknown node_id"
node = nodes[node_id]
meta = node["meta"]
imgs: List[Image.Image] = node.get("images", []) or []
node_json = json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8")
path_node_json = f"{_node_base(session_id, node_id)}/node.json"
ok1, msg1 = ds_upload_bytes(path_node_json, node_json, f"save node {node_id}")
if not ok1:
return False, msg1
for i, img in enumerate(imgs):
b = _pil_to_png_bytes(img)
path_img = f"{_node_base(session_id, node_id)}/layer_{i+1}.png"
ok2, msg2 = ds_upload_bytes(path_img, b, f"save node {node_id} layer {i+1}")
if not ok2:
return False, msg2
return True, f"Saved node {node_id} to dataset"
def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled"
ok, msg = ds_ensure_repo()
if not ok:
return False, msg
session_id = state.get("session_id")
if not session_id:
return False, "No session_id"
manifest = {
"session_id": session_id,
"created_at": state.get("created_at"),
"root_node_id": state.get("root_node_id"),
"selected_node_id": state.get("selected_node_id"),
"nodes": {
nid: {"meta": obj["meta"], "num_layers": len(obj.get("images", []) or [])}
for nid, obj in (state.get("nodes", {}) or {}).items()
},
}
b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8")
# Save under sessions/<id>/index.json (как у тебя в датасете на скрине)
ok1, msg1 = ds_upload_bytes(f"{_session_base(session_id)}/index.json", b, f"save session index {session_id}")
if not ok1:
return False, msg1
# Optional duplicate name for compatibility
ok2, msg2 = ds_upload_bytes(f"{_session_base(session_id)}/session.json", b, f"save session manifest {session_id}")
if not ok2:
return False, msg2
# Root index.json (fix KeyError('id') + last session)
ok3, msg3 = ds_write_root_index(session_id)
if not ok3:
return False, msg3
return True, "Saved session manifest + root index"
def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
for p in (f"{_session_base(session_id)}/index.json", f"{_session_base(session_id)}/session.json"):
b, msg = ds_download_bytes(p)
if b is None:
continue
try:
return json.loads(b.decode("utf-8")), "OK"
except Exception as e:
return None, f"Failed to parse manifest: {repr(e)}"
return None, f"Manifest not found for session {session_id}"
def _load_node_images(session_id: str, node_id: str, num_layers: int) -> Tuple[List[Image.Image], str]:
imgs: List[Image.Image] = []
for i in range(num_layers):
b, msg = ds_download_bytes(f"{_node_base(session_id, node_id)}/layer_{i+1}.png")
if b is None:
return [], msg
imgs.append(_png_bytes_to_pil(b))
return imgs, "OK"
# -------------------------
# UI callbacks
# -------------------------
def _init_state() -> Dict[str, Any]:
return {
"session_id": None,
"created_at": None,
"root_node_id": None,
"selected_node_id": None,
"nodes": {},
"last_refined_node_id": None,
}
def _persistence_status_text() -> str:
tok = _get_hf_token()
repo = _get_dataset_repo()
if tok and repo:
return f"✅ Dataset persistence enabled: `{repo}`"
if repo and not tok:
return "⚠️ Dataset repo set, but HF_TOKEN/hf missing"
if tok and not repo:
return "⚠️ HF_TOKEN/hf set, but DATASET_REPO missing"
return "⚠️ Dataset persistence disabled (set HF_TOKEN + DATASET_REPO secrets to enable)"
def on_refresh_sessions():
sessions, msg = ds_list_sessions()
return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg
def on_init_dataset():
ok, msg = ds_ensure_repo()
return msg
def _current_node_export(state: Dict[str, Any], node_id: str) -> Tuple[Optional[str], Optional[str], str]:
imgs = _get_node_images(state, node_id)
if not imgs:
return None, None, "No images to export"
pptx_path = imagelist_to_pptx_from_pils(imgs)
zip_path = imagelist_to_zip_from_pils(imgs, prefix=f"{node_id}_layer")
return pptx_path, zip_path, "OK"
def _build_layer_dropdown(n: int) -> Tuple[List[str], Optional[str]]:
if n <= 0:
return [], None
choices = [f"Layer {i+1}" for i in range(n)]
return choices, choices[0]
def _layer_label(idx: int, n: int) -> str:
if n <= 0:
return "Selected: -"
idx = max(0, min(n - 1, idx))
return f"Selected: Layer {idx+1} / {n}"
def on_decompose_click(
state: Dict[str, Any],
input_image,
seed,
randomize_seed,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
layer,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
):
if state is None or not isinstance(state, dict):
state = _init_state()
pil_image = _normalize_input_image(input_image)
if not state.get("session_id"):
state["session_id"] = _new_id("sess")
state["created_at"] = time.time()
layers_out, used_seed, _used_inputs = gpu_run_pipeline(
pil_image_rgba=pil_image,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=layer,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
settings_snapshot = {
"seed": used_seed,
"randomize_seed": bool(randomize_seed),
"prompt": prompt,
"neg_prompt": neg_prompt,
"true_guidance_scale": float(true_guidance_scale),
"num_inference_steps": int(num_inference_steps),
"layers": int(layer),
"resolution": int(_normalize_resolution(resolution)),
"cfg_norm": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
"gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)),
}
state["nodes"] = {}
state["last_refined_node_id"] = None
root_id = _add_node(
state,
name="root (decompose)",
parent_id=None,
op="decompose",
images=layers_out,
settings=settings_snapshot,
)
state["root_node_id"] = root_id
state["selected_node_id"] = root_id
n_layers = len(layers_out)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
refined_visible = gr.update(visible=False)
refined_gallery = []
pptx_path, zip_path, exp_msg = _current_node_export(state, root_id)
status = f"Decomposed into {n_layers} layer(s). Seed={used_seed}. {exp_msg}"
return (
state,
layers_out,
layers_out,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=root_id),
chips,
refined_visible,
refined_gallery,
pptx_path,
zip_path,
status,
str(used_seed),
)
def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str):
node_id = state.get("selected_node_id")
imgs = _get_node_images(state, node_id) if node_id else []
n = len(imgs)
if not layer_name or not layer_name.startswith("Layer "):
idx = 0
else:
try:
idx = int(layer_name.replace("Layer ", "").strip()) - 1
except Exception:
idx = 0
idx = max(0, min(n - 1, idx)) if n > 0 else 0
return gr.update(value=idx), _layer_label(idx, n)
def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData):
node_id = state.get("selected_node_id")
imgs = _get_node_images(state, node_id) if node_id else []
n = len(imgs)
idx = int(evt.index) if evt and evt.index is not None else 0
idx = max(0, min(n - 1, idx)) if n > 0 else 0
dd_value = f"Layer {idx+1}" if n > 0 else None
return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n)
def _refine_from_source(
state: Dict[str, Any],
source_node_id: str,
source_layer_idx: int,
sub_layers: int,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
):
src_imgs = _get_node_images(state, source_node_id)
if not src_imgs:
raise ValueError("Source node has no images")
if source_layer_idx < 0 or source_layer_idx >= len(src_imgs):
raise ValueError("Invalid layer index")
selected_layer_img = src_imgs[source_layer_idx]
layers_out, used_seed, _used_inputs = gpu_run_pipeline(
pil_image_rgba=selected_layer_img,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=sub_layers,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
settings_snapshot = {
"seed": used_seed,
"randomize_seed": bool(randomize_seed),
"prompt": prompt,
"neg_prompt": neg_prompt,
"true_guidance_scale": float(true_guidance_scale),
"num_inference_steps": int(num_inference_steps),
"layers": int(sub_layers),
"resolution": int(_normalize_resolution(resolution)),
"cfg_norm": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
"gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)),
"refined_from": {"source_node_id": source_node_id, "source_layer_idx": int(source_layer_idx)},
}
return layers_out, used_seed, settings_snapshot
def on_refine_click(
state: Dict[str, Any],
selected_layer_idx: int,
sub_layers: int,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
):
if not state.get("selected_node_id"):
return (
state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, "No selected node. Run Decompose first.", gr.update()
)
source_node_id = state["selected_node_id"]
src_imgs = _get_node_images(state, source_node_id)
if not src_imgs:
return (
state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, "Selected node has no images.", gr.update()
)
n = len(src_imgs)
idx = int(selected_layer_idx) if selected_layer_idx is not None else 0
idx = max(0, min(n - 1, idx))
sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10)
layers_out, used_seed, settings_snapshot = _refine_from_source(
state,
source_node_id=source_node_id,
source_layer_idx=idx,
sub_layers=sub_layers,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
seed=seed,
randomize_seed=randomize_seed,
)
child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}"
child_id = _add_node(
state,
name=child_name,
parent_id=source_node_id,
op="refine",
images=layers_out,
settings=settings_snapshot,
source_node_id=source_node_id,
source_layer_idx=idx,
sub_layers=sub_layers,
)
state["selected_node_id"] = child_id
state["last_refined_node_id"] = child_id
n_layers = len(layers_out)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(state, child_id)
status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}"
return (
state,
layers_out,
layers_out,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=child_id),
chips,
gr.update(visible=True),
layers_out,
pptx_path,
zip_path,
status,
gr.update(),
)
def on_history_select(state: Dict[str, Any], node_id: str):
if not node_id or node_id not in state.get("nodes", {}):
return (
state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, "Unknown node."
)
state["selected_node_id"] = node_id
imgs = _get_node_images(state, node_id)
n_layers = len(imgs)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(state, node_id)
return (
state,
imgs,
imgs,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=node_id),
chips,
gr.update(visible=False),
[],
pptx_path,
zip_path,
f"Selected node: {node_id}. {exp_msg}",
)
def on_back_to_parent(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id or node_id not in state.get("nodes", {}):
return state, gr.update(), "No selected node."
parent = state["nodes"][node_id]["meta"].get("parent_id")
if not parent:
return state, gr.update(), "Already at root."
return on_history_select(state, parent)
def on_duplicate_node(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return state, gr.update(), "No selected node."
new_id = _duplicate_node(state, node_id)
if not new_id:
return state, gr.update(), "Duplicate failed."
return on_history_select(state, new_id)
def on_rename_node(state: Dict[str, Any], new_name: str):
node_id = state.get("selected_node_id")
if not node_id:
return state, gr.update(), "No selected node."
_rename_node(state, node_id, new_name)
hist_choices = _history_choices(state)
chips = _make_chips(state)
return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed."
def on_export_selected(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return None, None, "No selected node."
pptx_path, zip_path, msg = _current_node_export(state, node_id)
return pptx_path, zip_path, msg
def on_save_current(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return "Nothing to save."
ok1, msg1 = _persist_node_to_dataset(state, node_id)
if not ok1:
return msg1
ok2, msg2 = _persist_session_manifest(state)
if not ok2:
return msg2
return f"✅ Saved node + session manifest. {msg1}"
def on_load_session(state: Dict[str, Any], session_id: str):
if not session_id:
return (
state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, "Pick a session id."
)
manifest, msg = _load_session_manifest(session_id)
if manifest is None:
return (
state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, msg
)
new_state = _init_state()
new_state["session_id"] = manifest.get("session_id") or session_id
new_state["created_at"] = manifest.get("created_at")
new_state["root_node_id"] = manifest.get("root_node_id")
new_state["selected_node_id"] = manifest.get("selected_node_id") or manifest.get("root_node_id")
nodes_meta = manifest.get("nodes", {}) or {}
for nid, obj in nodes_meta.items():
meta = obj.get("meta") or {}
new_state["nodes"][nid] = {"meta": meta, "images": []}
sel = new_state["selected_node_id"]
if not sel or sel not in nodes_meta:
sel = new_state["root_node_id"]
new_state["selected_node_id"] = sel
if sel and sel in nodes_meta:
num_layers = int(nodes_meta[sel].get("num_layers", 0))
imgs, msg2 = _load_node_images(session_id, sel, num_layers)
if not imgs:
return (
new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load images: {msg2}"
)
new_state["nodes"][sel]["images"] = imgs
root = new_state.get("root_node_id")
if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]:
rl = int(nodes_meta[root].get("num_layers", 0))
rimgs, _ = _load_node_images(session_id, root, rl)
if rimgs:
new_state["nodes"][root]["images"] = rimgs
imgs = _get_node_images(new_state, sel) if sel else []
n_layers = len(imgs)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(new_state)
chips = _make_chips(new_state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(new_state, sel) if sel else (None, None, "No node")
return (
new_state,
imgs,
imgs,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=sel),
chips,
gr.update(visible=False),
[],
pptx_path,
zip_path,
f"Loaded session {session_id}. {exp_msg}",
)
def on_history_need_images(state: Dict[str, Any], node_id: str):
if not node_id or node_id not in state.get("nodes", {}):
return state, "Unknown node."
imgs = state["nodes"][node_id].get("images", [])
if imgs:
return state, "OK"
session_id = state.get("session_id")
if not session_id:
return state, "No session_id."
manifest, msg = _load_session_manifest(session_id)
if not manifest:
return state, f"Cannot load manifest: {msg}"
node_obj = (manifest.get("nodes", {}) or {}).get(node_id, {})
num_layers = int(node_obj.get("num_layers", 0))
if num_layers <= 0:
return state, "No layers in manifest for this node."
imgs2, msg2 = _load_node_images(session_id, node_id, num_layers)
if not imgs2:
return state, f"Failed to load images: {msg2}"
state["nodes"][node_id]["images"] = imgs2
return state, "Loaded images."
# -------------------------
# Build UI
# -------------------------
ensure_dirname(LOG_DIR)
examples = [
"assets/test_images/1.png",
"assets/test_images/2.png",
"assets/test_images/3.png",
"assets/test_images/4.png",
"assets/test_images/5.png",
"assets/test_images/6.png",
"assets/test_images/7.png",
"assets/test_images/8.png",
"assets/test_images/9.png",
"assets/test_images/10.png",
"assets/test_images/11.png",
"assets/test_images/12.png",
"assets/test_images/13.png",
]
with gr.Blocks() as demo:
state = gr.State(_init_state())
gr.HTML(
'<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" '
'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">'
)
persistence_banner = gr.Markdown(_persistence_status_text())
with gr.Row():
btn_init_ds = gr.Button("Init dataset repo", variant="secondary")
btn_refresh_sessions = gr.Button("Refresh sessions", variant="secondary")
ds_status = gr.Markdown("")
with gr.Row():
load_session_dd = gr.Dropdown(
label="Load session (from dataset)",
choices=[],
value=None,
allow_custom_value=True,
)
btn_load_session = gr.Button("Load session", variant="primary")
gr.Markdown(
"""
The text prompt describes the overall content of the input image.
It is not designed to control the semantic content of individual layers explicitly.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", image_mode="RGBA")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Textbox(
label="Prompt (Optional)",
placeholder="Describe the image (optional)",
value="",
lines=2,
)
neg_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Negative prompt",
value=" ",
lines=2,
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
true_guidance_scale = gr.Slider(
label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=4.0
)
num_inference_steps = gr.Slider(
label="Number of inference steps", minimum=1, maximum=100, step=1, value=50 # DO NOT CHANGE
)
layer = gr.Slider(label="Layers", minimum=2, maximum=10, step=1, value=7) # DO NOT CHANGE
resolution = gr.Radio(label="Processing resolution", choices=[640, 1024], value=640) # DO NOT CHANGE
cfg_norm = gr.Checkbox(label="Enable CFG normalization", value=True)
use_en_prompt = gr.Checkbox(label="Auto caption language: True=EN, False=ZH", value=True)
gpu_duration = gr.Textbox(
label="GPU duration override (seconds, 20..1500)",
value="1000",
lines=1,
placeholder="e.g. 120, 300, 1000, 1500",
)
btn_decompose = gr.Button("Decompose!", variant="primary")
with gr.Group():
gr.Markdown("### Refine (Recursive Decomposition)")
sub_layers = gr.Slider(label="Sub-layers (Refine)", minimum=2, maximum=10, step=1, value=3)
btn_refine = gr.Button("Refine selected layer", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Current node layers")
gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
gr.Markdown("### Layer picker (Photoshop-style)")
layer_picker = gr.Gallery(label="Pick a layer", columns=8, rows=1, format="png")
with gr.Row():
layer_dropdown = gr.Dropdown(label="Refine layer", choices=[], value=None)
selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, precision=0, interactive=False)
selected_layer_label = gr.Markdown("Selected: -")
with gr.Accordion("Refined layers (last refine)", open=True, visible=False) as refined_block:
refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png")
gr.Markdown("### History (nodes)")
with gr.Row():
history_dd = gr.Dropdown(label="Node id", choices=[], value=None)
chips_md = gr.Markdown("[root] [parent:-] [children:0]")
with gr.Row():
btn_back_parent = gr.Button("← back to parent", variant="secondary")
btn_duplicate = gr.Button("Duplicate node (branch)", variant="secondary")
with gr.Row():
rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Type new name and click Rename")
btn_rename = gr.Button("Rename", variant="secondary")
with gr.Row():
btn_export = gr.Button("Export selected node (ZIP/PPTX)", variant="primary")
btn_save = gr.Button("Save selected node to dataset", variant="primary")
with gr.Row():
export_pptx = gr.File(label="Download PPTX")
export_zip = gr.File(label="Download ZIP")
status = gr.Markdown("")
seed_used = gr.Textbox(label="Seed used", value="", interactive=False)
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[gallery, export_pptx, export_zip],
fn=lambda img: ([], None, None),
cache_examples=False,
run_on_click=False,
)
btn_init_ds.click(fn=on_init_dataset, outputs=[ds_status])
btn_refresh_sessions.click(fn=on_refresh_sessions, outputs=[load_session_dd, ds_status])
btn_load_session.click(
fn=on_load_session,
inputs=[state, load_session_dd],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status
],
)
btn_decompose.click(
fn=on_decompose_click,
inputs=[
state, input_image, seed, randomize_seed, prompt, neg_prompt, true_guidance_scale,
num_inference_steps, layer, cfg_norm, use_en_prompt, resolution, gpu_duration
],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used
],
)
layer_picker.select(
fn=on_layer_pick_from_gallery,
inputs=[state],
outputs=[selected_layer_idx, layer_dropdown, selected_layer_label],
)
layer_dropdown.change(
fn=on_layer_pick_from_dropdown,
inputs=[state, layer_dropdown],
outputs=[selected_layer_idx, selected_layer_label],
)
btn_refine.click(
fn=on_refine_click,
inputs=[
state, selected_layer_idx, sub_layers, prompt, neg_prompt, true_guidance_scale,
num_inference_steps, cfg_norm, use_en_prompt, resolution, gpu_duration, seed, randomize_seed
],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used
],
)
def _history_select_with_lazy(state_, node_id_):
state_, _ = on_history_need_images(state_, node_id_)
return on_history_select(state_, node_id_)
history_dd.change(
fn=_history_select_with_lazy,
inputs=[state, history_dd],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status
],
)
btn_back_parent.click(
fn=on_back_to_parent,
inputs=[state],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status
],
)
btn_duplicate.click(
fn=on_duplicate_node,
inputs=[state],
outputs=[
state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label,
history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status
],
)
btn_rename.click(fn=on_rename_node, inputs=[state, rename_text], outputs=[state, history_dd, chips_md, status])
btn_export.click(fn=on_export_selected, inputs=[state], outputs=[export_pptx, export_zip, status])
btn_save.click(fn=on_save_current, inputs=[state], outputs=[status])
demo.queue()
if __name__ == "__main__":
demo.launch()