Diffusers
Nap commited on
Commit
0abc420
·
verified ·
1 Parent(s): 127fa0a

Delete modified code. Since Kijai added support for the giant model, no modifications are needed.

Browse files
Files changed (1) hide show
  1. nodes.py +0 -190
nodes.py DELETED
@@ -1,190 +0,0 @@
1
-
2
- import torch
3
- import torch.nn.functional as F
4
- from torchvision import transforms
5
- import os
6
- from contextlib import nullcontext
7
-
8
- import comfy.model_management as mm
9
- from comfy.utils import ProgressBar, load_torch_file
10
- import folder_paths
11
-
12
- from .depth_anything_v2.dpt import DepthAnythingV2
13
-
14
- from contextlib import nullcontext
15
- try:
16
- from accelerate import init_empty_weights
17
- from accelerate.utils import set_module_tensor_to_device
18
- is_accelerate_available = True
19
- except:
20
- pass
21
-
22
- class DownloadAndLoadDepthAnythingV2Model:
23
- @classmethod
24
- def INPUT_TYPES(s):
25
- return {"required": {
26
- "model": (
27
- [
28
- 'depth_anything_v2_vits_fp16.safetensors',
29
- 'depth_anything_v2_vits_fp32.safetensors',
30
- 'depth_anything_v2_vitb_fp16.safetensors',
31
- 'depth_anything_v2_vitb_fp32.safetensors',
32
- 'depth_anything_v2_vitl_fp16.safetensors',
33
- 'depth_anything_v2_vitl_fp32.safetensors',
34
- 'depth_anything_v2_vitg_fp32.safetensors',
35
- 'depth_anything_v2_metric_hypersim_vitl_fp32.safetensors',
36
- 'depth_anything_v2_metric_vkitti_vitl_fp32.safetensors'
37
- ],
38
- {
39
- "default": 'depth_anything_v2_vitl_fp32.safetensors'
40
- }),
41
- },
42
- }
43
-
44
- RETURN_TYPES = ("DAMODEL",)
45
- RETURN_NAMES = ("da_v2_model",)
46
- FUNCTION = "loadmodel"
47
- CATEGORY = "DepthAnythingV2"
48
- DESCRIPTION = """
49
- Models autodownload to `ComfyUI\models\depthanything` from
50
- https://huggingface.co/Kijai/DepthAnythingV2-safetensors/tree/main
51
-
52
- fp16 reduces quality by a LOT, not recommended.
53
- """
54
-
55
- def loadmodel(self, model):
56
- device = mm.get_torch_device()
57
- dtype = torch.float16 if "fp16" in model else torch.float32
58
- model_configs = {
59
- 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
60
- 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
61
- 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
62
- 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
63
- }
64
- custom_config = {
65
- 'model_name': model,
66
- }
67
- if not hasattr(self, 'model') or self.model == None or custom_config != self.current_config:
68
- self.current_config = custom_config
69
- download_path = os.path.join(folder_paths.models_dir, "depthanything")
70
- model_path = os.path.join(download_path, model)
71
-
72
- if not os.path.exists(model_path):
73
- print(f"Downloading model to: {model_path}")
74
- from huggingface_hub import snapshot_download
75
- snapshot_download(repo_id="Kijai/DepthAnythingV2-safetensors",
76
- allow_patterns=[f"*{model}*"],
77
- local_dir=download_path,
78
- local_dir_use_symlinks=False)
79
-
80
- print(f"Loading model from: {model_path}")
81
-
82
- if "vitg" in model:
83
- encoder = "vitg"
84
- elif "vitl" in model:
85
- encoder = "vitl"
86
- elif "vitb" in model:
87
- encoder = "vitb"
88
- elif "vits" in model:
89
- encoder = "vits"
90
-
91
- if "hypersim" in model:
92
- max_depth = 20.0
93
- else:
94
- max_depth = 80.0
95
-
96
- with (init_empty_weights() if is_accelerate_available else nullcontext()):
97
- if 'metric' in model:
98
- self.model = DepthAnythingV2(**{**model_configs[encoder], 'is_metric': True, 'max_depth': max_depth})
99
- else:
100
- self.model = DepthAnythingV2(**model_configs[encoder])
101
-
102
- state_dict = load_torch_file(model_path)
103
- if is_accelerate_available:
104
- for key in state_dict:
105
- set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=state_dict[key])
106
- else:
107
- self.model.load_state_dict(state_dict)
108
-
109
- self.model.eval()
110
- da_model = {
111
- "model": self.model,
112
- "dtype": dtype,
113
- "is_metric": self.model.is_metric
114
- }
115
-
116
- return (da_model,)
117
-
118
- class DepthAnything_V2:
119
- @classmethod
120
- def INPUT_TYPES(s):
121
- return {"required": {
122
- "da_model": ("DAMODEL", ),
123
- "images": ("IMAGE", ),
124
- },
125
- }
126
-
127
- RETURN_TYPES = ("IMAGE",)
128
- RETURN_NAMES =("image",)
129
- FUNCTION = "process"
130
- CATEGORY = "DepthAnythingV2"
131
- DESCRIPTION = """
132
- https://depth-anything-v2.github.io
133
- """
134
-
135
- def process(self, da_model, images):
136
- device = mm.get_torch_device()
137
- offload_device = mm.unet_offload_device()
138
- model = da_model['model']
139
- dtype=da_model['dtype']
140
-
141
- B, H, W, C = images.shape
142
-
143
- #images = images.to(device)
144
- images = images.permute(0, 3, 1, 2)
145
-
146
- orig_H, orig_W = H, W
147
- if W % 14 != 0:
148
- W = W - (W % 14)
149
- if H % 14 != 0:
150
- H = H - (H % 14)
151
- if orig_H % 14 != 0 or orig_W % 14 != 0:
152
- images = F.interpolate(images, size=(H, W), mode="bilinear")
153
-
154
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
155
- normalized_images = normalize(images)
156
- pbar = ProgressBar(B)
157
- out = []
158
- model.to(device)
159
- autocast_condition = (dtype != torch.float32) and not mm.is_device_mps(device)
160
- with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
161
- for img in normalized_images:
162
- depth = model(img.unsqueeze(0).to(device))
163
- depth = (depth - depth.min()) / (depth.max() - depth.min())
164
- out.append(depth.cpu())
165
- pbar.update(1)
166
- model.to(offload_device)
167
- depth_out = torch.cat(out, dim=0)
168
- depth_out = depth_out.unsqueeze(-1).repeat(1, 1, 1, 3).cpu().float()
169
-
170
- final_H = (orig_H // 2) * 2
171
- final_W = (orig_W // 2) * 2
172
-
173
-
174
-
175
- if depth_out.shape[1] != final_H or depth_out.shape[2] != final_W:
176
- depth_out = F.interpolate(depth_out.permute(0, 3, 1, 2), size=(final_H, final_W), mode="bilinear").permute(0, 2, 3, 1)
177
- depth_out = (depth_out - depth_out.min()) / (depth_out.max() - depth_out.min())
178
- depth_out = torch.clamp(depth_out, 0, 1)
179
- if da_model['is_metric']:
180
- depth_out = 1 - depth_out
181
- return (depth_out,)
182
-
183
- NODE_CLASS_MAPPINGS = {
184
- "DepthAnything_V2": DepthAnything_V2,
185
- "DownloadAndLoadDepthAnythingV2Model": DownloadAndLoadDepthAnythingV2Model
186
- }
187
- NODE_DISPLAY_NAME_MAPPINGS = {
188
- "DepthAnything_V2": "Depth Anything V2",
189
- "DownloadAndLoadDepthAnythingV2Model": "DownloadAndLoadDepthAnythingV2Model"
190
- }