Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import io | |
| import json | |
| import time | |
| import uuid | |
| import random | |
| import tempfile | |
| import zipfile | |
| from dataclasses import dataclass, asdict | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from pptx import Presentation | |
| from diffusers import QwenImageLayeredPipeline | |
| from huggingface_hub import HfApi, login | |
| from huggingface_hub.utils import HfHubHTTPError | |
| LOG_DIR = "/tmp/local" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # ------------------------- | |
| # HF auth (Spaces secrets) | |
| # ------------------------- | |
| def _get_hf_token() -> Optional[str]: | |
| # priority: HF_TOKEN -> hf -> HUGGINGFACEHUB_API_TOKEN | |
| return ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("hf") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| def _get_dataset_repo() -> Optional[str]: | |
| # priority: DATASET_REPO -> HF_DATASET_REPO | |
| return os.environ.get("DATASET_REPO") or os.environ.get("HF_DATASET_REPO") | |
| HF_TOKEN = _get_hf_token() | |
| DATASET_REPO = _get_dataset_repo() | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| except Exception as e: | |
| print("HF login failed:", repr(e)) | |
| # ------------------------- | |
| # Helpers | |
| # ------------------------- | |
| def ensure_dirname(path: str): | |
| if path and not os.path.exists(path): | |
| os.makedirs(path, exist_ok=True) | |
| def px_to_emu(px, dpi=96): | |
| inch = px / dpi | |
| emu = inch * 914400 | |
| return int(emu) | |
| def imagelist_to_pptx_from_pils(images: List[Image.Image]) -> str: | |
| if not images: | |
| raise ValueError("No images to export") | |
| w, h = images[0].size | |
| prs = Presentation() | |
| prs.slide_width = px_to_emu(w) | |
| prs.slide_height = px_to_emu(h) | |
| slide = prs.slides.add_slide(prs.slide_layouts[6]) | |
| left = top = 0 | |
| for img in images: | |
| tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) # /tmp | |
| img.save(tmp.name) | |
| slide.shapes.add_picture( | |
| tmp.name, | |
| left, | |
| top, | |
| width=px_to_emu(w), | |
| height=px_to_emu(h), | |
| ) | |
| out = tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) # /tmp | |
| prs.save(out.name) | |
| return out.name | |
| def imagelist_to_zip_from_pils(images: List[Image.Image], prefix: str = "layer") -> str: | |
| outzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) # /tmp | |
| with zipfile.ZipFile(outzip.name, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for i, img in enumerate(images): | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| zipf.writestr(f"{prefix}_{i+1}.png", buf.getvalue()) | |
| return outzip.name | |
| def _clamp_int(x, default: int, lo: int, hi: int) -> int: | |
| try: | |
| v = int(x) | |
| except Exception: | |
| v = default | |
| return max(lo, min(hi, v)) | |
| def _normalize_resolution(resolution: Any) -> int: | |
| resolution = _clamp_int(resolution, default=640, lo=640, hi=1024) | |
| if resolution not in (640, 1024): | |
| resolution = 640 | |
| return resolution | |
| def _normalize_input_image(input_image: Any) -> Image.Image: | |
| if isinstance(input_image, list): | |
| input_image = input_image[0] | |
| if isinstance(input_image, str): | |
| pil_image = Image.open(input_image).convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, Image.Image): | |
| pil_image = input_image.convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, np.ndarray): | |
| pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA") | |
| else: | |
| raise ValueError(f"Unsupported input_image type: {type(input_image)}") | |
| return pil_image | |
| # ------------------------- | |
| # Dataset persistence helpers | |
| # ------------------------- | |
| def ds_enabled() -> bool: | |
| return bool(_get_hf_token()) and bool(_get_dataset_repo()) | |
| def ds_api() -> HfApi: | |
| token = _get_hf_token() | |
| if not token: | |
| raise RuntimeError("HF token missing") | |
| return HfApi(token=token) | |
| def ds_repo_id() -> str: | |
| repo = _get_dataset_repo() | |
| if not repo: | |
| raise RuntimeError("DATASET_REPO/HF_DATASET_REPO missing") | |
| return repo | |
| def ds_ensure_repo() -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled: missing HF token and/or dataset repo env." | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True) | |
| return True, f"Dataset repo ready: {repo_id}" | |
| except HfHubHTTPError as e: | |
| return False, f"Failed to create/ensure dataset repo: {e}" | |
| except Exception as e: | |
| return False, f"Failed to create/ensure dataset repo: {repr(e)}" | |
| def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled: missing HF token and/or dataset repo env." | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(data) | |
| tmp.flush() | |
| api.upload_file( | |
| path_or_fileobj=tmp.name, | |
| path_in_repo=path_in_repo, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| commit_message=commit_message, | |
| ) | |
| return True, f"Uploaded: {path_in_repo}" | |
| except HfHubHTTPError as e: | |
| # 403 "must use a write token" — это сюда | |
| return False, f"Upload failed (HTTP): {e}" | |
| except Exception as e: | |
| return False, f"Upload failed: {repr(e)}" | |
| def ds_download_bytes(path_in_repo: str) -> Tuple[Optional[bytes], str]: | |
| if not ds_enabled(): | |
| return None, "Dataset persistence disabled" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| tmpdir = tempfile.mkdtemp() | |
| local_path = api.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=path_in_repo, | |
| local_dir=tmpdir, | |
| ) | |
| with open(local_path, "rb") as f: | |
| return f.read(), "OK" | |
| except HfHubHTTPError as e: | |
| return None, f"Download failed (HTTP): {e}" | |
| except Exception as e: | |
| return None, f"Download failed: {repr(e)}" | |
| def _root_index_path() -> str: | |
| return "index.json" | |
| def ds_read_root_index() -> Dict[str, Any]: | |
| """ | |
| Root index.json (backward compatible): | |
| { | |
| "id": "<last_session_id>", | |
| "last_session_id": "<last_session_id>", | |
| "sessions": ["sess_...", ...], | |
| "updated_at": 123.0 | |
| } | |
| """ | |
| b, _ = ds_download_bytes(_root_index_path()) | |
| if b is None: | |
| return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()} | |
| try: | |
| obj = json.loads(b.decode("utf-8")) | |
| if "last_session_id" not in obj and "id" in obj: | |
| obj["last_session_id"] = obj.get("id") | |
| if "id" not in obj: | |
| obj["id"] = obj.get("last_session_id") | |
| if "sessions" not in obj or not isinstance(obj["sessions"], list): | |
| obj["sessions"] = [] | |
| return obj | |
| except Exception: | |
| return {"id": None, "last_session_id": None, "sessions": [], "updated_at": time.time()} | |
| def ds_write_root_index(last_session_id: Optional[str]) -> Tuple[bool, str]: | |
| idx = ds_read_root_index() | |
| idx["last_session_id"] = last_session_id | |
| idx["id"] = last_session_id # FIX: KeyError('id') | |
| idx["updated_at"] = time.time() | |
| if last_session_id: | |
| idx["sessions"] = [last_session_id] + [s for s in idx.get("sessions", []) if s != last_session_id] | |
| b = json.dumps(idx, ensure_ascii=False, indent=2).encode("utf-8") | |
| return ds_upload_bytes(_root_index_path(), b, f"update root index last_session_id={last_session_id}") | |
| def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]: | |
| if not ds_enabled(): | |
| return [], "Dataset persistence disabled" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| # Prefer root index (fast) | |
| sess = [] | |
| try: | |
| root = ds_read_root_index() | |
| sess = [s for s in root.get("sessions", []) if isinstance(s, str)] | |
| except Exception: | |
| sess = [] | |
| # Fallback scan | |
| if not sess: | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") | |
| found = set() | |
| for p in files: | |
| if p.startswith("sessions/") and (p.endswith("/index.json") or p.endswith("/session.json")): | |
| parts = p.split("/") | |
| if len(parts) >= 3: | |
| found.add(parts[1]) | |
| sess = sorted(found, reverse=True) | |
| sess = sess[:max_sessions] | |
| return sess, f"Found {len(sess)} session(s)" | |
| except Exception as e: | |
| return [], f"List sessions failed: {repr(e)}" | |
| # ------------------------- | |
| # Node / History model | |
| # ------------------------- | |
| class NodeMeta: | |
| node_id: str | |
| name: str | |
| parent_id: Optional[str] | |
| children: List[str] | |
| op: str # "decompose" | "refine" | "duplicate" | |
| created_at: float | |
| source_node_id: Optional[str] = None | |
| source_layer_idx: Optional[int] = None | |
| sub_layers: Optional[int] = None | |
| settings: Optional[Dict[str, Any]] = None | |
| def _new_id(prefix: str) -> str: | |
| return f"{prefix}_{uuid.uuid4().hex[:10]}" | |
| def _make_chips(state: Dict[str, Any]) -> str: | |
| node_id = state.get("selected_node_id") | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| if not node_id or node_id not in nodes: | |
| return "[root] [parent:-] [children:0]" | |
| meta = nodes[node_id]["meta"] | |
| parent = meta.get("parent_id") or "-" | |
| children = meta.get("children") or [] | |
| root = state.get("root_node_id") or "-" | |
| return f"[root:{root}] [parent:{parent}] [children:{len(children)}]" | |
| def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]: | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| items = [] | |
| for nid, obj in nodes.items(): | |
| meta = obj["meta"] | |
| items.append((meta.get("created_at", 0.0), nid, meta.get("name", nid))) | |
| items.sort(key=lambda x: x[0]) | |
| return [(f"{name} — {nid}", nid) for _, nid, name in items] | |
| def _get_node_images(state: Dict[str, Any], node_id: str) -> List[Image.Image]: | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| if node_id not in nodes: | |
| return [] | |
| return nodes[node_id].get("images", []) or [] | |
| def _add_node( | |
| state: Dict[str, Any], | |
| *, | |
| name: str, | |
| parent_id: Optional[str], | |
| op: str, | |
| images: List[Image.Image], | |
| settings: Optional[Dict[str, Any]] = None, | |
| source_node_id: Optional[str] = None, | |
| source_layer_idx: Optional[int] = None, | |
| sub_layers: Optional[int] = None, | |
| ) -> str: | |
| node_id = _new_id("node") | |
| meta = NodeMeta( | |
| node_id=node_id, | |
| name=name, | |
| parent_id=parent_id, | |
| children=[], | |
| op=op, | |
| created_at=time.time(), | |
| source_node_id=source_node_id, | |
| source_layer_idx=source_layer_idx, | |
| sub_layers=sub_layers, | |
| settings=settings or {}, | |
| ) | |
| state.setdefault("nodes", {}) | |
| state["nodes"][node_id] = {"meta": asdict(meta), "images": images} | |
| if parent_id and parent_id in state["nodes"]: | |
| state["nodes"][parent_id]["meta"].setdefault("children", []) | |
| state["nodes"][parent_id]["meta"]["children"].append(node_id) | |
| return node_id | |
| def _rename_node(state: Dict[str, Any], node_id: str, new_name: str): | |
| if not new_name: | |
| return | |
| if node_id in state.get("nodes", {}): | |
| state["nodes"][node_id]["meta"]["name"] = new_name | |
| def _duplicate_node(state: Dict[str, Any], node_id: str) -> Optional[str]: | |
| if node_id not in state.get("nodes", {}): | |
| return None | |
| src = state["nodes"][node_id] | |
| meta = src["meta"] | |
| parent_id = meta.get("parent_id") | |
| images = src.get("images", []) | |
| name = f"{meta.get('name','node')} (copy)" | |
| return _add_node( | |
| state, | |
| name=name, | |
| parent_id=parent_id, | |
| op="duplicate", | |
| images=images, | |
| settings=meta.get("settings") or {}, | |
| ) | |
| # ------------------------- | |
| # GPU duration + GPU-only pipeline runner | |
| # IMPORTANT: pipeline init is INSIDE GPU worker (ZeroGPU friendly) | |
| # ------------------------- | |
| def get_duration(*args, **kwargs): | |
| # wrapper may pass random kwargs like pil_image_rgba etc; ignore | |
| gpu_duration = kwargs.get("gpu_duration", 1000) | |
| return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500) | |
| _GPU_PIPE: Optional[QwenImageLayeredPipeline] = None | |
| def _gpu_get_pipe() -> QwenImageLayeredPipeline: | |
| global _GPU_PIPE | |
| if _GPU_PIPE is not None: | |
| return _GPU_PIPE | |
| # This function runs inside GPU worker (due to @spaces.GPU on caller) | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "CUDA is not available inside GPU worker. " | |
| "Check Space hardware: it must be ZeroGPU/GPU, not CPU." | |
| ) | |
| dtype = torch.bfloat16 | |
| _GPU_PIPE = QwenImageLayeredPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Layered", | |
| torch_dtype=dtype, | |
| ).to("cuda") | |
| return _GPU_PIPE | |
| def gpu_run_pipeline( | |
| pil_image_rgba: Image.Image, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| layer=4, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| ): | |
| # Everything heavy here happens on GPU worker | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| resolution = _normalize_resolution(resolution) | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("No CUDA GPUs are available (GPU worker not running).") | |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
| inputs = { | |
| "image": pil_image_rgba, | |
| "generator": generator, | |
| "true_cfg_scale": float(true_guidance_scale), | |
| "prompt": prompt, | |
| "negative_prompt": neg_prompt, | |
| "num_inference_steps": int(num_inference_steps), | |
| "num_images_per_prompt": 1, | |
| "layers": int(layer), | |
| "resolution": int(resolution), | |
| "cfg_normalize": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| } | |
| # reduce allocator hiccups | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| pipe = _gpu_get_pipe() | |
| with torch.inference_mode(): | |
| out = pipe(**inputs) | |
| output_images = out.images[0] # list of PIL layers | |
| return output_images, int(seed), inputs | |
| # ------------------------- | |
| # Dataset persistence: save/load nodes + session | |
| # ------------------------- | |
| def _pil_to_png_bytes(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def _png_bytes_to_pil(b: bytes) -> Image.Image: | |
| return Image.open(io.BytesIO(b)).convert("RGBA") | |
| def _session_base(session_id: str) -> str: | |
| return f"sessions/{session_id}" | |
| def _node_base(session_id: str, node_id: str) -> str: | |
| return f"{_session_base(session_id)}/nodes/{node_id}" | |
| def _persist_node_to_dataset(state: Dict[str, Any], node_id: str) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled. Set DATASET_REPO and HF_TOKEN/hf." | |
| ok, msg = ds_ensure_repo() | |
| if not ok: | |
| return False, msg | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return False, "No session_id in state (run Decompose first)" | |
| nodes = state.get("nodes", {}) | |
| if node_id not in nodes: | |
| return False, "Unknown node_id" | |
| node = nodes[node_id] | |
| meta = node["meta"] | |
| imgs: List[Image.Image] = node.get("images", []) or [] | |
| node_json = json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8") | |
| path_node_json = f"{_node_base(session_id, node_id)}/node.json" | |
| ok1, msg1 = ds_upload_bytes(path_node_json, node_json, f"save node {node_id}") | |
| if not ok1: | |
| return False, msg1 | |
| for i, img in enumerate(imgs): | |
| b = _pil_to_png_bytes(img) | |
| path_img = f"{_node_base(session_id, node_id)}/layer_{i+1}.png" | |
| ok2, msg2 = ds_upload_bytes(path_img, b, f"save node {node_id} layer {i+1}") | |
| if not ok2: | |
| return False, msg2 | |
| return True, f"Saved node {node_id} to dataset" | |
| def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled" | |
| ok, msg = ds_ensure_repo() | |
| if not ok: | |
| return False, msg | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return False, "No session_id" | |
| manifest = { | |
| "session_id": session_id, | |
| "created_at": state.get("created_at"), | |
| "root_node_id": state.get("root_node_id"), | |
| "selected_node_id": state.get("selected_node_id"), | |
| "nodes": { | |
| nid: {"meta": obj["meta"], "num_layers": len(obj.get("images", []) or [])} | |
| for nid, obj in (state.get("nodes", {}) or {}).items() | |
| }, | |
| } | |
| b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8") | |
| # Save under sessions/<id>/index.json (как у тебя в датасете на скрине) | |
| ok1, msg1 = ds_upload_bytes(f"{_session_base(session_id)}/index.json", b, f"save session index {session_id}") | |
| if not ok1: | |
| return False, msg1 | |
| # Optional duplicate name for compatibility | |
| ok2, msg2 = ds_upload_bytes(f"{_session_base(session_id)}/session.json", b, f"save session manifest {session_id}") | |
| if not ok2: | |
| return False, msg2 | |
| # Root index.json (fix KeyError('id') + last session) | |
| ok3, msg3 = ds_write_root_index(session_id) | |
| if not ok3: | |
| return False, msg3 | |
| return True, "Saved session manifest + root index" | |
| def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]: | |
| for p in (f"{_session_base(session_id)}/index.json", f"{_session_base(session_id)}/session.json"): | |
| b, msg = ds_download_bytes(p) | |
| if b is None: | |
| continue | |
| try: | |
| return json.loads(b.decode("utf-8")), "OK" | |
| except Exception as e: | |
| return None, f"Failed to parse manifest: {repr(e)}" | |
| return None, f"Manifest not found for session {session_id}" | |
| def _load_node_images(session_id: str, node_id: str, num_layers: int) -> Tuple[List[Image.Image], str]: | |
| imgs: List[Image.Image] = [] | |
| for i in range(num_layers): | |
| b, msg = ds_download_bytes(f"{_node_base(session_id, node_id)}/layer_{i+1}.png") | |
| if b is None: | |
| return [], msg | |
| imgs.append(_png_bytes_to_pil(b)) | |
| return imgs, "OK" | |
| # ------------------------- | |
| # UI callbacks | |
| # ------------------------- | |
| def _init_state() -> Dict[str, Any]: | |
| return { | |
| "session_id": None, | |
| "created_at": None, | |
| "root_node_id": None, | |
| "selected_node_id": None, | |
| "nodes": {}, | |
| "last_refined_node_id": None, | |
| } | |
| def _persistence_status_text() -> str: | |
| tok = _get_hf_token() | |
| repo = _get_dataset_repo() | |
| if tok and repo: | |
| return f"✅ Dataset persistence enabled: `{repo}`" | |
| if repo and not tok: | |
| return "⚠️ Dataset repo set, but HF_TOKEN/hf missing" | |
| if tok and not repo: | |
| return "⚠️ HF_TOKEN/hf set, but DATASET_REPO missing" | |
| return "⚠️ Dataset persistence disabled (set HF_TOKEN + DATASET_REPO secrets to enable)" | |
| def on_refresh_sessions(): | |
| sessions, msg = ds_list_sessions() | |
| return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg | |
| def on_init_dataset(): | |
| ok, msg = ds_ensure_repo() | |
| return msg | |
| def _current_node_export(state: Dict[str, Any], node_id: str) -> Tuple[Optional[str], Optional[str], str]: | |
| imgs = _get_node_images(state, node_id) | |
| if not imgs: | |
| return None, None, "No images to export" | |
| pptx_path = imagelist_to_pptx_from_pils(imgs) | |
| zip_path = imagelist_to_zip_from_pils(imgs, prefix=f"{node_id}_layer") | |
| return pptx_path, zip_path, "OK" | |
| def _build_layer_dropdown(n: int) -> Tuple[List[str], Optional[str]]: | |
| if n <= 0: | |
| return [], None | |
| choices = [f"Layer {i+1}" for i in range(n)] | |
| return choices, choices[0] | |
| def _layer_label(idx: int, n: int) -> str: | |
| if n <= 0: | |
| return "Selected: -" | |
| idx = max(0, min(n - 1, idx)) | |
| return f"Selected: Layer {idx+1} / {n}" | |
| def on_decompose_click( | |
| state: Dict[str, Any], | |
| input_image, | |
| seed, | |
| randomize_seed, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| layer, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| ): | |
| if state is None or not isinstance(state, dict): | |
| state = _init_state() | |
| pil_image = _normalize_input_image(input_image) | |
| if not state.get("session_id"): | |
| state["session_id"] = _new_id("sess") | |
| state["created_at"] = time.time() | |
| layers_out, used_seed, _used_inputs = gpu_run_pipeline( | |
| pil_image_rgba=pil_image, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=layer, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| settings_snapshot = { | |
| "seed": used_seed, | |
| "randomize_seed": bool(randomize_seed), | |
| "prompt": prompt, | |
| "neg_prompt": neg_prompt, | |
| "true_guidance_scale": float(true_guidance_scale), | |
| "num_inference_steps": int(num_inference_steps), | |
| "layers": int(layer), | |
| "resolution": int(_normalize_resolution(resolution)), | |
| "cfg_norm": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), | |
| } | |
| state["nodes"] = {} | |
| state["last_refined_node_id"] = None | |
| root_id = _add_node( | |
| state, | |
| name="root (decompose)", | |
| parent_id=None, | |
| op="decompose", | |
| images=layers_out, | |
| settings=settings_snapshot, | |
| ) | |
| state["root_node_id"] = root_id | |
| state["selected_node_id"] = root_id | |
| n_layers = len(layers_out) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| refined_visible = gr.update(visible=False) | |
| refined_gallery = [] | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, root_id) | |
| status = f"Decomposed into {n_layers} layer(s). Seed={used_seed}. {exp_msg}" | |
| return ( | |
| state, | |
| layers_out, | |
| layers_out, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=root_id), | |
| chips, | |
| refined_visible, | |
| refined_gallery, | |
| pptx_path, | |
| zip_path, | |
| status, | |
| str(used_seed), | |
| ) | |
| def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str): | |
| node_id = state.get("selected_node_id") | |
| imgs = _get_node_images(state, node_id) if node_id else [] | |
| n = len(imgs) | |
| if not layer_name or not layer_name.startswith("Layer "): | |
| idx = 0 | |
| else: | |
| try: | |
| idx = int(layer_name.replace("Layer ", "").strip()) - 1 | |
| except Exception: | |
| idx = 0 | |
| idx = max(0, min(n - 1, idx)) if n > 0 else 0 | |
| return gr.update(value=idx), _layer_label(idx, n) | |
| def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData): | |
| node_id = state.get("selected_node_id") | |
| imgs = _get_node_images(state, node_id) if node_id else [] | |
| n = len(imgs) | |
| idx = int(evt.index) if evt and evt.index is not None else 0 | |
| idx = max(0, min(n - 1, idx)) if n > 0 else 0 | |
| dd_value = f"Layer {idx+1}" if n > 0 else None | |
| return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n) | |
| def _refine_from_source( | |
| state: Dict[str, Any], | |
| source_node_id: str, | |
| source_layer_idx: int, | |
| sub_layers: int, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ): | |
| src_imgs = _get_node_images(state, source_node_id) | |
| if not src_imgs: | |
| raise ValueError("Source node has no images") | |
| if source_layer_idx < 0 or source_layer_idx >= len(src_imgs): | |
| raise ValueError("Invalid layer index") | |
| selected_layer_img = src_imgs[source_layer_idx] | |
| layers_out, used_seed, _used_inputs = gpu_run_pipeline( | |
| pil_image_rgba=selected_layer_img, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=sub_layers, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| settings_snapshot = { | |
| "seed": used_seed, | |
| "randomize_seed": bool(randomize_seed), | |
| "prompt": prompt, | |
| "neg_prompt": neg_prompt, | |
| "true_guidance_scale": float(true_guidance_scale), | |
| "num_inference_steps": int(num_inference_steps), | |
| "layers": int(sub_layers), | |
| "resolution": int(_normalize_resolution(resolution)), | |
| "cfg_norm": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), | |
| "refined_from": {"source_node_id": source_node_id, "source_layer_idx": int(source_layer_idx)}, | |
| } | |
| return layers_out, used_seed, settings_snapshot | |
| def on_refine_click( | |
| state: Dict[str, Any], | |
| selected_layer_idx: int, | |
| sub_layers: int, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ): | |
| if not state.get("selected_node_id"): | |
| return ( | |
| state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, "No selected node. Run Decompose first.", gr.update() | |
| ) | |
| source_node_id = state["selected_node_id"] | |
| src_imgs = _get_node_images(state, source_node_id) | |
| if not src_imgs: | |
| return ( | |
| state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, "Selected node has no images.", gr.update() | |
| ) | |
| n = len(src_imgs) | |
| idx = int(selected_layer_idx) if selected_layer_idx is not None else 0 | |
| idx = max(0, min(n - 1, idx)) | |
| sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10) | |
| layers_out, used_seed, settings_snapshot = _refine_from_source( | |
| state, | |
| source_node_id=source_node_id, | |
| source_layer_idx=idx, | |
| sub_layers=sub_layers, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| ) | |
| child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}" | |
| child_id = _add_node( | |
| state, | |
| name=child_name, | |
| parent_id=source_node_id, | |
| op="refine", | |
| images=layers_out, | |
| settings=settings_snapshot, | |
| source_node_id=source_node_id, | |
| source_layer_idx=idx, | |
| sub_layers=sub_layers, | |
| ) | |
| state["selected_node_id"] = child_id | |
| state["last_refined_node_id"] = child_id | |
| n_layers = len(layers_out) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, child_id) | |
| status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}" | |
| return ( | |
| state, | |
| layers_out, | |
| layers_out, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=child_id), | |
| chips, | |
| gr.update(visible=True), | |
| layers_out, | |
| pptx_path, | |
| zip_path, | |
| status, | |
| gr.update(), | |
| ) | |
| def on_history_select(state: Dict[str, Any], node_id: str): | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return ( | |
| state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, "Unknown node." | |
| ) | |
| state["selected_node_id"] = node_id | |
| imgs = _get_node_images(state, node_id) | |
| n_layers = len(imgs) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, node_id) | |
| return ( | |
| state, | |
| imgs, | |
| imgs, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=node_id), | |
| chips, | |
| gr.update(visible=False), | |
| [], | |
| pptx_path, | |
| zip_path, | |
| f"Selected node: {node_id}. {exp_msg}", | |
| ) | |
| def on_back_to_parent(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return state, gr.update(), "No selected node." | |
| parent = state["nodes"][node_id]["meta"].get("parent_id") | |
| if not parent: | |
| return state, gr.update(), "Already at root." | |
| return on_history_select(state, parent) | |
| def on_duplicate_node(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return state, gr.update(), "No selected node." | |
| new_id = _duplicate_node(state, node_id) | |
| if not new_id: | |
| return state, gr.update(), "Duplicate failed." | |
| return on_history_select(state, new_id) | |
| def on_rename_node(state: Dict[str, Any], new_name: str): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return state, gr.update(), "No selected node." | |
| _rename_node(state, node_id, new_name) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed." | |
| def on_export_selected(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return None, None, "No selected node." | |
| pptx_path, zip_path, msg = _current_node_export(state, node_id) | |
| return pptx_path, zip_path, msg | |
| def on_save_current(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return "Nothing to save." | |
| ok1, msg1 = _persist_node_to_dataset(state, node_id) | |
| if not ok1: | |
| return msg1 | |
| ok2, msg2 = _persist_session_manifest(state) | |
| if not ok2: | |
| return msg2 | |
| return f"✅ Saved node + session manifest. {msg1}" | |
| def on_load_session(state: Dict[str, Any], session_id: str): | |
| if not session_id: | |
| return ( | |
| state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, "Pick a session id." | |
| ) | |
| manifest, msg = _load_session_manifest(session_id) | |
| if manifest is None: | |
| return ( | |
| state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, msg | |
| ) | |
| new_state = _init_state() | |
| new_state["session_id"] = manifest.get("session_id") or session_id | |
| new_state["created_at"] = manifest.get("created_at") | |
| new_state["root_node_id"] = manifest.get("root_node_id") | |
| new_state["selected_node_id"] = manifest.get("selected_node_id") or manifest.get("root_node_id") | |
| nodes_meta = manifest.get("nodes", {}) or {} | |
| for nid, obj in nodes_meta.items(): | |
| meta = obj.get("meta") or {} | |
| new_state["nodes"][nid] = {"meta": meta, "images": []} | |
| sel = new_state["selected_node_id"] | |
| if not sel or sel not in nodes_meta: | |
| sel = new_state["root_node_id"] | |
| new_state["selected_node_id"] = sel | |
| if sel and sel in nodes_meta: | |
| num_layers = int(nodes_meta[sel].get("num_layers", 0)) | |
| imgs, msg2 = _load_node_images(session_id, sel, num_layers) | |
| if not imgs: | |
| return ( | |
| new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load images: {msg2}" | |
| ) | |
| new_state["nodes"][sel]["images"] = imgs | |
| root = new_state.get("root_node_id") | |
| if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]: | |
| rl = int(nodes_meta[root].get("num_layers", 0)) | |
| rimgs, _ = _load_node_images(session_id, root, rl) | |
| if rimgs: | |
| new_state["nodes"][root]["images"] = rimgs | |
| imgs = _get_node_images(new_state, sel) if sel else [] | |
| n_layers = len(imgs) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(new_state) | |
| chips = _make_chips(new_state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(new_state, sel) if sel else (None, None, "No node") | |
| return ( | |
| new_state, | |
| imgs, | |
| imgs, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=sel), | |
| chips, | |
| gr.update(visible=False), | |
| [], | |
| pptx_path, | |
| zip_path, | |
| f"Loaded session {session_id}. {exp_msg}", | |
| ) | |
| def on_history_need_images(state: Dict[str, Any], node_id: str): | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return state, "Unknown node." | |
| imgs = state["nodes"][node_id].get("images", []) | |
| if imgs: | |
| return state, "OK" | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return state, "No session_id." | |
| manifest, msg = _load_session_manifest(session_id) | |
| if not manifest: | |
| return state, f"Cannot load manifest: {msg}" | |
| node_obj = (manifest.get("nodes", {}) or {}).get(node_id, {}) | |
| num_layers = int(node_obj.get("num_layers", 0)) | |
| if num_layers <= 0: | |
| return state, "No layers in manifest for this node." | |
| imgs2, msg2 = _load_node_images(session_id, node_id, num_layers) | |
| if not imgs2: | |
| return state, f"Failed to load images: {msg2}" | |
| state["nodes"][node_id]["images"] = imgs2 | |
| return state, "Loaded images." | |
| # ------------------------- | |
| # Build UI | |
| # ------------------------- | |
| ensure_dirname(LOG_DIR) | |
| examples = [ | |
| "assets/test_images/1.png", | |
| "assets/test_images/2.png", | |
| "assets/test_images/3.png", | |
| "assets/test_images/4.png", | |
| "assets/test_images/5.png", | |
| "assets/test_images/6.png", | |
| "assets/test_images/7.png", | |
| "assets/test_images/8.png", | |
| "assets/test_images/9.png", | |
| "assets/test_images/10.png", | |
| "assets/test_images/11.png", | |
| "assets/test_images/12.png", | |
| "assets/test_images/13.png", | |
| ] | |
| with gr.Blocks() as demo: | |
| state = gr.State(_init_state()) | |
| gr.HTML( | |
| '<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" ' | |
| 'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">' | |
| ) | |
| persistence_banner = gr.Markdown(_persistence_status_text()) | |
| with gr.Row(): | |
| btn_init_ds = gr.Button("Init dataset repo", variant="secondary") | |
| btn_refresh_sessions = gr.Button("Refresh sessions", variant="secondary") | |
| ds_status = gr.Markdown("") | |
| with gr.Row(): | |
| load_session_dd = gr.Dropdown( | |
| label="Load session (from dataset)", | |
| choices=[], | |
| value=None, | |
| allow_custom_value=True, | |
| ) | |
| btn_load_session = gr.Button("Load session", variant="primary") | |
| gr.Markdown( | |
| """ | |
| The text prompt describes the overall content of the input image. | |
| It is not designed to control the semantic content of individual layers explicitly. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Input Image", image_mode="RGBA") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt (Optional)", | |
| placeholder="Describe the image (optional)", | |
| value="", | |
| lines=2, | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="Negative prompt", | |
| value=" ", | |
| lines=2, | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| true_guidance_scale = gr.Slider( | |
| label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=4.0 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", minimum=1, maximum=100, step=1, value=50 # DO NOT CHANGE | |
| ) | |
| layer = gr.Slider(label="Layers", minimum=2, maximum=10, step=1, value=7) # DO NOT CHANGE | |
| resolution = gr.Radio(label="Processing resolution", choices=[640, 1024], value=640) # DO NOT CHANGE | |
| cfg_norm = gr.Checkbox(label="Enable CFG normalization", value=True) | |
| use_en_prompt = gr.Checkbox(label="Auto caption language: True=EN, False=ZH", value=True) | |
| gpu_duration = gr.Textbox( | |
| label="GPU duration override (seconds, 20..1500)", | |
| value="1000", | |
| lines=1, | |
| placeholder="e.g. 120, 300, 1000, 1500", | |
| ) | |
| btn_decompose = gr.Button("Decompose!", variant="primary") | |
| with gr.Group(): | |
| gr.Markdown("### Refine (Recursive Decomposition)") | |
| sub_layers = gr.Slider(label="Sub-layers (Refine)", minimum=2, maximum=10, step=1, value=3) | |
| btn_refine = gr.Button("Refine selected layer", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Current node layers") | |
| gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png") | |
| gr.Markdown("### Layer picker (Photoshop-style)") | |
| layer_picker = gr.Gallery(label="Pick a layer", columns=8, rows=1, format="png") | |
| with gr.Row(): | |
| layer_dropdown = gr.Dropdown(label="Refine layer", choices=[], value=None) | |
| selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, precision=0, interactive=False) | |
| selected_layer_label = gr.Markdown("Selected: -") | |
| with gr.Accordion("Refined layers (last refine)", open=True, visible=False) as refined_block: | |
| refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png") | |
| gr.Markdown("### History (nodes)") | |
| with gr.Row(): | |
| history_dd = gr.Dropdown(label="Node id", choices=[], value=None) | |
| chips_md = gr.Markdown("[root] [parent:-] [children:0]") | |
| with gr.Row(): | |
| btn_back_parent = gr.Button("← back to parent", variant="secondary") | |
| btn_duplicate = gr.Button("Duplicate node (branch)", variant="secondary") | |
| with gr.Row(): | |
| rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Type new name and click Rename") | |
| btn_rename = gr.Button("Rename", variant="secondary") | |
| with gr.Row(): | |
| btn_export = gr.Button("Export selected node (ZIP/PPTX)", variant="primary") | |
| btn_save = gr.Button("Save selected node to dataset", variant="primary") | |
| with gr.Row(): | |
| export_pptx = gr.File(label="Download PPTX") | |
| export_zip = gr.File(label="Download ZIP") | |
| status = gr.Markdown("") | |
| seed_used = gr.Textbox(label="Seed used", value="", interactive=False) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image], | |
| outputs=[gallery, export_pptx, export_zip], | |
| fn=lambda img: ([], None, None), | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| btn_init_ds.click(fn=on_init_dataset, outputs=[ds_status]) | |
| btn_refresh_sessions.click(fn=on_refresh_sessions, outputs=[load_session_dd, ds_status]) | |
| btn_load_session.click( | |
| fn=on_load_session, | |
| inputs=[state, load_session_dd], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status | |
| ], | |
| ) | |
| btn_decompose.click( | |
| fn=on_decompose_click, | |
| inputs=[ | |
| state, input_image, seed, randomize_seed, prompt, neg_prompt, true_guidance_scale, | |
| num_inference_steps, layer, cfg_norm, use_en_prompt, resolution, gpu_duration | |
| ], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used | |
| ], | |
| ) | |
| layer_picker.select( | |
| fn=on_layer_pick_from_gallery, | |
| inputs=[state], | |
| outputs=[selected_layer_idx, layer_dropdown, selected_layer_label], | |
| ) | |
| layer_dropdown.change( | |
| fn=on_layer_pick_from_dropdown, | |
| inputs=[state, layer_dropdown], | |
| outputs=[selected_layer_idx, selected_layer_label], | |
| ) | |
| btn_refine.click( | |
| fn=on_refine_click, | |
| inputs=[ | |
| state, selected_layer_idx, sub_layers, prompt, neg_prompt, true_guidance_scale, | |
| num_inference_steps, cfg_norm, use_en_prompt, resolution, gpu_duration, seed, randomize_seed | |
| ], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status, seed_used | |
| ], | |
| ) | |
| def _history_select_with_lazy(state_, node_id_): | |
| state_, _ = on_history_need_images(state_, node_id_) | |
| return on_history_select(state_, node_id_) | |
| history_dd.change( | |
| fn=_history_select_with_lazy, | |
| inputs=[state, history_dd], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status | |
| ], | |
| ) | |
| btn_back_parent.click( | |
| fn=on_back_to_parent, | |
| inputs=[state], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status | |
| ], | |
| ) | |
| btn_duplicate.click( | |
| fn=on_duplicate_node, | |
| inputs=[state], | |
| outputs=[ | |
| state, gallery, layer_picker, layer_dropdown, selected_layer_idx, selected_layer_label, | |
| history_dd, chips_md, refined_block, refined_gallery, export_pptx, export_zip, status | |
| ], | |
| ) | |
| btn_rename.click(fn=on_rename_node, inputs=[state, rename_text], outputs=[state, history_dd, chips_md, status]) | |
| btn_export.click(fn=on_export_selected, inputs=[state], outputs=[export_pptx, export_zip, status]) | |
| btn_save.click(fn=on_save_current, inputs=[state], outputs=[status]) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |