import os import io import json import time import uuid import random import tempfile import zipfile from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch import gradio as gr import spaces from PIL import Image from pptx import Presentation from diffusers import QwenImageLayeredPipeline from huggingface_hub import HfApi, login from huggingface_hub.utils import HfHubHTTPError LOG_DIR = "/tmp/local" MAX_SEED = np.iinfo(np.int32).max # ------------------------- # HF auth (Spaces secrets) # ------------------------- def _get_hf_token() -> Optional[str]: # priority: HF_TOKEN -> hf -> HUGGINGFACEHUB_API_TOKEN return ( os.environ.get("HF_TOKEN") or os.environ.get("hf") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) def _get_dataset_repo() -> Optional[str]: # priority: DATASET_REPO -> HF_DATASET_REPO return os.environ.get("DATASET_REPO") or os.environ.get("HF_DATASET_REPO") HF_TOKEN = _get_hf_token() DATASET_REPO = _get_dataset_repo() if HF_TOKEN: try: login(token=HF_TOKEN) except Exception as e: print("HF login failed:", repr(e)) # ------------------------- # Helpers # ------------------------- def ensure_dirname(path: str): if path and not os.path.exists(path): os.makedirs(path, exist_ok=True) def px_to_emu(px, dpi=96): inch = px / dpi emu = inch * 914400 return int(emu) def imagelist_to_pptx_from_pils(images: List[Image.Image]) -> str: if not images: raise ValueError("No images to export") w, h = images[0].size prs = Presentation() prs.slide_width = px_to_emu(w) prs.slide_height = px_to_emu(h) slide = prs.slides.add_slide(prs.slide_layouts[6]) left = top = 0 for img in images: tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) # /tmp img.save(tmp.name) slide.shapes.add_picture( tmp.name, left, top, width=px_to_emu(w), height=px_to_emu(h), ) out = tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) # /tmp prs.save(out.name) return out.name def imagelist_to_zip_from_pils(images: List[Image.Image], prefix: str = "layer") -> str: outzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) # /tmp with zipfile.ZipFile(outzip.name, "w", zipfile.ZIP_DEFLATED) as zipf: for i, img in enumerate(images): buf = io.BytesIO() img.save(buf, format="PNG") zipf.writestr(f"{prefix}_{i+1}.png", buf.getvalue()) return outzip.name def _clamp_int(x, default: int, lo: int, hi: int) -> int: try: v = int(x) except Exception: v = default return max(lo, min(hi, v)) def _normalize_resolution(resolution: Any) -> int: resolution = _clamp_int(resolution, default=640, lo=640, hi=1024) if resolution not in (640, 1024): resolution = 640 return resolution def _normalize_input_image(input_image: Any) -> Image.Image: if isinstance(input_image, list): input_image = input_image[0] if isinstance(input_image, str): pil_image = Image.open(input_image).convert("RGB").convert("RGBA") elif isinstance(input_image, Image.Image): pil_image = input_image.convert("RGB").convert("RGBA") elif isinstance(input_image, np.ndarray): pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA") else: raise ValueError(f"Unsupported input_image type: {type(input_image)}") return pil_image # ------------------------- # Dataset persistence helpers # ------------------------- def ds_enabled() -> bool: return bool(_get_hf_token()) and bool(_get_dataset_repo()) def ds_api() -> HfApi: token = _get_hf_token() if not token: raise RuntimeError("HF token missing") return HfApi(token=token) def ds_repo_id() -> str: repo = _get_dataset_repo() if not repo: raise RuntimeError("DATASET_REPO/HF_DATASET_REPO missing") return repo def ds_ensure_repo() -> Tuple[bool, str]: if not ds_enabled(): return False, "Dataset persistence disabled: missing HF token and/or dataset repo env." api = ds_api() repo_id = ds_repo_id() try: api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True) return True, f"Dataset repo ready: {repo_id}" except HfHubHTTPError as e: return False, f"Failed to create/ensure dataset repo: {e}" except Exception as e: return False, f"Failed to create/ensure dataset repo: {repr(e)}" def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str) -> Tuple[bool, str]: if not ds_enabled(): return False, "Dataset persistence disabled: missing HF token and/or dataset repo env." api = ds_api() repo_id = ds_repo_id() try: with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(data) tmp.flush() api.upload_file( path_or_fileobj=tmp.name, path_in_repo=path_in_repo, repo_id=repo_id, repo_type="dataset", commit_message=commit_message, ) return True, f"Uploaded: {path_in_repo}" except HfHubHTTPError as e: # 403 "must use a write token" — это сюда return False, f"Upload failed (HTTP): {e}" except Exception as e: return False, f"Upload failed: {repr(e)}" def ds_download_bytes(path_in_repo: str) -> Tuple[Optional[bytes], str]: if not ds_enabled(): return None, "Dataset persistence disabled" api = ds_api() repo_id = ds_repo_id() try: tmpdir = tempfile.mkdtemp() local_path = api.hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=path_in_repo, local_dir=tmpdir, ) with open(local_path, "rb") as f: return f.read(), "OK" except HfHubHTTPError as e: return None, f"Download failed (HTTP): {e}" except Exception as e: return None, f"Download failed: {repr(e)}" def _root_index_path() -> str: return "index.json" def ds_read_root_index() -> Dict[str, Any]: """ Root index.json (backward compatible): { "id": "", "last_session_id": "", "sessions": ["sess_...", ...], "updated_at": 123.0 } """ b, _ = ds_download_bytes(_root_index_path()) if b is None: return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()} try: obj = json.loads(b.decode("utf-8")) if "last_session_id" not in obj and "id" in obj: obj["last_session_id"] = obj.get("id") if "id" not in obj: obj["id"] = obj.get("last_session_id") if "sessions" not in obj or not isinstance(obj["sessions"], list): obj["sessions"] = [] return obj except Exception: return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()} def ds_write_root_index(last_session_id: Optional[str]) -> Tuple[bool, str]: idx = ds_read_root_index() idx["last_session_id"] = last_session_id idx["id"] = last_session_id # FIX: KeyError('id') idx["updated_at"] = time.time() if last_session_id: idx["sessions"] = [last_session_id] + [s for s in idx.get("sessions", []) if s != last_session_id] b = json.dumps(idx, ensure_ascii=False, indent=2).encode("utf-8") return ds_upload_bytes(_root_index_path(), b, f"update root index last_session_id={last_session_id}") def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]: if not ds_enabled(): return [], "Dataset persistence disabled" api = ds_api() repo_id = ds_repo_id() try: # Prefer root index (fast) sess = [] try: root = ds_read_root_index() sess = [s for s in root.get("sessions", []) if isinstance(s, str)] except Exception: sess = [] # Fallback scan if not sess: files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") found = set() for p in files: if p.startswith("sessions/") and (p.endswith("/index.json") or p.endswith("/session.json")): parts = p.split("/") if len(parts) >= 3: found.add(parts[1]) sess = sorted(found, reverse=True) sess = sess[:max_sessions] return sess, f"Found {len(sess)} session(s)" except Exception as e: return [], f"List sessions failed: {repr(e)}" # ------------------------- # Node / History model # ------------------------- @dataclass class NodeMeta: node_id: str name: str parent_id: Optional[str] children: List[str] op: str # "decompose" | "refine" | "duplicate" created_at: float source_node_id: Optional[str] = None source_layer_idx: Optional[int] = None sub_layers: Optional[int] = None settings: Optional[Dict[str, Any]] = None def _new_id(prefix: str) -> str: return f"{prefix}_{uuid.uuid4().hex[:10]}" def _make_chips(state: Dict[str, Any]) -> str: node_id = state.get("selected_node_id") nodes: Dict[str, Any] = state.get("nodes", {}) if not node_id or node_id not in nodes: return "[root] [parent:-] [children:0]" meta = nodes[node_id]["meta"] parent = meta.get("parent_id") or "-" children = meta.get("children") or [] root = state.get("root_node_id") or "-" return f"[root:{root}] [parent:{parent}] [children:{len(children)}]" def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]: nodes: Dict[str, Any] = state.get("nodes", {}) items = [] for nid, obj in nodes.items(): meta = obj["meta"] items.append((meta.get("created_at", 0.0), nid, meta.get("name", nid))) items.sort(key=lambda x: x[0]) return [(f"{name} — {nid}", nid) for _, nid, name in items] def _get_node_images(state: Dict[str, Any], node_id: str) -> List[Image.Image]: nodes: Dict[str, Any] = state.get("nodes", {}) if node_id not in nodes: return [] return nodes[node_id].get("images", []) or [] def _add_node( state: Dict[str, Any], *, name: str, parent_id: Optional[str], op: str, images: List[Image.Image], settings: Optional[Dict[str, Any]] = None, source_node_id: Optional[str] = None, source_layer_idx: Optional[int] = None, sub_layers: Optional[int] = None, ) -> str: node_id = _new_id("node") meta = NodeMeta( node_id=node_id, name=name, parent_id=parent_id, children=[], op=op, created_at=time.time(), source_node_id=source_node_id, source_layer_idx=source_layer_idx, sub_layers=sub_layers, settings=settings or {}, ) state.setdefault("nodes", {}) state["nodes"][node_id] = {"meta": asdict(meta), "images": images} if parent_id and parent_id in state["nodes"]: state["nodes"][parent_id]["meta"].setdefault("children", []) state["nodes"][parent_id]["meta"]["children"].append(node_id) return node_id def _rename_node(state: Dict[str, Any], node_id: str, new_name: str): if not new_name: return if node_id in state.get("nodes", {}): state["nodes"][node_id]["meta"]["name"] = new_name def _duplicate_node(state: Dict[str, Any], node_id: str) -> Optional[str]: if node_id not in state.get("nodes", {}): return None src = state["nodes"][node_id] meta = src["meta"] parent_id = meta.get("parent_id") images = src.get("images", []) name = f"{meta.get('name','node')} (copy)" return _add_node( state, name=name, parent_id=parent_id, op="duplicate", images=images, settings=meta.get("settings") or {}, ) # ------------------------- # GPU duration + GPU-only pipeline runner # IMPORTANT: pipeline init is INSIDE GPU worker (ZeroGPU friendly) # ------------------------- def get_duration(*args, **kwargs): # wrapper may pass random kwargs like pil_image_rgba etc; ignore gpu_duration = kwargs.get("gpu_duration", 1000) return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500) _GPU_PIPE: Optional[QwenImageLayeredPipeline] = None def _gpu_get_pipe() -> QwenImageLayeredPipeline: global _GPU_PIPE if _GPU_PIPE is not None: return _GPU_PIPE # This function runs inside GPU worker (due to @spaces.GPU on caller) if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available inside GPU worker. " "Check Space hardware: it must be ZeroGPU/GPU, not CPU." ) dtype = torch.bfloat16 _GPU_PIPE = QwenImageLayeredPipeline.from_pretrained( "Qwen/Qwen-Image-Layered", torch_dtype=dtype, ).to("cuda") return _GPU_PIPE @spaces.GPU(duration=get_duration) def gpu_run_pipeline( pil_image_rgba: Image.Image, seed=777, randomize_seed=False, prompt=None, neg_prompt=" ", true_guidance_scale=4.0, num_inference_steps=50, layer=4, cfg_norm=True, use_en_prompt=True, resolution=640, gpu_duration=1000, ): # Everything heavy here happens on GPU worker if randomize_seed: seed = random.randint(0, MAX_SEED) resolution = _normalize_resolution(resolution) if not torch.cuda.is_available(): raise RuntimeError("No CUDA GPUs are available (GPU worker not running).") generator = torch.Generator(device="cuda").manual_seed(int(seed)) inputs = { "image": pil_image_rgba, "generator": generator, "true_cfg_scale": float(true_guidance_scale), "prompt": prompt, "negative_prompt": neg_prompt, "num_inference_steps": int(num_inference_steps), "num_images_per_prompt": 1, "layers": int(layer), "resolution": int(resolution), "cfg_normalize": bool(cfg_norm), "use_en_prompt": bool(use_en_prompt), } # reduce allocator hiccups try: torch.cuda.empty_cache() except Exception: pass pipe = _gpu_get_pipe() with torch.inference_mode(): out = pipe(**inputs) output_images = out.images[0] # list of PIL layers return output_images, int(seed), inputs # ------------------------- # Dataset persistence: save/load nodes + session # ------------------------- def _pil_to_png_bytes(img: Image.Image) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def _png_bytes_to_pil(b: bytes) -> Image.Image: return Image.open(io.BytesIO(b)).convert("RGBA") def _session_base(session_id: str) -> str: return f"sessions/{session_id}" def _node_base(session_id: str, node_id: str) -> str: return f"{_session_base(session_id)}/nodes/{node_id}" def _persist_node_to_dataset(state: Dict[str, Any], node_id: str) -> Tuple[bool, str]: if not ds_enabled(): return False, "Dataset persistence disabled. Set DATASET_REPO and HF_TOKEN/hf." ok, msg = ds_ensure_repo() if not ok: return False, msg session_id = state.get("session_id") if not session_id: return False, "No session_id in state (run Decompose first)" nodes = state.get("nodes", {}) if node_id not in nodes: return False, "Unknown node_id" node = nodes[node_id] meta = node["meta"] imgs: List[Image.Image] = node.get("images", []) or [] node_json = json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8") path_node_json = f"{_node_base(session_id, node_id)}/node.json" ok1, msg1 = ds_upload_bytes(path_node_json, node_json, f"save node {node_id}") if not ok1: return False, msg1 for i, img in enumerate(imgs): b = _pil_to_png_bytes(img) path_img = f"{_node_base(session_id, node_id)}/layer_{i+1}.png" ok2, msg2 = ds_upload_bytes(path_img, b, f"save node {node_id} layer {i+1}") if not ok2: return False, msg2 return True, f"Saved node {node_id} to dataset" def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]: if not ds_enabled(): return False, "Dataset persistence disabled" ok, msg = ds_ensure_repo() if not ok: return False, msg session_id = state.get("session_id") if not session_id: return False, "No session_id" manifest = { "session_id": session_id, "created_at": state.get("created_at"), "root_node_id": state.get("root_node_id"), "selected_node_id": state.get("selected_node_id"), "nodes": { nid: {"meta": obj["meta"], "num_layers": len(obj.get("images", []) or [])} for nid, obj in (state.get("nodes", {}) or {}).items() }, } b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8") # Save under sessions//index.json (как у тебя в датасете на скрине) ok1, msg1 = ds_upload_bytes(f"{_session_base(session_id)}/index.json", b, f"save session index {session_id}") if not ok1: return False, msg1 # Optional duplicate name for compatibility ok2, msg2 = ds_upload_bytes(f"{_session_base(session_id)}/session.json", b, f"save session manifest {session_id}") if not ok2: return False, msg2 # Root index.json (fix KeyError('id') + last session) ok3, msg3 = ds_write_root_index(session_id) if not ok3: return False, msg3 return True, "Saved session manifest + root index" def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]: for p in (f"{_session_base(session_id)}/index.json", f"{_session_base(session_id)}/session.json"): b, msg = ds_download_bytes(p) if b is None: continue try: return json.loads(b.decode("utf-8")), "OK" except Exception as e: return None, f"Failed to parse manifest: {repr(e)}" return None, f"Manifest not found for session {session_id}" def _load_node_images(session_id: str, node_id: str, num_layers: int) -> Tuple[List[Image.Image], str]: imgs: List[Image.Image] = [] for i in range(num_layers): b, msg = ds_download_bytes(f"{_node_base(session_id, node_id)}/layer_{i+1}.png") if b is None: return [], msg imgs.append(_png_bytes_to_pil(b)) return imgs, "OK" # ------------------------- # UI callbacks # ------------------------- def _init_state() -> Dict[str, Any]: return { "session_id": None, "created_at": None, "root_node_id": None, "selected_node_id": None, "nodes": {}, "last_refined_node_id": None, } def _persistence_status_text() -> str: tok = _get_hf_token() repo = _get_dataset_repo() if tok and repo: return f"✅ Dataset persistence enabled: `{repo}`" if repo and not tok: return "⚠️ Dataset repo set, but HF_TOKEN/hf missing" if tok and not repo: return "⚠️ HF_TOKEN/hf set, but DATASET_REPO missing" return "⚠️ Dataset persistence disabled (set HF_TOKEN + DATASET_REPO secrets to enable)" def on_refresh_sessions(): sessions, msg = ds_list_sessions() return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg def on_init_dataset(): ok, msg = ds_ensure_repo() return msg def _current_node_export(state: Dict[str, Any], node_id: str) -> Tuple[Optional[str], Optional[str], str]: imgs = _get_node_images(state, node_id) if not imgs: return None, None, "No images to export" pptx_path = imagelist_to_pptx_from_pils(imgs) zip_path = imagelist_to_zip_from_pils(imgs, prefix=f"{node_id}_layer") return pptx_path, zip_path, "OK" def _build_layer_dropdown(n: int) -> Tuple[List[str], Optional[str]]: if n <= 0: return [], None choices = [f"Layer {i+1}" for i in range(n)] return choices, choices[0] def _layer_label(idx: int, n: int) -> str: if n <= 0: return "Selected: -" idx = max(0, min(n - 1, idx)) return f"Selected: Layer {idx+1} / {n}" def on_decompose_click( state: Dict[str, Any], input_image, seed, randomize_seed, prompt, neg_prompt, true_guidance_scale, num_inference_steps, layer, cfg_norm, use_en_prompt, resolution, gpu_duration, ): if state is None or not isinstance(state, dict): state = _init_state() pil_image = _normalize_input_image(input_image) if not state.get("session_id"): state["session_id"] = _new_id("sess") state["created_at"] = time.time() layers_out, used_seed, _used_inputs = gpu_run_pipeline( pil_image_rgba=pil_image, seed=seed, randomize_seed=randomize_seed, prompt=prompt, neg_prompt=neg_prompt, true_guidance_scale=true_guidance_scale, num_inference_steps=num_inference_steps, layer=layer, cfg_norm=cfg_norm, use_en_prompt=use_en_prompt, resolution=resolution, gpu_duration=gpu_duration, ) settings_snapshot = { "seed": used_seed, "randomize_seed": bool(randomize_seed), "prompt": prompt, "neg_prompt": neg_prompt, "true_guidance_scale": float(true_guidance_scale), "num_inference_steps": int(num_inference_steps), "layers": int(layer), "resolution": int(_normalize_resolution(resolution)), "cfg_norm": bool(cfg_norm), "use_en_prompt": bool(use_en_prompt), "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), } state["nodes"] = {} state["last_refined_node_id"] = None root_id = _add_node( state, name="root (decompose)", parent_id=None, op="decompose", images=layers_out, settings=settings_snapshot, ) state["root_node_id"] = root_id state["selected_node_id"] = root_id n_layers = len(layers_out) layer_choices, layer_value = _build_layer_dropdown(n_layers) hist_choices = _history_choices(state) chips = _make_chips(state) selected_label = _layer_label(0, n_layers) refined_visible = gr.update(visible=False) refined_gallery = [] pptx_path, zip_path, exp_msg = _current_node_export(state, root_id) status = f"Decomposed into {n_layers} layer(s). Seed={used_seed}. {exp_msg}" return ( state, layers_out, layers_out, gr.update(choices=layer_choices, value=layer_value), gr.update(value=0), selected_label, gr.update(choices=[c[1] for c in hist_choices], value=root_id), chips, refined_visible, refined_gallery, pptx_path, zip_path, status, str(used_seed), ) def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str): node_id = state.get("selected_node_id") imgs = _get_node_images(state, node_id) if node_id else [] n = len(imgs) if not layer_name or not layer_name.startswith("Layer "): idx = 0 else: try: idx = int(layer_name.replace("Layer ", "").strip()) - 1 except Exception: idx = 0 idx = max(0, min(n - 1, idx)) if n > 0 else 0 return gr.update(value=idx), _layer_label(idx, n) def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData): node_id = state.get("selected_node_id") imgs = _get_node_images(state, node_id) if node_id else [] n = len(imgs) idx = int(evt.index) if evt and evt.index is not None else 0 idx = max(0, min(n - 1, idx)) if n > 0 else 0 dd_value = f"Layer {idx+1}" if n > 0 else None return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n) def _refine_from_source( state: Dict[str, Any], source_node_id: str, source_layer_idx: int, sub_layers: int, prompt, neg_prompt, true_guidance_scale, num_inference_steps, cfg_norm, use_en_prompt, resolution, gpu_duration, seed, randomize_seed, ): src_imgs = _get_node_images(state, source_node_id) if not src_imgs: raise ValueError("Source node has no images") if source_layer_idx < 0 or source_layer_idx >= len(src_imgs): raise ValueError("Invalid layer index") selected_layer_img = src_imgs[source_layer_idx] layers_out, used_seed, _used_inputs = gpu_run_pipeline( pil_image_rgba=selected_layer_img, seed=seed, randomize_seed=randomize_seed, prompt=prompt, neg_prompt=neg_prompt, true_guidance_scale=true_guidance_scale, num_inference_steps=num_inference_steps, layer=sub_layers, cfg_norm=cfg_norm, use_en_prompt=use_en_prompt, resolution=resolution, gpu_duration=gpu_duration, ) settings_snapshot = { "seed": used_seed, "randomize_seed": bool(randomize_seed), "prompt": prompt, "neg_prompt": neg_prompt, "true_guidance_scale": float(true_guidance_scale), "num_inference_steps": int(num_inference_steps), "layers": int(sub_layers), "resolution": int(_normalize_resolution(resolution)), "cfg_norm": bool(cfg_norm), "use_en_prompt": bool(use_en_prompt), "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), "refined_from": {"source_node_id": source_node_id, "source_layer_idx": int(source_layer_idx)}, } return layers_out, used_seed, settings_snapshot def on_refine_click( state: Dict[str, Any], selected_layer_idx: int, sub_layers: int, prompt, neg_prompt, true_guidance_scale, num_inference_steps, cfg_norm, use_en_prompt, resolution, gpu_duration, seed, randomize_seed, ): if not state.get("selected_node_id"): return ( state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "No selected node. Run Decompose first.", gr.update() ) source_node_id = state["selected_node_id"] src_imgs = _get_node_images(state, source_node_id) if not src_imgs: return ( state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "Selected node has no images.", gr.update() ) n = len(src_imgs) idx = int(selected_layer_idx) if selected_layer_idx is not None else 0 idx = max(0, min(n - 1, idx)) sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10) layers_out, used_seed, settings_snapshot = _refine_from_source( state, source_node_id=source_node_id, source_layer_idx=idx, sub_layers=sub_layers, prompt=prompt, neg_prompt=neg_prompt, true_guidance_scale=true_guidance_scale, num_inference_steps=num_inference_steps, cfg_norm=cfg_norm, use_en_prompt=use_en_prompt, resolution=resolution, gpu_duration=gpu_duration, seed=seed, randomize_seed=randomize_seed, ) child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}" child_id = _add_node( state, name=child_name, parent_id=source_node_id, op="refine", images=layers_out, settings=settings_snapshot, source_node_id=source_node_id, source_layer_idx=idx, sub_layers=sub_layers, ) state["selected_node_id"] = child_id state["last_refined_node_id"] = child_id n_layers = len(layers_out) layer_choices, layer_value = _build_layer_dropdown(n_layers) hist_choices = _history_choices(state) chips = _make_chips(state) selected_label = _layer_label(0, n_layers) pptx_path, zip_path, exp_msg = _current_node_export(state, child_id) status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}" return ( state, layers_out, layers_out, gr.update(choices=layer_choices, value=layer_value), gr.update(value=0), selected_label, gr.update(choices=[c[1] for c in hist_choices], value=child_id), chips, gr.update(visible=True), layers_out, pptx_path, zip_path, status, gr.update(), ) def on_history_select(state: Dict[str, Any], node_id: str): if not node_id or node_id not in state.get("nodes", {}): return ( state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "Unknown node." ) state["selected_node_id"] = node_id imgs = _get_node_images(state, node_id) n_layers = len(imgs) layer_choices, layer_value = _build_layer_dropdown(n_layers) hist_choices = _history_choices(state) chips = _make_chips(state) selected_label = _layer_label(0, n_layers) pptx_path, zip_path, exp_msg = _current_node_export(state, node_id) return ( state, imgs, imgs, gr.update(choices=layer_choices, value=layer_value), gr.update(value=0), selected_label, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, gr.update(visible=False), [], pptx_path, zip_path, f"Selected node: {node_id}. {exp_msg}", ) def on_back_to_parent(state: Dict[str, Any]): node_id = state.get("selected_node_id") if not node_id or node_id not in state.get("nodes", {}): return state, gr.update(), "No selected node." parent = state["nodes"][node_id]["meta"].get("parent_id") if not parent: return state, gr.update(), "Already at root." return on_history_select(state, parent) def on_duplicate_node(state: Dict[str, Any]): node_id = state.get("selected_node_id") if not node_id: return state, gr.update(), "No selected node." new_id = _duplicate_node(state, node_id) if not new_id: return state, gr.update(), "Duplicate failed." return on_history_select(state, new_id) def on_rename_node(state: Dict[str, Any], new_name: str): node_id = state.get("selected_node_id") if not node_id: return state, gr.update(), "No selected node." _rename_node(state, node_id, new_name) hist_choices = _history_choices(state) chips = _make_chips(state) return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed." def on_export_selected(state: Dict[str, Any]): node_id = state.get("selected_node_id") if not node_id: return None, None, "No selected node." pptx_path, zip_path, msg = _current_node_export(state, node_id) return pptx_path, zip_path, msg def on_save_current(state: Dict[str, Any]): node_id = state.get("selected_node_id") if not node_id: return "Nothing to save." ok1, msg1 = _persist_node_to_dataset(state, node_id) if not ok1: return msg1 ok2, msg2 = _persist_session_manifest(state) if not ok2: return msg2 return f"✅ Saved node + session manifest. {msg1}" def on_load_session(state: Dict[str, Any], session_id: str): if not session_id: return ( state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "Pick a session id." ) manifest, msg = _load_session_manifest(session_id) if manifest is None: return ( state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, msg ) new_state = _init_state() new_state["session_id"] = manifest.get("session_id") or session_id new_state["created_at"] = manifest.get("created_at") new_state["root_node_id"] = manifest.get("root_node_id") new_state["selected_node_id"] = manifest.get("selected_node_id") or manifest.get("root_node_id") nodes_meta = manifest.get("nodes", {}) or {} for nid, obj in nodes_meta.items(): meta = obj.get("meta") or {} new_state["nodes"][nid] = {"meta": meta, "images": []} sel = new_state["selected_node_id"] if not sel or sel not in nodes_meta: sel = new_state["root_node_id"] new_state["selected_node_id"] = sel if sel and sel in nodes_meta: num_layers = int(nodes_meta[sel].get("num_layers", 0)) imgs, msg2 = _load_node_images(session_id, sel, num_layers) if not imgs: return ( new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load images: {msg2}" ) new_state["nodes"][sel]["images"] = imgs root = new_state.get("root_node_id") if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]: rl = int(nodes_meta[root].get("num_layers", 0)) rimgs, _ = _load_node_images(session_id, root, rl) if rimgs: new_state["nodes"][root]["images"] = rimgs imgs = _get_node_images(new_state, sel) if sel else [] n_layers = len(imgs) layer_choices, layer_value = _build_layer_dropdown(n_layers) hist_choices = _history_choices(new_state) chips = _make_chips(new_state) selected_label = _layer_label(0, n_layers) pptx_path, zip_path, exp_msg = _current_node_export(new_state, sel) if sel else (None, None, "No node") return ( new_state, imgs, imgs, gr.update(choices=layer_choices, value=layer_value), gr.update(value=0), selected_label, gr.update(choices=[c[1] for c in hist_choices], value=sel), chips, gr.update(visible=False), [], pptx_path, zip_path, f"Loaded session {session_id}. {exp_msg}", ) def on_history_need_images(state: Dict[str, Any], node_id: str): if not node_id or node_id not in state.get("nodes", {}): return state, "Unknown node." imgs = state["nodes"][node_id].get("images", []) if imgs: return state, "OK" session_id = state.get("session_id") if not session_id: return state, "No session_id." manifest, msg = _load_session_manifest(session_id) if not manifest: return state, f"Cannot load manifest: {msg}" node_obj = (manifest.get("nodes", {}) or {}).get(node_id, {}) num_layers = int(node_obj.get("num_layers", 0)) if num_layers <= 0: return state, "No layers in manifest for this node." imgs2, msg2 = _load_node_images(session_id, node_id, num_layers) if not imgs2: return state, f"Failed to load images: {msg2}" state["nodes"][node_id]["images"] = imgs2 return state, "Loaded images." # ------------------------- # Build UI # ------------------------- ensure_dirname(LOG_DIR) examples = [ "assets/test_images/1.png", "assets/test_images/2.png", "assets/test_images/3.png", "assets/test_images/4.png", "assets/test_images/5.png", "assets/test_images/6.png", "assets/test_images/7.png", "assets/test_images/8.png", "assets/test_images/9.png", "assets/test_images/10.png", "assets/test_images/11.png", "assets/test_images/12.png", "assets/test_images/13.png", ] with gr.Blocks() as demo: state = gr.State(_init_state()) gr.HTML( '' ) persistence_banner = gr.Markdown(_persistence_status_text()) with gr.Row(): btn_init_ds = gr.Button("Init dataset repo", variant="secondary") btn_refresh_sessions = gr.Button("Refresh sessions", variant="secondary") ds_status = gr.Markdown("") with gr.Row(): load_session_dd = gr.Dropdown( label="Load session (from dataset)", choices=[], value=None, allow_custom_value=True, ) btn_load_session = gr.Button("Load session", variant="primary") gr.Markdown( """ The text prompt describes the overall content of the input image. It is not designed to control the semantic content of individual layers explicitly. """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Input Image", image_mode="RGBA") with gr.Accordion("Advanced Settings", open=False): prompt = gr.Textbox( label="Prompt (Optional)", placeholder="Describe the image (optional)", value="", lines=2, ) neg_prompt = gr.Textbox( label="Negative Prompt (Optional)", placeholder="Negative prompt", value=" ", lines=2, ) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) true_guidance_scale = gr.Slider( label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=4.0 ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=100, step=1, value=50 # DO NOT CHANGE ) layer = gr.Slider(label="Layers", minimum=2, maximum=10, step=1, value=7) # DO NOT CHANGE resolution = gr.Radio(label="Processing resolution", choices=[640, 1024], value=640) # DO NOT CHANGE cfg_norm = gr.Checkbox(label="Enable CFG normalization", value=True) use_en_prompt = gr.Checkbox(label="Auto caption language: True=EN, False=ZH", value=True) gpu_duration = gr.Textbox( label="GPU duration override (seconds, 20..1500)", value="1000", lines=1, placeholder="e.g. 120, 300, 1000, 1500", ) btn_decompose = gr.Button("Decompose!", variant="primary") with gr.Group(): gr.Markdown("### Refine (Recursive Decomposition)") sub_layers = gr.Slider(label="Sub-layers (Refine)", minimum=2, maximum=10, step=1, value=3) btn_refine = gr.Button("Refine selected layer", variant="primary") with gr.Column(scale=2): gr.Markdown("### Current node layers") gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png") gr.Markdown("### Layer picker (Photoshop-style)") layer_picker = gr.Gallery(label="Pick a layer", columns=8, rows=1, format="png") with gr.Row(): layer_dropdown = gr.Dropdown(label="Refine layer", choices=[], value=None) selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, precision=0, interactive=False) selected_layer_label = gr.Markdown("Selected: -") with gr.Accordion("Refined layers (last refine)", open=True, visible=False) as refined_block: refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png") gr.Markdown("### History (nodes)") with gr.Row(): history_dd = gr.Dropdown(label="Node id", choices=[], value=None) chips_md = gr.Markdown("[root] [parent:-] [children:0]") with gr.Row(): btn_back_parent = gr.Button("← back to parent", variant="secondary") btn_duplicate = gr.Button("Duplicate node (branch)", variant="secondary") with gr.Row(): rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Type new name and click Rename") btn_rename = gr.Button("Rename", variant="secondary") with gr.Row(): btn_export = gr.Button("Export selected node (ZIP/PPTX)", variant="primary") btn_save = gr.Button("Save selected node to dataset", variant="primary") with gr.Row(): export_pptx = gr.File(label="Download PPTX") export_zip = gr.File(label="Download ZIP") status = gr.Markdown("") seed_used = gr.Textbox(label="Seed used", value="", interactive=False) gr.Examples( examples=examples, inputs=[input_image], outputs=[gallery, export_pptx, export_zip], fn=lambda img: ([], None, None), cache_examples=False, run_on_click=False, ) btn_init_ds.click(fn=on_init_dataset, outputs=[ds_status]) btn_refresh_sessions.click(fn=on_refresh_sessions, outputs=[load_session_dd, ds_status]) btn_load_session.click( fn=on_load_session, inputs=[state, load_session_dd], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status ], ) btn_decompose.click( fn=on_decompose_click, inputs=[ state, input_image, seed, randomize_seed, prompt, neg_prompt, true_guidance_scale, num_inference_steps, layer, cfg_norm, use_en_prompt, resolution, gpu_duration ], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used ], ) layer_picker.select( fn=on_layer_pick_from_gallery, inputs=[state], outputs=[selected_layer_idx, layer_dropdown, selected_layer_label], ) layer_dropdown.change( fn=on_layer_pick_from_dropdown, inputs=[state, layer_dropdown], outputs=[selected_layer_idx, selected_layer_label], ) btn_refine.click( fn=on_refine_click, inputs=[ state, selected_layer_idx, sub_layers, prompt, neg_prompt, true_guidance_scale, num_inference_steps, cfg_norm, use_en_prompt, resolution, gpu_duration, seed, randomize_seed ], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used ], ) def _history_select_with_lazy(state_, node_id_): state_, _ = on_history_need_images(state_, node_id_) return on_history_select(state_, node_id_) history_dd.change( fn=_history_select_with_lazy, inputs=[state, history_dd], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status ], ) btn_back_parent.click( fn=on_back_to_parent, inputs=[state], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status ], ) btn_duplicate.click( fn=on_duplicate_node, inputs=[state], outputs=[ state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status ], ) btn_rename.click(fn=on_rename_node, inputs=[state, rename_text], outputs=[state, history_dd, chips_md, status]) btn_export.click(fn=on_export_selected, inputs=[state], outputs=[export_pptx, export_zip, status]) btn_save.click(fn=on_save_current, inputs=[state], outputs=[status]) demo.queue() if __name__ == "__main__": demo.launch()