File size: 10,383 Bytes
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69f3b7
 
 
 
 
 
 
695fbf0
e69f3b7
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69f3b7
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from __future__ import annotations

import math
import torch
import torch.nn.functional as F  # noqa: F401

import comfy.utils as _utils
import comfy.sample as _sample
import comfy.samplers as _samplers
from comfy.k_diffusion import sampling as _kds

import nodes  # latent preview callback


def _smoothstep01(x: torch.Tensor) -> torch.Tensor:
    return x * x * (3.0 - 2.0 * x)


def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
                         mix: float, denoise: float, jitter: float, seed: int,
                         _debug: bool = False, tail_smooth: float = 0.0,
                         auto_hybrid_tail: bool = True, auto_tail_strength: float = 0.35):
    """Return 1D tensor of sigmas (len == steps+1), strictly descending and ending with 0.

    mode: 'karras' | 'beta' | 'hybrid'. If 'hybrid', blend tail toward beta by `mix`.
    We DO NOT apply 'drop penultimate' until the very end to preserve denoise math.
    """
    ms = model.get_model_object("model_sampling")
    steps = int(steps)
    assert steps >= 1

    # --- base tracks ---
    sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
    sig_b = _samplers.calculate_sigmas(ms, "beta",   steps)

    def _align_len(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Align two sigma schedules to the same length (use tail of longer)."""
        if a.shape[0] == b.shape[0]:
            return a, b
        m = min(a.shape[0], b.shape[0])
        return a[-m:], b[-m:]

    mode = str(mode).lower()
    sig_k, sig_b = _align_len(sig_k, sig_b)
    if mode == "karras":
        sig = sig_k
    elif mode == "beta":
        sig = sig_b
    else:
        n = sig_k.shape[0]
        t = torch.linspace(0.0, 1.0, n, device=sig_k.device, dtype=sig_k.dtype)
        m = float(max(0.0, min(1.0, mix)))
        eps = 1e-6 if m < 1e-6 else m
        w = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0)
        w = _smoothstep01(w)
        sig = sig_k * (1.0 - w) + sig_b * w

    # --- Comfy denoise semantics: recompute a "full" track and take the tail of desired length ---
    sig_k_base = sig_k
    sig_b_base = sig_b
    if denoise is not None and 0.0 < float(denoise) < 0.999999:
        new_steps = max(1, int(steps / max(1e-6, float(denoise))))
        sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
        sb = _samplers.calculate_sigmas(ms, "beta",   new_steps)
        sk, sb = _align_len(sk, sb)
        if mode == "karras":
            sig_full = sk
        elif mode == "beta":
            sig_full = sb
        else:
            n2 = sk.shape[0]
            t2 = torch.linspace(0.0, 1.0, n2, device=sk.device, dtype=sk.dtype)
            m = float(max(0.0, min(1.0, mix)))
            eps = 1e-6 if m < 1e-6 else m
            w2 = torch.clamp((t2 - (1.0 - m)) / eps, 0.0, 1.0)
            w2 = _smoothstep01(w2)
            sig_full = sk * (1.0 - w2) + sb * w2
        need = steps + 1
        if sig_full.shape[0] >= need:
            sig = sig_full[-need:]
            sig_k_base = sk[-need:]
            sig_b_base = sb[-need:]
        else:
            # Worst case: trust what we got; we will still guarantee the last sigma is zero later
            sig = sig_full
            tail = min(need, sk.shape[0])
            sig_k_base = sk[-tail:]
            sig_b_base = sb[-tail:]

    # --- auto-hybrid tail: blend beta into the tail when the steps become brittle ---
    if bool(auto_hybrid_tail) and sig.numel() > 2:
        n = sig.shape[0]
        t = torch.linspace(0.0, 1.0, n, device=sig.device, dtype=sig.dtype)
        m = float(max(0.0, min(1.0, mix)))
        if mode == "hybrid":
            eps = 1e-6 if m < 1e-6 else m
            w_m = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0)
            w_m = _smoothstep01(w_m)
        elif mode == "beta":
            w_m = torch.ones_like(t)
        else:
            w_m = torch.zeros_like(t)
        dif = (sig[1:] - sig[:-1]).abs() / sig[:-1].abs().clamp_min(1e-8)
        dif = torch.cat([dif, dif[-1:]], dim=0)
        dif = (dif - dif.min()) / (dif.max() - dif.min() + 1e-8)
        ramp = _smoothstep01(torch.clamp((t - 0.7) / 0.3, 0.0, 1.0))
        w_a = dif * ramp
        g = float(max(0.0, min(1.0, auto_tail_strength)))
        u = w_m + g * w_a - w_m * g * w_a
        sig = sig_k_base * (1.0 - u) + sig_b_base * u

    # --- tiny schedule jitter ---
    j = float(max(0.0, min(0.1, float(jitter))))
    if j > 0.0 and sig.numel() > 1:
        gen = torch.Generator(device='cpu')
        gen.manual_seed(int(seed) ^ 0x5EEDCAFE)
        noise = torch.randn(sig.shape, generator=gen, device='cpu').to(sig.device, sig.dtype)
        amp = j * float(sig[0].item() - sig[-1].item()) * 1e-3
        sig = sig + noise * amp
        sig, _ = torch.sort(sig, descending=True)

    # --- hard guarantee of ending with exact zero ---
    if sig[-1].abs() > 1e-12:
        sig = torch.cat([sig[:-1], sig.new_zeros(1)], dim=0)

    # --- and only now drop-penultimate for respective samplers ---
    # --- gentle smoothing of sigma tail (adaptive, safe for monotonic decrease) ---
    ts = float(max(0.0, min(1.0, tail_smooth)))
    if ts > 0.0 and sig.numel() > 2:
        s = sig.clone()
        n = int(s.shape[0])
        t = torch.linspace(0.0, 1.0, n, device=s.device, dtype=s.dtype)
        w = (t.pow(2) * ts).clamp(0.0, 1.0)
        for i in range(n - 2, -1, -1):
            a = float(min(0.5, 0.5 * w[i].item()))
            s[i] = (1.0 - a) * s[i] + a * s[i + 1]
        sig = s

    if base_sampler in _samplers.KSampler.DISCARD_PENULTIMATE_SIGMA_SAMPLERS and sig.numel() >= 2:
        sig = torch.cat([sig[:-2], sig[-1:]], dim=0)

    sig = sig.to(model.load_device)

    # Lightweight debug: schedule summary
    if _debug:
        try:
            desc_ok = bool((sig[:-1] > sig[1:]).all().item()) if sig.numel() > 1 else True
            head = ", ".join(f"{float(v):.4g}" for v in sig[:3].tolist()) if sig.numel() >= 3 else \
                   ", ".join(f"{float(v):.4g}" for v in sig.tolist())
            tail = ", ".join(f"{float(v):.4g}" for v in sig[-3:].tolist()) if sig.numel() >= 3 else head
            print(f"[ZeSmart][dbg] sigmas len={sig.numel()} desc={desc_ok} first={float(sig[0]):.6g} last={float(sig[-1]):.6g}")
            print(f"[ZeSmart][dbg] head: [{head}]  tail: [{tail}]")
        except Exception:
            pass

    return sig


class MG_ZeSmartSampler:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model": ("MODEL", {}),
                "seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "control_after_generate": True}),
                "steps": ("INT", {"default": 20, "min": 1, "max": 4096}),
                "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 50.0, "step": 0.1}),
                "base_sampler": (_samplers.KSampler.SAMPLERS, {"default": "dpmpp_2m"}),
                "schedule": (["karras", "beta", "hybrid"], {"default": "hybrid", "tooltip": "Sigma curve: karras — soft start; beta — stable tail; hybrid — their mix."}),
                "positive": ("CONDITIONING", {}),
                "negative": ("CONDITIONING", {}),
                "latent": ("LATENT", {}),
            },
            "optional": {
                "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Path shortening: 1.0 = full; <1.0 = take the last steps only. Useful for inpaint/mixing."}),
                "hybrid_mix": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For schedule=hybrid: tail fraction blended toward beta (0=karras, 1=beta)."}),
                "jitter_sigma": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 0.1, "step": 0.001, "tooltip": "Tiny sigma jitter to kill moiré/banding on backgrounds. 0–0.02 is usually enough."}),
                "tail_smooth": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Smooth the sigma tail — reduces wobble/banding. Too high may soften details."}),
                "auto_hybrid_tail": ("BOOLEAN", {"default": True, "tooltip": "Auto‑blend beta on the tail when steps become brittle."}),
                "auto_tail_strength": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Strength of auto beta‑mix on the tail (0=off, 1=max)."}),
                "debug_probe": ("BOOLEAN", {"default": False, "tooltip": "Print sigma summary (length, first/last, head/tail)."}),
            }
        }

    RETURN_TYPES = ("LATENT",)
    RETURN_NAMES = ("LATENT",)
    FUNCTION = "apply"
    CATEGORY = "MagicNodes/Experimental"

    def apply(self, model, seed, steps, cfg, base_sampler, schedule,
              positive, negative, latent, denoise=1.0, hybrid_mix=0.5,
              jitter_sigma=0.02, tail_smooth=0.07,
              auto_hybrid_tail=True, auto_tail_strength=0.3,
              debug_probe=False):
        # Prepare latent + noise
        lat_img = latent["samples"]
        lat_img = _sample.fix_empty_latent_channels(model, lat_img)
        batch_inds = latent.get("batch_index", None)
        noise = _sample.prepare_noise(lat_img, seed, batch_inds)
        noise_mask = latent.get("noise_mask", None)

        # Custom sigmas
        sigmas = _build_hybrid_sigmas(model, int(steps), str(base_sampler), str(schedule),
                                      float(hybrid_mix), float(denoise), float(jitter_sigma), int(seed),
                                      _debug=bool(debug_probe), tail_smooth=float(tail_smooth),
                                      auto_hybrid_tail=bool(auto_hybrid_tail),
                                      auto_tail_strength=float(auto_tail_strength))

        # Use native sampler; all tweaks happen in sigma schedule only.
        sampler_obj = _samplers.sampler_object(str(base_sampler))
        callback = nodes.latent_preview.prepare_callback(model, int(steps))
        disable_pbar = not _utils.PROGRESS_BAR_ENABLED
        samples = _sample.sample_custom(model, noise, float(cfg), sampler_obj, sigmas,
                                        positive, negative, lat_img,
                                        noise_mask=noise_mask, callback=callback,
                                        disable_pbar=disable_pbar, seed=seed)
        out = {**latent}
        out["samples"] = samples
        return (out,)