matthartman commited on
Commit
a9d4b99
·
verified ·
1 Parent(s): 8074ca3

Deploy Gradio app with multiple files

Browse files
app.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from huggingface_hub import snapshot_download, hf_hub_download
3
+ import os
4
+ import subprocess
5
+ import importlib, site
6
+ from PIL import Image
7
+ import uuid
8
+ import shutil
9
+ import time
10
+ import cv2
11
+ import json
12
+ import gradio as gr
13
+ import sys
14
+ import gc
15
+
16
+ BASE = os.path.dirname(os.path.abspath(__file__))
17
+ PREPROCESS_DIR = os.path.join(BASE, "wan", "modules", "animate", "preprocess")
18
+ sys.path.append(PREPROCESS_DIR)
19
+
20
+ # Re-discover all .pth/.egg-link files
21
+ for sitedir in site.getsitepackages():
22
+ site.addsitedir(sitedir)
23
+
24
+ # Clear caches so importlib will pick up new modules
25
+ importlib.invalidate_caches()
26
+
27
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
28
+
29
+
30
+
31
+ try:
32
+
33
+ sh("pip install flash-attn --no-build-isolation")
34
+ # print("Attempting to download and build sam2...")
35
+
36
+ # print("download sam")
37
+ # sam_dir = snapshot_download(repo_id="alexnasa/sam2")
38
+
39
+ # @spaces.GPU(duration=500)
40
+ # def install_sam():
41
+
42
+ # os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
43
+ # sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .")
44
+
45
+ # print("install sam")
46
+ # install_sam()
47
+
48
+ print("Attempting to download")
49
+
50
+ print("download sam")
51
+ snapshot_download(repo_id="alexnasa/sam2_C", local_dir=f"{os.getcwd()}" )
52
+
53
+ # tell Python to re-scan site-packages now that the egg-link exists
54
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
55
+
56
+ print("sam2 installed successfully.")
57
+
58
+ except Exception as e:
59
+ raise gr.Error("sam2 installation failed")
60
+
61
+ import torch
62
+ from generate import generate, load_model
63
+ from preprocess_data import run as run_preprocess
64
+ from preprocess_data import load_preprocess_models
65
+ print(f"Torch version: {torch.__version__}")
66
+
67
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
68
+
69
+ snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B")
70
+ wan_animate = load_model(True)
71
+
72
+ rc_mapping = {
73
+ "Video → Ref Image" : False,
74
+ "Video ← Ref Image" : True
75
+ }
76
+
77
+
78
+ def preprocess_video(input_video_path, duration, session_id=None):
79
+
80
+ if session_id is None:
81
+ session_id = uuid.uuid4().hex
82
+
83
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
84
+ os.makedirs(output_dir, exist_ok=True)
85
+
86
+ process_video_path = os.path.join(output_dir, 'input_video.mp4')
87
+
88
+ clip_and_set_fps(input_video_path, process_video_path, duration_s=duration)
89
+
90
+ return process_video_path
91
+
92
+ def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None):
93
+ """
94
+ Extracts the audio track from a video file and saves it as a WAV file.
95
+
96
+ Args:
97
+ video_path (str): Path to the input video file.
98
+ output_wav_path (str): Path to save the extracted WAV file.
99
+ sample_rate (int, optional): Output sample rate (e.g., 16000).
100
+ If None, keep the original.
101
+ """
102
+ cmd = [
103
+ 'ffmpeg',
104
+ '-i', video_path, # Input video
105
+ '-vn', # Disable video
106
+ '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format)
107
+ '-ac', '1', # Mono channel (use '2' for stereo)
108
+ '-y', # Overwrite output
109
+ '-loglevel', 'error' # Cleaner output
110
+ ]
111
+
112
+ # Only add the sample rate option if explicitly specified
113
+ if sample_rate is not None:
114
+ cmd.extend(['-ar', str(sample_rate)])
115
+
116
+ cmd.append(output_wav_path)
117
+
118
+ try:
119
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
120
+ return True
121
+ except subprocess.CalledProcessError as e:
122
+ return False
123
+
124
+
125
+ def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path):
126
+ """
127
+ Combines a silent MP4 video with a WAV audio file into a single MP4 with sound.
128
+
129
+ Args:
130
+ video_path (str): Path to the silent video file.
131
+ audio_path (str): Path to the WAV audio file.
132
+ output_video_path (str): Path to save the output MP4 with audio.
133
+ """
134
+ cmd = [
135
+ 'ffmpeg',
136
+ '-i', video_path, # Input video
137
+ '-i', audio_path, # Input audio
138
+ '-c:v', 'copy', # Copy video without re-encoding
139
+ '-c:a', 'aac', # Encode audio as AAC (MP4-compatible)
140
+ '-shortest', # Stop when the shortest stream ends
141
+ '-y', # Overwrite output
142
+ '-loglevel', 'error',
143
+ output_video_path
144
+ ]
145
+
146
+ try:
147
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
148
+ except subprocess.CalledProcessError as e:
149
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
150
+
151
+
152
+ def clip_and_set_fps(input_video_path, output_video_path, duration_s=2, target_fps=30):
153
+ """
154
+ Trim to duration_s and (optionally) change FPS, without resizing.
155
+ - If target_fps is None, keeps the original FPS.
156
+ - Re-encodes video when changing FPS for predictable timing.
157
+ """
158
+ vf = []
159
+ if target_fps is not None:
160
+ vf.append(f"fps={target_fps}")
161
+ vf_arg = ",".join(vf) if vf else None
162
+
163
+ cmd = [
164
+ "ffmpeg",
165
+ "-nostdin",
166
+ "-hide_banner",
167
+ "-y",
168
+ "-i", input_video_path,
169
+ "-t", str(duration_s),
170
+ ]
171
+
172
+ if vf_arg:
173
+ cmd += ["-vf", vf_arg]
174
+
175
+ cmd += [
176
+ "-c:v", "libx264",
177
+ "-pix_fmt", "yuv420p",
178
+ "-preset", "veryfast",
179
+ "-crf", "18",
180
+ "-c:a", "aac", # use aac so MP4 stays compatible
181
+ "-movflags", "+faststart",
182
+ output_video_path,
183
+ ]
184
+
185
+ try:
186
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
187
+ except subprocess.CalledProcessError as e:
188
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
189
+
190
+
191
+ def is_portrait(video_file):
192
+
193
+ # Get video information
194
+ cap = cv2.VideoCapture(video_file)
195
+ if not cap.isOpened():
196
+ error_msg = "Cannot open video file"
197
+ gr.Warning(error_msg)
198
+
199
+ orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
200
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
201
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
202
+
203
+ cap.release()
204
+
205
+ return orig_width < orig_height
206
+
207
+ def calculate_time_required(max_duration_s, rc_bool):
208
+
209
+ if max_duration_s == 2:
210
+ return 120
211
+ elif max_duration_s == 4:
212
+ return 180
213
+ elif max_duration_s == 6:
214
+ return 260
215
+ elif max_duration_s == 8:
216
+ return 330
217
+ elif max_duration_s == 10:
218
+ return 340
219
+
220
+ def get_display_time_required(max_duration_s, rc_bool):
221
+ # the 30 seconds extra is just for saftey in case of a unexpected slow down
222
+ return calculate_time_required(max_duration_s, rc_bool) - 30
223
+
224
+ def update_time_required(max_duration_s, rc_str):
225
+
226
+ rc_bool = rc_mapping[rc_str]
227
+
228
+ duration_s = get_display_time_required(max_duration_s, rc_bool)
229
+ duration_m = duration_s / 60
230
+
231
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
232
+
233
+
234
+ def get_duration(input_video, max_duration_s, edited_frame, rc_bool, session_id, progress):
235
+ return calculate_time_required(max_duration_s, rc_bool)
236
+
237
+ @spaces.GPU(duration=get_duration)
238
+ def _animate(input_video, max_duration_s, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),):
239
+
240
+ if session_id is None:
241
+ session_id = uuid.uuid4().hex
242
+
243
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
244
+ os.makedirs(output_dir, exist_ok=True)
245
+
246
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
247
+ os.makedirs(preprocess_dir, exist_ok=True)
248
+
249
+ output_video_path = os.path.join(output_dir, 'result.mp4')
250
+
251
+ # --- Measure preprocess time ---
252
+ start_preprocess = time.time()
253
+
254
+ if is_portrait(input_video):
255
+ w = 480
256
+ h = 832
257
+ else:
258
+ w = 832
259
+ h = 480
260
+
261
+ tag_string = "retarget_flag"
262
+
263
+ if rc_bool:
264
+ tag_string = "replace_flag"
265
+
266
+ preprocess_model = load_preprocess_models()
267
+
268
+ run_preprocess(preprocess_model, input_video, edited_frame, preprocess_dir, w, h, tag_string)
269
+
270
+ preprocess_time = time.time() - start_preprocess
271
+ print(f"Preprocess took {preprocess_time:.2f} seconds")
272
+
273
+ # --- Measure generate time ---
274
+ start_generate = time.time()
275
+
276
+
277
+ generate(wan_animate, preprocess_dir, output_video_path, rc_bool)
278
+
279
+ generate_time = time.time() - start_generate
280
+ print(f"Generate took {generate_time:.2f} seconds")
281
+
282
+ # --- Optional total time ---
283
+ total_time = preprocess_time + generate_time
284
+ print(f"Total time: {total_time:.2f} seconds")
285
+
286
+ gc.collect()
287
+ torch.cuda.empty_cache()
288
+
289
+ return output_video_path
290
+
291
+ def animate_scene(input_video, max_duration_s, edited_frame, rc_str, use_ai_image, ai_prompt, session_id = None, progress=gr.Progress(track_tqdm=True),):
292
+
293
+ if not input_video:
294
+ raise gr.Error("Please provide an video")
295
+
296
+ if not use_ai_image and not edited_frame:
297
+ raise gr.Error("Please provide an image or enable AI generation")
298
+
299
+ if use_ai_image and not ai_prompt:
300
+ raise gr.Error("Please provide a prompt for AI image generation")
301
+
302
+ if session_id is None:
303
+ session_id = uuid.uuid4().hex
304
+
305
+ input_video = preprocess_video(input_video, max_duration_s, session_id)
306
+
307
+ rc_bool = rc_mapping[rc_str]
308
+
309
+
310
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
311
+ os.makedirs(output_dir, exist_ok=True)
312
+
313
+ input_audio_path = os.path.join(output_dir, 'input_audio.wav')
314
+
315
+ audio_extracted = extract_audio_from_video_ffmpeg(input_video, input_audio_path)
316
+
317
+ edited_frame_png = os.path.join(output_dir, 'edited_frame.png')
318
+
319
+ if use_ai_image:
320
+ # Generate image using AI model
321
+ generated_image = generate_ai_image(ai_prompt, session_id)
322
+ edited_frame_img = generated_image
323
+ else:
324
+ edited_frame_img = Image.open(edited_frame)
325
+
326
+ edited_frame_img.save(edited_frame_png)
327
+
328
+ print(f'{session_id} inference started')
329
+
330
+ output_video_path = _animate(input_video, max_duration_s, edited_frame_png, rc_bool, session_id, progress)
331
+
332
+ final_video_path = os.path.join(output_dir, 'final_result.mp4')
333
+
334
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
335
+ pose_video = os.path.join(preprocess_dir, 'src_pose.mp4')
336
+
337
+ if rc_bool:
338
+ mask_video = os.path.join(preprocess_dir, 'src_mask.mp4')
339
+ bg_video = os.path.join(preprocess_dir, 'src_bg.mp4')
340
+ face_video = os.path.join(preprocess_dir, 'src_face.mp4')
341
+ else:
342
+ mask_video = os.path.join(preprocess_dir, 'src_pose.mp4')
343
+ bg_video = os.path.join(preprocess_dir, 'src_pose.mp4')
344
+ face_video = os.path.join(preprocess_dir, 'src_pose.mp4')
345
+
346
+ if audio_extracted:
347
+ combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path)
348
+ else:
349
+ final_video_path = output_video_path
350
+
351
+ print(f"task for {session_id} finalised")
352
+
353
+ return final_video_path, pose_video, bg_video, mask_video, face_video
354
+
355
+ css = """
356
+ #col-container {
357
+ margin: 0 auto;
358
+ max-width: 1600px;
359
+ }
360
+
361
+ #step-column {
362
+ padding: 10px;
363
+ border-radius: 8px;
364
+ box-shadow: var(--card-shadow);
365
+ margin: 10px;
366
+ }
367
+
368
+ #col-showcase {
369
+ margin: 0 auto;
370
+ max-width: 1100px;
371
+ }
372
+
373
+ .button-gradient {
374
+ background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
375
+ border: none;
376
+ padding: 14px 28px;
377
+ font-size: 16px;
378
+ font-weight: bold;
379
+ color: white;
380
+ border-radius: 10px;
381
+ cursor: pointer;
382
+ transition: 0.3s ease-in-out;
383
+ animation: 2s linear 0s infinite normal none running gradientAnimation;
384
+ box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
385
+ }
386
+
387
+ .toggle-container {
388
+ display: inline-flex;
389
+ background-color: #ffd6ff; /* light pink background */
390
+ border-radius: 9999px;
391
+ padding: 4px;
392
+ position: relative;
393
+ width: fit-content;
394
+ font-family: sans-serif;
395
+ }
396
+
397
+ .toggle-container input[type="radio"] {
398
+ display: none;
399
+ }
400
+
401
+ .toggle-container label {
402
+ position: relative;
403
+ z-index: 2;
404
+ flex: 1;
405
+ text-align: center;
406
+ font-weight: 700;
407
+ color: #4b2ab5; /* dark purple text for unselected */
408
+ padding: 6px 22px;
409
+ border-radius: 9999px;
410
+ cursor: pointer;
411
+ transition: color 0.25s ease;
412
+ }
413
+
414
+ /* Moving highlight */
415
+ .toggle-highlight {
416
+ position: absolute;
417
+ top: 4px;
418
+ left: 4px;
419
+ width: calc(50% - 4px);
420
+ height: calc(100% - 8px);
421
+ background-color: #4b2ab5; /* dark purple background */
422
+ border-radius: 9999px;
423
+ transition: transform 0.25s ease;
424
+ z-index: 1;
425
+ }
426
+
427
+ /* When "True" is checked */
428
+ #true:checked ~ label[for="true"] {
429
+ color: #ffd6ff; /* light pink text */
430
+ }
431
+
432
+ /* When "False" is checked */
433
+ #false:checked ~ label[for="false"] {
434
+ color: #ffd6ff; /* light pink text */
435
+ }
436
+
437
+ /* Move highlight to right side when False is checked */
438
+ #false:checked ~ .toggle-highlight {
439
+ transform: translateX(100%);
440
+ }
441
+ """
442
+
443
+ def log_change(log_source, session_id, meta_data = None):
444
+
445
+ if not meta_data:
446
+ print(f'{session_id} changed {log_source}')
447
+ else:
448
+ print(f'{session_id} changed {log_source} with {meta_data}')
449
+
450
+ def generate_ai_image(prompt, session_id):
451
+ """
452
+ Generate an image using an AI model based on the prompt.
453
+ This is a placeholder - implement with your preferred image generation model.
454
+ """
455
+ # TODO: Implement actual AI image generation
456
+ # Example using a hypothetical image generation model:
457
+ # from diffusers import StableDiffusionPipeline
458
+ # pipe = StableDiffusionPipeline.from_pretrained("model_name")
459
+ # image = pipe(prompt).images[0]
460
+
461
+ # For now, return a placeholder
462
+ raise gr.Error("AI image generation not yet implemented. Please upload an image instead.")
463
+
464
+ def start_session(request: gr.Request):
465
+
466
+ return request.session_hash
467
+
468
+ def cleanup(request: gr.Request):
469
+
470
+ sid = request.session_hash
471
+
472
+ if sid:
473
+ print(f"{sid} left")
474
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
475
+ shutil.rmtree(d1, ignore_errors=True)
476
+
477
+ with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo:
478
+
479
+ session_state = gr.State()
480
+ demo.load(start_session, outputs=[session_state])
481
+
482
+ with gr.Column(elem_id="col-container"):
483
+ with gr.Row():
484
+ gr.HTML(
485
+ """
486
+ <div style="text-align: center;">
487
+ <p style="font-size:16px; display: inline; margin: 0;">
488
+ <strong>Wan2.2-Animate-14B </strong>
489
+ </p>
490
+ <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
491
+ [Model]
492
+ </a>
493
+ <p style="font-size:16px; display: inline; margin: 0;">
494
+ -- HF Space By:
495
+ </p>
496
+ <a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
497
+ <img src="https://img.shields.io/badge/🤗-Follow Me-yellow.svg">
498
+ </a>
499
+ </div>
500
+ """
501
+ )
502
+ with gr.Row():
503
+ with gr.Column(elem_id="step-column"):
504
+ gr.HTML("""
505
+ <div>
506
+ <span style="font-size: 24px;">1. Upload a Video</span><br>
507
+ </div>
508
+ """)
509
+ input_video = gr.Video(label="Input Video", height=512)
510
+ max_duration_slider = gr.Slider(2, 10, 2, step=2, label="Max Duration", visible=False)
511
+
512
+ gr.Examples(
513
+ examples=[
514
+
515
+ [
516
+ "./examples/martialart.mp4",
517
+ ],
518
+
519
+ [
520
+ "./examples/test_example.mp4",
521
+ ],
522
+
523
+ ],
524
+ inputs=[input_video],
525
+ cache_examples=False,
526
+ )
527
+
528
+
529
+ with gr.Column(elem_id="step-column"):
530
+ gr.HTML("""
531
+ <div>
532
+ <span style="font-size: 24px;">2. Upload or Generate Ref Image</span><br>
533
+ </div>
534
+ """)
535
+
536
+ use_ai_image = gr.Checkbox(label="Generate Image with AI", value=False)
537
+
538
+ with gr.Group() as upload_group:
539
+ edited_frame = gr.Image(label="Ref Image", type="filepath", height=512)
540
+
541
+ with gr.Group(visible=False) as ai_group:
542
+ ai_prompt = gr.Textbox(label="AI Image Prompt", placeholder="Describe the image you want to generate...")
543
+ generate_btn = gr.Button("Generate Image", variant="secondary")
544
+ ai_generated_preview = gr.Image(label="Generated Preview", type="pil", height=512)
545
+
546
+ default_replace_string = "Video ← Ref Image"
547
+ replace_character_string = gr.Radio(
548
+ ["Video → Ref Image", "Video ← Ref Image"], value=default_replace_string, show_label=False
549
+ )
550
+
551
+ def toggle_image_input(use_ai):
552
+ return gr.update(visible=not use_ai), gr.update(visible=use_ai)
553
+
554
+ use_ai_image.change(
555
+ toggle_image_input,
556
+ inputs=[use_ai_image],
557
+ outputs=[upload_group, ai_group]
558
+ )
559
+
560
+ gr.Examples(
561
+ examples=[
562
+
563
+ [
564
+ "./examples/ali.png",
565
+ ],
566
+
567
+ [
568
+ "./examples/amber.png",
569
+ ],
570
+
571
+ [
572
+ "./examples/ella.png",
573
+ ],
574
+
575
+ [
576
+ "./examples/sydney.png",
577
+ ],
578
+
579
+ ],
580
+ inputs=[edited_frame],
581
+ cache_examples=False,
582
+ )
583
+
584
+ with gr.Column(elem_id="step-column"):
585
+ gr.HTML("""
586
+ <div>
587
+ <span style="font-size: 24px;">3. Wan Animate it!</span><br>
588
+ </div>
589
+ """)
590
+ output_video = gr.Video(label="Edited Video", height=512)
591
+
592
+ duration_s = get_display_time_required(2, default_replace_string)
593
+ duration_m = duration_s / 60
594
+
595
+ time_required = f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)"
596
+
597
+ time_required = gr.Text(value=time_required, show_label=False, visible=False)
598
+ action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
599
+
600
+ with gr.Accordion("Preprocessed Data", open=False, visible=True):
601
+ with gr.Row():
602
+ pose_video = gr.Video(label="Pose Video")
603
+ bg_video = gr.Video(label="Background Video")
604
+ face_video = gr.Video(label="Face Video")
605
+ mask_video = gr.Video(label="Mask Video")
606
+
607
+ with gr.Row():
608
+ with gr.Column(elem_id="col-showcase"):
609
+
610
+ gr.Examples(
611
+ examples=[
612
+
613
+ [
614
+ "./examples/okay.mp4",
615
+ 2,
616
+ "./examples/amber.png",
617
+ "Video ← Ref Image",
618
+ False,
619
+ ""
620
+ ],
621
+
622
+ [
623
+ "./examples/superman.mp4",
624
+ 2,
625
+ "./examples/superman.png",
626
+ "Video ← Ref Image",
627
+ False,
628
+ ""
629
+ ],
630
+
631
+ [
632
+ "./examples/test_example.mp4",
633
+ 2,
634
+ "./examples/ella.png",
635
+ "Video ← Ref Image",
636
+ False,
637
+ ""
638
+ ],
639
+
640
+ [
641
+ "./examples/paul.mp4",
642
+ 2,
643
+ "./examples/man.png",
644
+ "Video → Ref Image",
645
+ False,
646
+ ""
647
+ ],
648
+
649
+ [
650
+ "./examples/desi.mp4",
651
+ 2,
652
+ "./examples/desi.png",
653
+ "Video ← Ref Image",
654
+ False,
655
+ ""
656
+ ],
657
+
658
+ ],
659
+ inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, use_ai_image, ai_prompt],
660
+ outputs=[output_video, pose_video, bg_video, mask_video, face_video],
661
+ fn=animate_scene,
662
+ cache_examples=True,
663
+ )
664
+
665
+ action_button.click(fn=animate_scene, inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, use_ai_image, ai_prompt, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video])
666
+ replace_character_string.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required])
667
+
668
+ max_duration_slider.change(log_change, inputs=[gr.State("slider"), session_state, max_duration_slider]).then(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required])
669
+ input_video.change(log_change, inputs=[gr.State("video"), session_state])
670
+ edited_frame.change(log_change, inputs=[gr.State("ref image"), session_state])
671
+
672
+ if __name__ == "__main__":
673
+ demo.queue()
674
+ demo.unload(cleanup)
675
+ demo.launch(ssr_mode=False, share=True)
generate.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import merge_video_audio, save_video, str2bool
22
+
23
+
24
+ EXAMPLE_PROMPT = {
25
+ "t2v-A14B": {
26
+ "prompt":
27
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
28
+ },
29
+ "i2v-A14B": {
30
+ "prompt":
31
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
32
+ "image":
33
+ "examples/i2v_input.JPG",
34
+ },
35
+ "ti2v-5B": {
36
+ "prompt":
37
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
38
+ },
39
+ "animate-14B": {
40
+ "prompt": "视频中的人在做动作",
41
+ "video": "",
42
+ "pose": "",
43
+ "mask": "",
44
+ },
45
+ "s2v-14B": {
46
+ "prompt":
47
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
48
+ "image":
49
+ "examples/i2v_input.JPG",
50
+ "audio":
51
+ "examples/talk.wav",
52
+ "tts_prompt_audio":
53
+ "examples/zero_shot_prompt.wav",
54
+ "tts_prompt_text":
55
+ "希望你以后能够做的比我还好呦。",
56
+ "tts_text":
57
+ "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
58
+ },
59
+ }
60
+
61
+
62
+ def _validate_args(args):
63
+ # Basic check
64
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
65
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
66
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
67
+
68
+ if args.prompt is None:
69
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
70
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
71
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
72
+ if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
73
+ args.audio = EXAMPLE_PROMPT[args.task]["audio"]
74
+ if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
75
+ args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
76
+ args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
77
+ args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
78
+
79
+ if args.task == "i2v-A14B":
80
+ assert args.image is not None, "Please specify the image path for i2v."
81
+
82
+ cfg = WAN_CONFIGS[args.task]
83
+
84
+ if args.sample_steps is None:
85
+ args.sample_steps = cfg.sample_steps
86
+
87
+ if args.sample_shift is None:
88
+ args.sample_shift = cfg.sample_shift
89
+
90
+ if args.sample_guide_scale is None:
91
+ args.sample_guide_scale = cfg.sample_guide_scale
92
+
93
+ if args.frame_num is None:
94
+ args.frame_num = cfg.frame_num
95
+
96
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
97
+ 0, sys.maxsize)
98
+ # Size check
99
+ if not 's2v' in args.task:
100
+ assert args.size in SUPPORTED_SIZES[
101
+ args.
102
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
103
+
104
+
105
+ class _Args:
106
+ pass
107
+
108
+ def _parse_args():
109
+ args = _Args()
110
+
111
+ # core generation options
112
+ args.task = "animate-14B"
113
+ # args.size = "1280*720"
114
+ args.size = "720*1280"
115
+ args.frame_num = None
116
+ args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
+ args.offload_model = False
118
+ args.ulysses_size = 1
119
+ args.t5_fsdp = False
120
+ args.t5_cpu = False
121
+ args.dit_fsdp = False
122
+ args.prompt = None
123
+ args.use_prompt_extend = False
124
+ args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
125
+ args.prompt_extend_model = None
126
+ args.prompt_extend_target_lang = "zh" # ["zh", "en"]
127
+ args.base_seed = 1234
128
+ args.image = None
129
+ args.sample_solver = "unipc" # ['unipc', 'dpm++']
130
+ args.sample_steps = None
131
+ args.sample_shift = None
132
+ args.sample_guide_scale = None
133
+ args.convert_model_dtype = True
134
+
135
+ # animate
136
+ args.refert_num = 1
137
+
138
+ # s2v-only
139
+ args.num_clip = None
140
+ args.audio = None
141
+ args.enable_tts = False
142
+ args.tts_prompt_audio = None
143
+ args.tts_prompt_text = None
144
+ args.tts_text = None
145
+ args.pose_video = None
146
+ args.start_from_ref = False
147
+ args.infer_frames = 80
148
+
149
+ _validate_args(args)
150
+ return args
151
+
152
+
153
+
154
+ def _init_logging(rank):
155
+ # logging
156
+ if rank == 0:
157
+ # set format
158
+ logging.basicConfig(
159
+ level=logging.INFO,
160
+ format="[%(asctime)s] %(levelname)s: %(message)s",
161
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
162
+ else:
163
+ logging.basicConfig(level=logging.ERROR)
164
+
165
+ def load_model(use_relighting_lora = False):
166
+
167
+ cfg = WAN_CONFIGS["animate-14B"]
168
+
169
+ return wan.WanAnimate(
170
+ config=cfg,
171
+ checkpoint_dir="./Wan2.2-Animate-14B/",
172
+ device_id=0,
173
+ rank=0,
174
+ t5_fsdp=False,
175
+ dit_fsdp=False,
176
+ use_sp=False,
177
+ t5_cpu=False,
178
+ convert_model_dtype=False,
179
+ use_relighting_lora=use_relighting_lora
180
+ )
181
+
182
+ def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
183
+ args = _parse_args()
184
+ rank = int(os.getenv("RANK", 0))
185
+ world_size = int(os.getenv("WORLD_SIZE", 1))
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ device = local_rank
188
+ _init_logging(rank)
189
+
190
+ cfg = WAN_CONFIGS[args.task]
191
+
192
+ logging.info(f"Input prompt: {args.prompt}")
193
+ img = None
194
+ if args.image is not None:
195
+ img = Image.open(args.image).convert("RGB")
196
+ logging.info(f"Input image: {args.image}")
197
+
198
+ print(f'rank:{rank}')
199
+
200
+
201
+
202
+ logging.info(f"Generating video ...")
203
+ video = wan_animate.generate(
204
+ src_root_path=preprocess_dir,
205
+ replace_flag=replace_flag,
206
+ refert_num = args.refert_num,
207
+ clip_len=args.frame_num,
208
+ shift=args.sample_shift,
209
+ sample_solver=args.sample_solver,
210
+ sampling_steps=args.sample_steps,
211
+ guide_scale=args.sample_guide_scale,
212
+ seed=args.base_seed,
213
+ offload_model=args.offload_model)
214
+ if rank == 0:
215
+
216
+ save_video(
217
+ tensor=video[None],
218
+ save_file=save_file,
219
+ fps=cfg.sample_fps,
220
+ nrow=1,
221
+ normalize=True,
222
+ value_range=(-1, 1))
223
+ # if "s2v" in args.task:
224
+ # if args.enable_tts is False:
225
+ # merge_video_audio(video_path=args.save_file, audio_path=args.audio)
226
+ # else:
227
+ # merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
228
+ del video
229
+
230
+ torch.cuda.synchronize()
231
+ if dist.is_initialized():
232
+ dist.barrier()
233
+ dist.destroy_process_group()
234
+
235
+ logging.info("Finished.")
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord
2
+ peft
3
+ pandas
4
+ matplotlib
5
+ loguru
6
+ sentencepiece
7
+ dashscope
8
+ ftfy
9
+ diffusers
10
+ opencv-python
11
+ moviepy
12
+ torchvision==0.23.0
13
+ torchaudio==2.8.0
14
+ transformers
15
+ tokenizers
16
+ accelerate
17
+ tqdm
18
+ imageio[ffmpeg]
19
+ easydict
20
+ imageio-ffmpeg
21
+ numpy>=1.23.5,<2
22
+ hydra-core
23
+ iopath
24
+ pytest
25
+ pillow
26
+ librosa
27
+ fvcore
28
+ onnxruntime-gpu
29
+ flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
wan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
3
+ from .image2video import WanI2V
4
+ from .speech2video import WanS2V
5
+ from .text2video import WanT2V
6
+ from .textimage2video import WanTI2V
7
+ from .animate import WanAnimate
wan/animate.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+ import math
4
+ import os
5
+ import cv2
6
+ import types
7
+ from copy import deepcopy
8
+ from functools import partial
9
+ from einops import rearrange
10
+ import numpy as np
11
+ import torch
12
+
13
+ import torch.distributed as dist
14
+ from peft import set_peft_model_state_dict
15
+ from decord import VideoReader
16
+ from tqdm import tqdm
17
+ import torch.nn.functional as F
18
+ from .distributed.fsdp import shard_model
19
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
20
+ from .distributed.util import get_world_size
21
+
22
+ from .modules.animate import WanAnimateModel
23
+ from .modules.animate import CLIPModel
24
+ from .modules.t5 import T5EncoderModel
25
+ from .modules.vae2_1 import Wan2_1_VAE
26
+ from .modules.animate.animate_utils import TensorList, get_loraconfig
27
+ from .utils.fm_solvers import (
28
+ FlowDPMSolverMultistepScheduler,
29
+ get_sampling_sigmas,
30
+ retrieve_timesteps,
31
+ )
32
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
33
+
34
+
35
+
36
+ class WanAnimate:
37
+
38
+ def __init__(
39
+ self,
40
+ config,
41
+ checkpoint_dir,
42
+ device_id=0,
43
+ rank=0,
44
+ t5_fsdp=False,
45
+ dit_fsdp=False,
46
+ use_sp=False,
47
+ t5_cpu=False,
48
+ init_on_cpu=True,
49
+ convert_model_dtype=False,
50
+ use_relighting_lora=False
51
+ ):
52
+ r"""
53
+ Initializes the generation model components.
54
+
55
+ Args:
56
+ config (EasyDict):
57
+ Object containing model parameters initialized from config.py
58
+ checkpoint_dir (`str`):
59
+ Path to directory containing model checkpoints
60
+ device_id (`int`, *optional*, defaults to 0):
61
+ Id of target GPU device
62
+ rank (`int`, *optional*, defaults to 0):
63
+ Process rank for distributed training
64
+ t5_fsdp (`bool`, *optional*, defaults to False):
65
+ Enable FSDP sharding for T5 model
66
+ dit_fsdp (`bool`, *optional*, defaults to False):
67
+ Enable FSDP sharding for DiT model
68
+ use_sp (`bool`, *optional*, defaults to False):
69
+ Enable distribution strategy of sequence parallel.
70
+ t5_cpu (`bool`, *optional*, defaults to False):
71
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
72
+ init_on_cpu (`bool`, *optional*, defaults to True):
73
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
74
+ convert_model_dtype (`bool`, *optional*, defaults to False):
75
+ Convert DiT model parameters dtype to 'config.param_dtype'.
76
+ Only works without FSDP.
77
+ use_relighting_lora (`bool`, *optional*, defaults to False):
78
+ Whether to use relighting lora for character replacement.
79
+ """
80
+ self.device = torch.device(f"cuda:{device_id}")
81
+ self.config = config
82
+ self.rank = rank
83
+ self.t5_cpu = t5_cpu
84
+ self.init_on_cpu = init_on_cpu
85
+
86
+ self.num_train_timesteps = config.num_train_timesteps
87
+ self.param_dtype = config.param_dtype
88
+
89
+ if t5_fsdp or dit_fsdp or use_sp:
90
+ self.init_on_cpu = False
91
+
92
+ shard_fn = partial(shard_model, device_id=device_id)
93
+ self.text_encoder = T5EncoderModel(
94
+ text_len=config.text_len,
95
+ dtype=config.t5_dtype,
96
+ device=torch.device('cpu'),
97
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
98
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
99
+ shard_fn=shard_fn if t5_fsdp else None,
100
+ )
101
+
102
+ self.clip = CLIPModel(
103
+ dtype=torch.float16,
104
+ device=self.device,
105
+ checkpoint_path=os.path.join(checkpoint_dir,
106
+ config.clip_checkpoint),
107
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
108
+
109
+ self.vae = Wan2_1_VAE(
110
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
111
+ device=self.device)
112
+
113
+ logging.info(f"Creating WanAnimate from {checkpoint_dir}")
114
+
115
+ if not dit_fsdp:
116
+ self.noise_model = WanAnimateModel.from_pretrained(
117
+ checkpoint_dir,
118
+ torch_dtype=self.param_dtype,
119
+ device_map=self.device)
120
+ else:
121
+ self.noise_model = WanAnimateModel.from_pretrained(
122
+ checkpoint_dir, torch_dtype=self.param_dtype)
123
+
124
+ self.noise_model = self._configure_model(
125
+ model=self.noise_model,
126
+ use_sp=use_sp,
127
+ dit_fsdp=dit_fsdp,
128
+ shard_fn=shard_fn,
129
+ convert_model_dtype=convert_model_dtype,
130
+ use_lora=use_relighting_lora,
131
+ checkpoint_dir=checkpoint_dir,
132
+ config=config
133
+ )
134
+
135
+ # self.noise_model = torch.compile(self.noise_model)
136
+
137
+ if use_sp:
138
+ self.sp_size = get_world_size()
139
+ else:
140
+ self.sp_size = 1
141
+
142
+ self.sample_neg_prompt = config.sample_neg_prompt
143
+ self.sample_prompt = config.prompt
144
+
145
+
146
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
147
+ convert_model_dtype, use_lora, checkpoint_dir, config):
148
+ """
149
+ Configures a model object. This includes setting evaluation modes,
150
+ applying distributed parallel strategy, and handling device placement.
151
+
152
+ Args:
153
+ model (torch.nn.Module):
154
+ The model instance to configure.
155
+ use_sp (`bool`):
156
+ Enable distribution strategy of sequence parallel.
157
+ dit_fsdp (`bool`):
158
+ Enable FSDP sharding for DiT model.
159
+ shard_fn (callable):
160
+ The function to apply FSDP sharding.
161
+ convert_model_dtype (`bool`):
162
+ Convert DiT model parameters dtype to 'config.param_dtype'.
163
+ Only works without FSDP.
164
+
165
+ Returns:
166
+ torch.nn.Module:
167
+ The configured model.
168
+ """
169
+ model.eval().requires_grad_(False)
170
+
171
+ if use_sp:
172
+ for block in model.blocks:
173
+ block.self_attn.forward = types.MethodType(
174
+ sp_attn_forward, block.self_attn)
175
+
176
+ model.use_context_parallel = True
177
+
178
+ if dist.is_initialized():
179
+ dist.barrier()
180
+
181
+ if use_lora:
182
+ logging.info("Loading Relighting Lora. ")
183
+ lora_config = get_loraconfig(
184
+ transformer=model,
185
+ rank=128,
186
+ alpha=128
187
+ )
188
+ model.add_adapter(lora_config)
189
+ lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)
190
+ peft_state_dict = torch.load(lora_path)["state_dict"]
191
+ set_peft_model_state_dict(model, peft_state_dict)
192
+
193
+ if dit_fsdp:
194
+ model = shard_fn(model, use_lora=use_lora)
195
+ else:
196
+ if convert_model_dtype:
197
+ model.to(self.param_dtype)
198
+ if not self.init_on_cpu:
199
+ model.to(self.device)
200
+
201
+ return model
202
+
203
+ def inputs_padding(self, array, target_len):
204
+ idx = 0
205
+ flip = False
206
+ target_array = []
207
+ while len(target_array) < target_len:
208
+ target_array.append(deepcopy(array[idx]))
209
+ if flip:
210
+ idx -= 1
211
+ else:
212
+ idx += 1
213
+ if idx == 0 or idx == len(array) - 1:
214
+ flip = not flip
215
+ return target_array[:target_len]
216
+
217
+ def get_valid_len(self, real_len, clip_len=81, overlap=1):
218
+ real_clip_len = clip_len - overlap
219
+ last_clip_num = (real_len - overlap) % real_clip_len
220
+ if last_clip_num == 0:
221
+ extra = 0
222
+ else:
223
+ extra = real_clip_len - last_clip_num
224
+ target_len = real_len + extra
225
+ return target_len
226
+
227
+
228
+ def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
229
+ if mask_pixel_values is None:
230
+ msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
231
+ else:
232
+ msk = mask_pixel_values.clone()
233
+ msk[:, :mask_len] = 1
234
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
235
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
236
+ msk = msk.transpose(1, 2)[0]
237
+ return msk
238
+
239
+ def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
240
+ ori_height = img_ori.shape[0]
241
+ ori_width = img_ori.shape[1]
242
+ channel = img_ori.shape[2]
243
+
244
+ img_pad = np.zeros((height, width, channel))
245
+ if channel == 1:
246
+ img_pad[:, :, 0] = padding_color[0]
247
+ else:
248
+ img_pad[:, :, 0] = padding_color[0]
249
+ img_pad[:, :, 1] = padding_color[1]
250
+ img_pad[:, :, 2] = padding_color[2]
251
+
252
+ if (ori_height / ori_width) > (height / width):
253
+ new_width = int(height / ori_height * ori_width)
254
+ img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
255
+ padding = int((width - new_width) / 2)
256
+ if len(img.shape) == 2:
257
+ img = img[:, :, np.newaxis]
258
+ img_pad[:, padding: padding + new_width, :] = img
259
+ else:
260
+ new_height = int(width / ori_width * ori_height)
261
+ img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
262
+ padding = int((height - new_height) / 2)
263
+ if len(img.shape) == 2:
264
+ img = img[:, :, np.newaxis]
265
+ img_pad[padding: padding + new_height, :, :] = img
266
+
267
+ img_pad = np.uint8(img_pad)
268
+
269
+ return img_pad
270
+
271
+ def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
272
+ pose_video_reader = VideoReader(src_pose_path)
273
+ pose_len = len(pose_video_reader)
274
+ pose_idxs = list(range(pose_len))
275
+ cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()
276
+
277
+ face_video_reader = VideoReader(src_face_path)
278
+ face_len = len(face_video_reader)
279
+ face_idxs = list(range(face_len))
280
+ face_images = face_video_reader.get_batch(face_idxs).asnumpy()
281
+ height, width = cond_images[0].shape[:2]
282
+ refer_images = cv2.imread(src_ref_path)[..., ::-1]
283
+ refer_images = self.padding_resize(refer_images, height=height, width=width)
284
+ return cond_images, face_images, refer_images
285
+
286
+ def prepare_source_for_replace(self, src_bg_path, src_mask_path):
287
+ bg_video_reader = VideoReader(src_bg_path)
288
+ bg_len = len(bg_video_reader)
289
+ bg_idxs = list(range(bg_len))
290
+ bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()
291
+
292
+ mask_video_reader = VideoReader(src_mask_path)
293
+ mask_len = len(mask_video_reader)
294
+ mask_idxs = list(range(mask_len))
295
+ mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
296
+ mask_images = mask_images[:, :, :, 0] / 255
297
+ return bg_images, mask_images
298
+
299
+ def generate(
300
+ self,
301
+ src_root_path,
302
+ replace_flag=False,
303
+ clip_len=77,
304
+ refert_num=1,
305
+ shift=5.0,
306
+ sample_solver='dpm++',
307
+ sampling_steps=20,
308
+ guide_scale=1,
309
+ input_prompt="",
310
+ n_prompt="",
311
+ seed=-1,
312
+ offload_model=True,
313
+ ):
314
+ r"""
315
+ Generates video frames from input image using diffusion process.
316
+
317
+ Args:
318
+ src_root_path ('str'):
319
+ Process output path
320
+ replace_flag (`bool`, *optional*, defaults to False):
321
+ Whether to use character replace.
322
+ clip_len (`int`, *optional*, defaults to 77):
323
+ How many frames to generate per clips. The number should be 4n+1
324
+ refert_num (`int`, *optional*, defaults to 1):
325
+ How many frames used for temporal guidance. Recommended to be 1 or 5.
326
+ shift (`float`, *optional*, defaults to 5.0):
327
+ Noise schedule shift parameter.
328
+ sample_solver (`str`, *optional*, defaults to 'dpm++'):
329
+ Solver used to sample the video.
330
+ sampling_steps (`int`, *optional*, defaults to 20):
331
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
332
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
333
+ Classifier-free guidance scale. We only use it for expression control.
334
+ In most cases, it's not necessary and faster generation can be achieved without it.
335
+ When expression adjustments are needed, you may consider using this feature.
336
+ input_prompt (`str`):
337
+ Text prompt for content generation. We don't recommend custom prompts (although they work)
338
+ n_prompt (`str`, *optional*, defaults to ""):
339
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
340
+ seed (`int`, *optional*, defaults to -1):
341
+ Random seed for noise generation. If -1, use random seed
342
+ offload_model (`bool`, *optional*, defaults to True):
343
+ If True, offloads models to CPU during generation to save VRAM
344
+
345
+ Returns:
346
+ torch.Tensor:
347
+ Generated video frames tensor. Dimensions: (C, N, H, W) where:
348
+ - C: Color channels (3 for RGB)
349
+ - N: Number of frames
350
+ - H: Frame height
351
+ - W: Frame width
352
+ """
353
+ assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5."
354
+
355
+ seed_g = torch.Generator(device=self.device)
356
+ seed_g.manual_seed(seed)
357
+
358
+
359
+ if n_prompt == "":
360
+ n_prompt = self.sample_neg_prompt
361
+
362
+ if input_prompt == "":
363
+ input_prompt = self.sample_prompt
364
+
365
+ src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
366
+ src_face_path = os.path.join(src_root_path, "src_face.mp4")
367
+ src_ref_path = os.path.join(src_root_path, "src_ref.png")
368
+
369
+ cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)
370
+
371
+ if not self.t5_cpu:
372
+ self.text_encoder.model.to(self.device)
373
+ context = self.text_encoder([input_prompt], self.device)
374
+ context_null = self.text_encoder([n_prompt], self.device)
375
+ if offload_model:
376
+ self.text_encoder.model.cpu()
377
+ else:
378
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
379
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
380
+ context = [t.to(self.device) for t in context]
381
+ context_null = [t.to(self.device) for t in context_null]
382
+
383
+ real_frame_len = len(cond_images)
384
+ target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
385
+ logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))
386
+ cond_images = self.inputs_padding(cond_images, target_len)
387
+ face_images = self.inputs_padding(face_images, target_len)
388
+
389
+ if replace_flag:
390
+ src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
391
+ src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
392
+ bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
393
+ bg_images = self.inputs_padding(bg_images, target_len)
394
+ mask_images = self.inputs_padding(mask_images, target_len)
395
+ self.noise_model.enable_adapters()
396
+ else:
397
+ self.noise_model.disable_adapters()
398
+
399
+
400
+ height, width = refer_images.shape[:2]
401
+ start = 0
402
+ end = clip_len
403
+ all_out_frames = []
404
+ total_iterations = ((len(cond_images) - 1) // clip_len + 1) * sampling_steps
405
+
406
+ with tqdm(total=total_iterations) as pbar:
407
+
408
+ while True:
409
+ if start + refert_num >= len(cond_images):
410
+ break
411
+
412
+ if start == 0:
413
+ mask_reft_len = 0
414
+ else:
415
+ mask_reft_len = refert_num
416
+
417
+ batch = {
418
+ "conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width),
419
+ "bg_pixel_values": torch.zeros(1, 3, clip_len, height, width),
420
+ "mask_pixel_values": torch.zeros(1, 1, clip_len, height, width),
421
+ "face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512),
422
+ "refer_pixel_values": torch.zeros(1, 3, height, width),
423
+ "refer_t_pixel_values": torch.zeros(refert_num, 3, height, width)
424
+ }
425
+
426
+ batch["conditioning_pixel_values"] = rearrange(
427
+ torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),
428
+ "t h w c -> 1 c t h w",
429
+ )
430
+ batch["face_pixel_values"] = rearrange(
431
+ torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),
432
+ "t h w c -> 1 c t h w",
433
+ )
434
+
435
+ batch["refer_pixel_values"] = rearrange(
436
+ torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w"
437
+ )
438
+
439
+ if start > 0:
440
+ batch["refer_t_pixel_values"] = rearrange(
441
+ out_frames[0, :, -refert_num:].clone().detach(),
442
+ "c t h w -> t c h w",
443
+ )
444
+
445
+ batch["refer_t_pixel_values"] = rearrange(batch["refer_t_pixel_values"],
446
+ "t c h w -> 1 c t h w",
447
+ )
448
+
449
+ if replace_flag:
450
+ batch["bg_pixel_values"] = rearrange(
451
+ torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),
452
+ "t h w c -> 1 c t h w",
453
+ )
454
+
455
+ batch["mask_pixel_values"] = rearrange(
456
+ torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),
457
+ "t h w c -> 1 t c h w",
458
+ )
459
+
460
+
461
+ for key, value in batch.items():
462
+ if isinstance(value, torch.Tensor):
463
+ batch[key] = value.to(device=self.device, dtype=torch.bfloat16)
464
+
465
+ ref_pixel_values = batch["refer_pixel_values"]
466
+ refer_t_pixel_values = batch["refer_t_pixel_values"]
467
+ conditioning_pixel_values = batch["conditioning_pixel_values"]
468
+ face_pixel_values = batch["face_pixel_values"]
469
+
470
+ B, _, H, W = ref_pixel_values.shape
471
+ T = clip_len
472
+ lat_h = H // 8
473
+ lat_w = W // 8
474
+ lat_t = T // 4 + 1
475
+ target_shape = [lat_t + 1, lat_h, lat_w]
476
+ noise = [
477
+ torch.randn(
478
+ 16,
479
+ target_shape[0],
480
+ target_shape[1],
481
+ target_shape[2],
482
+ dtype=torch.float32,
483
+ device=self.device,
484
+ generator=seed_g,
485
+ )
486
+ ]
487
+
488
+ max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size
489
+ if max_seq_len % self.sp_size != 0:
490
+ raise ValueError(f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}")
491
+
492
+ with (
493
+ torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),
494
+ torch.no_grad()
495
+ ):
496
+ if sample_solver == 'unipc':
497
+ sample_scheduler = FlowUniPCMultistepScheduler(
498
+ num_train_timesteps=self.num_train_timesteps,
499
+ shift=1,
500
+ use_dynamic_shifting=False)
501
+ sample_scheduler.set_timesteps(
502
+ sampling_steps, device=self.device, shift=shift)
503
+ timesteps = sample_scheduler.timesteps
504
+ elif sample_solver == 'dpm++':
505
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
506
+ num_train_timesteps=self.num_train_timesteps,
507
+ shift=1,
508
+ use_dynamic_shifting=False)
509
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
510
+ timesteps, _ = retrieve_timesteps(
511
+ sample_scheduler,
512
+ device=self.device,
513
+ sigmas=sampling_sigmas)
514
+ else:
515
+ raise NotImplementedError("Unsupported solver.")
516
+
517
+ latents = noise
518
+
519
+ pose_latents_no_ref = self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))
520
+ pose_latents_no_ref = torch.stack(pose_latents_no_ref)
521
+ pose_latents = torch.cat([pose_latents_no_ref], dim=2)
522
+
523
+ ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w")
524
+ ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16))
525
+ ref_latents = torch.stack(ref_latents)
526
+
527
+ mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)
528
+ y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)
529
+
530
+ img = ref_pixel_values[0, :, 0]
531
+ clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)
532
+
533
+ if mask_reft_len > 0:
534
+ if replace_flag:
535
+ bg_pixel_values = batch["bg_pixel_values"]
536
+ y_reft = self.vae.encode(
537
+ [
538
+ torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)
539
+ ]
540
+ )[0]
541
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
542
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
543
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
544
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
545
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
546
+ else:
547
+ y_reft = self.vae.encode(
548
+ [
549
+ torch.concat(
550
+ [
551
+ torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),
552
+ size=(H, W), mode="bicubic"),
553
+ torch.zeros(3, T - mask_reft_len, H, W),
554
+ ],
555
+ dim=1,
556
+ ).to(self.device)
557
+ ]
558
+ )[0]
559
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
560
+ else:
561
+ if replace_flag:
562
+ bg_pixel_values = batch["bg_pixel_values"]
563
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
564
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
565
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
566
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
567
+ y_reft = self.vae.encode(
568
+ [
569
+ torch.concat(
570
+ [
571
+ bg_pixel_values[0],
572
+ ],
573
+ dim=1,
574
+ ).to(self.device)
575
+ ]
576
+ )[0]
577
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
578
+ else:
579
+ y_reft = self.vae.encode(
580
+ [
581
+ torch.concat(
582
+ [
583
+ torch.zeros(3, T - mask_reft_len, H, W),
584
+ ],
585
+ dim=1,
586
+ ).to(self.device)
587
+ ]
588
+ )[0]
589
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
590
+
591
+ y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)
592
+ y = torch.concat([y_ref, y_reft], dim=1)
593
+
594
+ arg_c = {
595
+ "context": context,
596
+ "seq_len": max_seq_len,
597
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
598
+ "y": [y],
599
+ "pose_latents": pose_latents,
600
+ "face_pixel_values": face_pixel_values,
601
+ }
602
+
603
+ if guide_scale > 1:
604
+ face_pixel_values_uncond = face_pixel_values * 0 - 1
605
+ arg_null = {
606
+ "context": context_null,
607
+ "seq_len": max_seq_len,
608
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
609
+ "y": [y],
610
+ "pose_latents": pose_latents,
611
+ "face_pixel_values": face_pixel_values_uncond,
612
+ }
613
+
614
+ for i, t in enumerate(timesteps):
615
+ latent_model_input = latents
616
+ timestep = [t]
617
+
618
+ timestep = torch.stack(timestep)
619
+
620
+ noise_pred_cond = TensorList(
621
+ self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)
622
+ )
623
+
624
+ if guide_scale > 1:
625
+ noise_pred_uncond = TensorList(
626
+ self.noise_model(
627
+ TensorList(latent_model_input), t=timestep, **arg_null
628
+ )
629
+ )
630
+ noise_pred = noise_pred_uncond + guide_scale * (
631
+ noise_pred_cond - noise_pred_uncond
632
+ )
633
+ else:
634
+ noise_pred = noise_pred_cond
635
+
636
+ temp_x0 = sample_scheduler.step(
637
+ noise_pred[0].unsqueeze(0),
638
+ t,
639
+ latents[0].unsqueeze(0),
640
+ return_dict=False,
641
+ generator=seed_g,
642
+ )[0]
643
+ latents[0] = temp_x0.squeeze(0)
644
+
645
+ x0 = latents
646
+
647
+ if pbar is not None:
648
+ pbar.update(1)
649
+
650
+
651
+ x0 = [x.to(dtype=torch.float32) for x in x0]
652
+ out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))
653
+
654
+ if start != 0:
655
+ out_frames = out_frames[:, :, refert_num:]
656
+
657
+ all_out_frames.append(out_frames.cpu())
658
+
659
+ start += clip_len - refert_num
660
+ end += clip_len - refert_num
661
+
662
+ videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
663
+ return videos[0] if self.rank == 0 else None
wan/configs/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_A14B import i2v_A14B
8
+ from .wan_s2v_14B import s2v_14B
9
+ from .wan_t2v_A14B import t2v_A14B
10
+ from .wan_ti2v_5B import ti2v_5B
11
+ from .wan_animate_14B import animate_14B
12
+
13
+ WAN_CONFIGS = {
14
+ 't2v-A14B': t2v_A14B,
15
+ 'i2v-A14B': i2v_A14B,
16
+ 'ti2v-5B': ti2v_5B,
17
+ 'animate-14B': animate_14B,
18
+ 's2v-14B': s2v_14B,
19
+ }
20
+
21
+ SIZE_CONFIGS = {
22
+ '720*1280': (720, 1280),
23
+ '1280*720': (1280, 720),
24
+ '480*832': (480, 832),
25
+ '832*480': (832, 480),
26
+ '704*1280': (704, 1280),
27
+ '1280*704': (1280, 704),
28
+ '1024*704': (1024, 704),
29
+ '704*1024': (704, 1024),
30
+ }
31
+
32
+ MAX_AREA_CONFIGS = {
33
+ '720*1280': 720 * 1280,
34
+ '1280*720': 1280 * 720,
35
+ '480*832': 480 * 832,
36
+ '832*480': 832 * 480,
37
+ '704*1280': 704 * 1280,
38
+ '1280*704': 1280 * 704,
39
+ '1024*704': 1024 * 704,
40
+ '704*1024': 704 * 1024,
41
+ }
42
+
43
+ SUPPORTED_SIZES = {
44
+ 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
45
+ 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
46
+ 'ti2v-5B': ('704*1280', '1280*704'),
47
+ 's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',
48
+ '704*1024', '704*1280', '1280*704'),
49
+ 'animate-14B': ('720*1280', '1280*720')
50
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ wan_shared_cfg.frame_num = 81
wan/configs/wan_animate_14B.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan animate 14B ------------------------#
7
+ animate_14B = EasyDict(__name__='Config: Wan animate 14B')
8
+ animate_14B.update(wan_shared_cfg)
9
+
10
+ animate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
11
+ animate_14B.t5_tokenizer = 'google/umt5-xxl'
12
+
13
+ animate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
14
+ animate_14B.clip_tokenizer = 'xlm-roberta-large'
15
+ animate_14B.lora_checkpoint = 'relighting_lora.ckpt'
16
+ # vae
17
+ animate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
18
+ animate_14B.vae_stride = (4, 8, 8)
19
+
20
+ # transformer
21
+ animate_14B.patch_size = (1, 2, 2)
22
+ animate_14B.dim = 5120
23
+ animate_14B.ffn_dim = 13824
24
+ animate_14B.freq_dim = 256
25
+ animate_14B.num_heads = 40
26
+ animate_14B.num_layers = 40
27
+ animate_14B.window_size = (-1, -1)
28
+ animate_14B.qk_norm = True
29
+ animate_14B.cross_attn_norm = True
30
+ animate_14B.eps = 1e-6
31
+ animate_14B.use_face_encoder = True
32
+ animate_14B.motion_encoder_dim = 512
33
+
34
+ # inference
35
+ animate_14B.sample_shift = 5.0
36
+ animate_14B.sample_steps = 5
37
+ animate_14B.sample_guide_scale = 1.0
38
+ animate_14B.frame_num = 77
39
+ animate_14B.sample_fps = 30
40
+ animate_14B.prompt = '视频中的人在做动作'
wan/configs/wan_i2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V A14B ------------------------#
8
+
9
+ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
+ i2v_A14B.update(wan_shared_cfg)
11
+
12
+ i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ i2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ i2v_A14B.patch_size = (1, 2, 2)
21
+ i2v_A14B.dim = 5120
22
+ i2v_A14B.ffn_dim = 13824
23
+ i2v_A14B.freq_dim = 256
24
+ i2v_A14B.num_heads = 40
25
+ i2v_A14B.num_layers = 40
26
+ i2v_A14B.window_size = (-1, -1)
27
+ i2v_A14B.qk_norm = True
28
+ i2v_A14B.cross_attn_norm = True
29
+ i2v_A14B.eps = 1e-6
30
+ i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ i2v_A14B.sample_shift = 5.0
35
+ i2v_A14B.sample_steps = 40
36
+ i2v_A14B.boundary = 0.900
37
+ i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
wan/configs/wan_s2v_14B.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan S2V 14B ------------------------#
7
+
8
+ s2v_14B = EasyDict(__name__='Config: Wan S2V 14B')
9
+ s2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ s2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ s2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ s2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ s2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # wav2vec
20
+ s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english"
21
+
22
+ s2v_14B.num_heads = 40
23
+ # transformer
24
+ s2v_14B.transformer = EasyDict(
25
+ __name__="Config: Transformer config for WanModel_S2V")
26
+ s2v_14B.transformer.patch_size = (1, 2, 2)
27
+ s2v_14B.transformer.dim = 5120
28
+ s2v_14B.transformer.ffn_dim = 13824
29
+ s2v_14B.transformer.freq_dim = 256
30
+ s2v_14B.transformer.num_heads = 40
31
+ s2v_14B.transformer.num_layers = 40
32
+ s2v_14B.transformer.window_size = (-1, -1)
33
+ s2v_14B.transformer.qk_norm = True
34
+ s2v_14B.transformer.cross_attn_norm = True
35
+ s2v_14B.transformer.eps = 1e-6
36
+ s2v_14B.transformer.enable_adain = True
37
+ s2v_14B.transformer.adain_mode = "attn_norm"
38
+ s2v_14B.transformer.audio_inject_layers = [
39
+ 0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39
40
+ ]
41
+ s2v_14B.transformer.zero_init = True
42
+ s2v_14B.transformer.zero_timestep = True
43
+ s2v_14B.transformer.enable_motioner = False
44
+ s2v_14B.transformer.add_last_motion = True
45
+ s2v_14B.transformer.trainable_token = False
46
+ s2v_14B.transformer.enable_tsm = False
47
+ s2v_14B.transformer.enable_framepack = True
48
+ s2v_14B.transformer.framepack_drop_mode = 'padd'
49
+ s2v_14B.transformer.audio_dim = 1024
50
+
51
+ s2v_14B.transformer.motion_frames = 73
52
+ s2v_14B.transformer.cond_dim = 16
53
+
54
+ # inference
55
+ s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
56
+ s2v_14B.drop_first_motion = True
57
+ s2v_14B.sample_shift = 3
58
+ s2v_14B.sample_steps = 40
59
+ s2v_14B.sample_guide_scale = 4.5
wan/configs/wan_t2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V A14B ------------------------#
7
+
8
+ t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
+ t2v_A14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_A14B.patch_size = (1, 2, 2)
21
+ t2v_A14B.dim = 5120
22
+ t2v_A14B.ffn_dim = 13824
23
+ t2v_A14B.freq_dim = 256
24
+ t2v_A14B.num_heads = 40
25
+ t2v_A14B.num_layers = 40
26
+ t2v_A14B.window_size = (-1, -1)
27
+ t2v_A14B.qk_norm = True
28
+ t2v_A14B.cross_attn_norm = True
29
+ t2v_A14B.eps = 1e-6
30
+ t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ t2v_A14B.sample_shift = 12.0
35
+ t2v_A14B.sample_steps = 40
36
+ t2v_A14B.boundary = 0.875
37
+ t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
wan/configs/wan_ti2v_5B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan TI2V 5B ------------------------#
7
+
8
+ ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
+ ti2v_5B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
+ ti2v_5B.vae_stride = (4, 16, 16)
18
+
19
+ # transformer
20
+ ti2v_5B.patch_size = (1, 2, 2)
21
+ ti2v_5B.dim = 3072
22
+ ti2v_5B.ffn_dim = 14336
23
+ ti2v_5B.freq_dim = 256
24
+ ti2v_5B.num_heads = 24
25
+ ti2v_5B.num_layers = 30
26
+ ti2v_5B.window_size = (-1, -1)
27
+ ti2v_5B.qk_norm = True
28
+ ti2v_5B.cross_attn_norm = True
29
+ ti2v_5B.eps = 1e-6
30
+
31
+ # inference
32
+ ti2v_5B.sample_fps = 24
33
+ ti2v_5B.sample_shift = 5.0
34
+ ti2v_5B.sample_steps = 50
35
+ ti2v_5B.sample_guide_scale = 5.0
36
+ ti2v_5B.frame_num = 121
wan/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ use_lora=False
22
+ ):
23
+ model = FSDP(
24
+ module=model,
25
+ process_group=process_group,
26
+ sharding_strategy=sharding_strategy,
27
+ auto_wrap_policy=partial(
28
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
29
+ mixed_precision=MixedPrecision(
30
+ param_dtype=param_dtype,
31
+ reduce_dtype=reduce_dtype,
32
+ buffer_dtype=buffer_dtype),
33
+ device_id=device_id,
34
+ sync_module_states=sync_module_states,
35
+ use_orig_params=True if use_lora else False)
36
+ return model
37
+
38
+
39
+ def free_model(model):
40
+ for m in model.modules():
41
+ if isinstance(m, FSDP):
42
+ _free_storage(m._handle.flat_param.data)
43
+ del model
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
wan/distributed/sequence_parallel.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+
5
+ from ..modules.model import sinusoidal_embedding_1d
6
+ from .ulysses import distributed_attention
7
+ from .util import gather_forward, get_rank, get_world_size
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+
23
+ @torch.amp.autocast('cuda', enabled=False)
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_world_size()
51
+ sp_rank = get_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output).float()
62
+
63
+
64
+ def sp_dit_forward(
65
+ self,
66
+ x,
67
+ t,
68
+ context,
69
+ seq_len,
70
+ y=None,
71
+ ):
72
+ """
73
+ x: A list of videos each with shape [C, T, H, W].
74
+ t: [B].
75
+ context: A list of text embeddings each with shape [L, C].
76
+ """
77
+ if self.model_type == 'i2v':
78
+ assert y is not None
79
+ # params
80
+ device = self.patch_embedding.weight.device
81
+ if self.freqs.device != device:
82
+ self.freqs = self.freqs.to(device)
83
+
84
+ if y is not None:
85
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
+
87
+ # embeddings
88
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
+ grid_sizes = torch.stack(
90
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
+ x = [u.flatten(2).transpose(1, 2) for u in x]
92
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
+ assert seq_lens.max() <= seq_len
94
+ x = torch.cat([
95
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
+ for u in x
97
+ ])
98
+
99
+ # time embeddings
100
+ if t.dim() == 1:
101
+ t = t.expand(t.size(0), seq_len)
102
+ with torch.amp.autocast('cuda', dtype=torch.float32):
103
+ bt = t.size(0)
104
+ t = t.flatten()
105
+ e = self.time_embedding(
106
+ sinusoidal_embedding_1d(self.freq_dim,
107
+ t).unflatten(0, (bt, seq_len)).float())
108
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
+
111
+ # context
112
+ context_lens = None
113
+ context = self.text_embedding(
114
+ torch.stack([
115
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
+ for u in context
117
+ ]))
118
+
119
+ # Context Parallel
120
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
+ e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
+ e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
+
124
+ # arguments
125
+ kwargs = dict(
126
+ e=e0,
127
+ seq_lens=seq_lens,
128
+ grid_sizes=grid_sizes,
129
+ freqs=self.freqs,
130
+ context=context,
131
+ context_lens=context_lens)
132
+
133
+ for block in self.blocks:
134
+ x = block(x, **kwargs)
135
+
136
+ # head
137
+ x = self.head(x, e)
138
+
139
+ # Context Parallel
140
+ x = gather_forward(x, dim=1)
141
+
142
+ # unpatchify
143
+ x = self.unpatchify(x, grid_sizes)
144
+ return [u.float() for u in x]
145
+
146
+
147
+ def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ half_dtypes = (torch.float16, torch.bfloat16)
150
+
151
+ def half(x):
152
+ return x if x.dtype in half_dtypes else x.to(dtype)
153
+
154
+ # query, key, value function
155
+ def qkv_fn(x):
156
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
157
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
158
+ v = self.v(x).view(b, s, n, d)
159
+ return q, k, v
160
+
161
+ q, k, v = qkv_fn(x)
162
+ q = rope_apply(q, grid_sizes, freqs)
163
+ k = rope_apply(k, grid_sizes, freqs)
164
+
165
+ x = distributed_attention(
166
+ half(q),
167
+ half(k),
168
+ half(v),
169
+ seq_lens,
170
+ window_size=self.window_size,
171
+ )
172
+
173
+ # output
174
+ x = x.flatten(2)
175
+ x = self.o(x)
176
+ return x
wan/distributed/ulysses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from ..modules.attention import flash_attention
6
+ from .util import all_to_all
7
+
8
+
9
+ def distributed_attention(
10
+ q,
11
+ k,
12
+ v,
13
+ seq_lens,
14
+ window_size=(-1, -1),
15
+ ):
16
+ """
17
+ Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
+ please refer to https://arxiv.org/pdf/2309.14509
19
+
20
+ Args:
21
+ q: [B, Lq // p, Nq, C1].
22
+ k: [B, Lk // p, Nk, C1].
23
+ v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
+ seq_lens: [B], length of each sequence in batch
25
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
+ """
27
+ if not dist.is_initialized():
28
+ raise ValueError("distributed group should be initialized.")
29
+ b = q.shape[0]
30
+
31
+ # gather q/k/v sequence
32
+ q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
+ k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
+ v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
+
36
+ # apply attention
37
+ x = flash_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ k_lens=seq_lens,
42
+ window_size=window_size,
43
+ )
44
+
45
+ # scatter q/k/v sequence
46
+ x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
+ return x
wan/distributed/util.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def init_distributed_group():
7
+ """r initialize sequence parallel group.
8
+ """
9
+ if not dist.is_initialized():
10
+ dist.init_process_group(backend='nccl')
11
+
12
+
13
+ def get_rank():
14
+ return dist.get_rank()
15
+
16
+
17
+ def get_world_size():
18
+ return dist.get_world_size()
19
+
20
+
21
+ def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
+ """
23
+ `scatter` along one dimension and `gather` along another.
24
+ """
25
+ world_size = get_world_size()
26
+ if world_size > 1:
27
+ inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
+ outputs = [torch.empty_like(u) for u in inputs]
29
+ dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
+ x = torch.cat(outputs, dim=gather_dim).contiguous()
31
+ return x
32
+
33
+
34
+ def all_gather(tensor):
35
+ world_size = dist.get_world_size()
36
+ if world_size == 1:
37
+ return [tensor]
38
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
+ torch.distributed.all_gather(tensor_list, tensor)
40
+ return tensor_list
41
+
42
+
43
+ def gather_forward(input, dim):
44
+ # skip if world_size == 1
45
+ world_size = dist.get_world_size()
46
+ if world_size == 1:
47
+ return input
48
+
49
+ # gather sequence
50
+ output = all_gather(input)
51
+ return torch.cat(output, dim=dim).contiguous()
wan/image2video.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_1 import Wan2_1_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+
32
+
33
+ class WanI2V:
34
+
35
+ def __init__(
36
+ self,
37
+ config,
38
+ checkpoint_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_sp=False,
44
+ t5_cpu=False,
45
+ init_on_cpu=True,
46
+ convert_model_dtype=False,
47
+ ):
48
+ r"""
49
+ Initializes the image-to-video generation model components.
50
+
51
+ Args:
52
+ config (EasyDict):
53
+ Object containing model parameters initialized from config.py
54
+ checkpoint_dir (`str`):
55
+ Path to directory containing model checkpoints
56
+ device_id (`int`, *optional*, defaults to 0):
57
+ Id of target GPU device
58
+ rank (`int`, *optional*, defaults to 0):
59
+ Process rank for distributed training
60
+ t5_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for T5 model
62
+ dit_fsdp (`bool`, *optional*, defaults to False):
63
+ Enable FSDP sharding for DiT model
64
+ use_sp (`bool`, *optional*, defaults to False):
65
+ Enable distribution strategy of sequence parallel.
66
+ t5_cpu (`bool`, *optional*, defaults to False):
67
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
68
+ init_on_cpu (`bool`, *optional*, defaults to True):
69
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
+ convert_model_dtype (`bool`, *optional*, defaults to False):
71
+ Convert DiT model parameters dtype to 'config.param_dtype'.
72
+ Only works without FSDP.
73
+ """
74
+ self.device = torch.device(f"cuda:{device_id}")
75
+ self.config = config
76
+ self.rank = rank
77
+ self.t5_cpu = t5_cpu
78
+ self.init_on_cpu = init_on_cpu
79
+
80
+ self.num_train_timesteps = config.num_train_timesteps
81
+ self.boundary = config.boundary
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None,
95
+ )
96
+
97
+ self.vae_stride = config.vae_stride
98
+ self.patch_size = config.patch_size
99
+ self.vae = Wan2_1_VAE(
100
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
+ device=self.device)
102
+
103
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
104
+ self.low_noise_model = WanModel.from_pretrained(
105
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
+ self.low_noise_model = self._configure_model(
107
+ model=self.low_noise_model,
108
+ use_sp=use_sp,
109
+ dit_fsdp=dit_fsdp,
110
+ shard_fn=shard_fn,
111
+ convert_model_dtype=convert_model_dtype)
112
+
113
+ self.high_noise_model = WanModel.from_pretrained(
114
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
+ self.high_noise_model = self._configure_model(
116
+ model=self.high_noise_model,
117
+ use_sp=use_sp,
118
+ dit_fsdp=dit_fsdp,
119
+ shard_fn=shard_fn,
120
+ convert_model_dtype=convert_model_dtype)
121
+ if use_sp:
122
+ self.sp_size = get_world_size()
123
+ else:
124
+ self.sp_size = 1
125
+
126
+ self.sample_neg_prompt = config.sample_neg_prompt
127
+
128
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
+ convert_model_dtype):
130
+ """
131
+ Configures a model object. This includes setting evaluation modes,
132
+ applying distributed parallel strategy, and handling device placement.
133
+
134
+ Args:
135
+ model (torch.nn.Module):
136
+ The model instance to configure.
137
+ use_sp (`bool`):
138
+ Enable distribution strategy of sequence parallel.
139
+ dit_fsdp (`bool`):
140
+ Enable FSDP sharding for DiT model.
141
+ shard_fn (callable):
142
+ The function to apply FSDP sharding.
143
+ convert_model_dtype (`bool`):
144
+ Convert DiT model parameters dtype to 'config.param_dtype'.
145
+ Only works without FSDP.
146
+
147
+ Returns:
148
+ torch.nn.Module:
149
+ The configured model.
150
+ """
151
+ model.eval().requires_grad_(False)
152
+
153
+ if use_sp:
154
+ for block in model.blocks:
155
+ block.self_attn.forward = types.MethodType(
156
+ sp_attn_forward, block.self_attn)
157
+ model.forward = types.MethodType(sp_dit_forward, model)
158
+
159
+ if dist.is_initialized():
160
+ dist.barrier()
161
+
162
+ if dit_fsdp:
163
+ model = shard_fn(model)
164
+ else:
165
+ if convert_model_dtype:
166
+ model.to(self.param_dtype)
167
+ if not self.init_on_cpu:
168
+ model.to(self.device)
169
+
170
+ return model
171
+
172
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
+ r"""
174
+ Prepares and returns the required model for the current timestep.
175
+
176
+ Args:
177
+ t (torch.Tensor):
178
+ current timestep.
179
+ boundary (`int`):
180
+ The timestep threshold. If `t` is at or above this value,
181
+ the `high_noise_model` is considered as the required model.
182
+ offload_model (`bool`):
183
+ A flag intended to control the offloading behavior.
184
+
185
+ Returns:
186
+ torch.nn.Module:
187
+ The active model on the target device for the current timestep.
188
+ """
189
+ if t.item() >= boundary:
190
+ required_model_name = 'high_noise_model'
191
+ offload_model_name = 'low_noise_model'
192
+ else:
193
+ required_model_name = 'low_noise_model'
194
+ offload_model_name = 'high_noise_model'
195
+ if offload_model or self.init_on_cpu:
196
+ if next(getattr(
197
+ self,
198
+ offload_model_name).parameters()).device.type == 'cuda':
199
+ getattr(self, offload_model_name).to('cpu')
200
+ if next(getattr(
201
+ self,
202
+ required_model_name).parameters()).device.type == 'cpu':
203
+ getattr(self, required_model_name).to(self.device)
204
+ return getattr(self, required_model_name)
205
+
206
+ def generate(self,
207
+ input_prompt,
208
+ img,
209
+ max_area=720 * 1280,
210
+ frame_num=81,
211
+ shift=5.0,
212
+ sample_solver='unipc',
213
+ sampling_steps=40,
214
+ guide_scale=5.0,
215
+ n_prompt="",
216
+ seed=-1,
217
+ offload_model=True):
218
+ r"""
219
+ Generates video frames from input image and text prompt using diffusion process.
220
+
221
+ Args:
222
+ input_prompt (`str`):
223
+ Text prompt for content generation.
224
+ img (PIL.Image.Image):
225
+ Input image tensor. Shape: [3, H, W]
226
+ max_area (`int`, *optional*, defaults to 720*1280):
227
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
228
+ frame_num (`int`, *optional*, defaults to 81):
229
+ How many frames to sample from a video. The number should be 4n+1
230
+ shift (`float`, *optional*, defaults to 5.0):
231
+ Noise schedule shift parameter. Affects temporal dynamics
232
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
234
+ Solver used to sample the video.
235
+ sampling_steps (`int`, *optional*, defaults to 40):
236
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
237
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
+ If tuple, the first guide_scale will be used for low noise model and
240
+ the second guide_scale will be used for high noise model.
241
+ n_prompt (`str`, *optional*, defaults to ""):
242
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
+ seed (`int`, *optional*, defaults to -1):
244
+ Random seed for noise generation. If -1, use random seed
245
+ offload_model (`bool`, *optional*, defaults to True):
246
+ If True, offloads models to CPU during generation to save VRAM
247
+
248
+ Returns:
249
+ torch.Tensor:
250
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
251
+ - C: Color channels (3 for RGB)
252
+ - N: Number of frames (81)
253
+ - H: Frame height (from max_area)
254
+ - W: Frame width from max_area)
255
+ """
256
+ # preprocess
257
+ guide_scale = (guide_scale, guide_scale) if isinstance(
258
+ guide_scale, float) else guide_scale
259
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
+
261
+ F = frame_num
262
+ h, w = img.shape[1:]
263
+ aspect_ratio = h / w
264
+ lat_h = round(
265
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
+ self.patch_size[1] * self.patch_size[1])
267
+ lat_w = round(
268
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
+ self.patch_size[2] * self.patch_size[2])
270
+ h = lat_h * self.vae_stride[1]
271
+ w = lat_w * self.vae_stride[2]
272
+
273
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
+ self.patch_size[1] * self.patch_size[2])
275
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
+
277
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
+ seed_g = torch.Generator(device=self.device)
279
+ seed_g.manual_seed(seed)
280
+ noise = torch.randn(
281
+ 16,
282
+ (F - 1) // self.vae_stride[0] + 1,
283
+ lat_h,
284
+ lat_w,
285
+ dtype=torch.float32,
286
+ generator=seed_g,
287
+ device=self.device)
288
+
289
+ msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
290
+ msk[:, 1:] = 0
291
+ msk = torch.concat([
292
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
+ ],
294
+ dim=1)
295
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
+ msk = msk.transpose(1, 2)[0]
297
+
298
+ if n_prompt == "":
299
+ n_prompt = self.sample_neg_prompt
300
+
301
+ # preprocess
302
+ if not self.t5_cpu:
303
+ self.text_encoder.model.to(self.device)
304
+ context = self.text_encoder([input_prompt], self.device)
305
+ context_null = self.text_encoder([n_prompt], self.device)
306
+ if offload_model:
307
+ self.text_encoder.model.cpu()
308
+ else:
309
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
310
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
+ context = [t.to(self.device) for t in context]
312
+ context_null = [t.to(self.device) for t in context_null]
313
+
314
+ y = self.vae.encode([
315
+ torch.concat([
316
+ torch.nn.functional.interpolate(
317
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
+ 0, 1),
319
+ torch.zeros(3, F - 1, h, w)
320
+ ],
321
+ dim=1).to(self.device)
322
+ ])[0]
323
+ y = torch.concat([msk, y])
324
+
325
+ @contextmanager
326
+ def noop_no_sync():
327
+ yield
328
+
329
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
+ noop_no_sync)
331
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
+ noop_no_sync)
333
+
334
+ # evaluation mode
335
+ with (
336
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
337
+ torch.no_grad(),
338
+ no_sync_low_noise(),
339
+ no_sync_high_noise(),
340
+ ):
341
+ boundary = self.boundary * self.num_train_timesteps
342
+
343
+ if sample_solver == 'unipc':
344
+ sample_scheduler = FlowUniPCMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sample_scheduler.set_timesteps(
349
+ sampling_steps, device=self.device, shift=shift)
350
+ timesteps = sample_scheduler.timesteps
351
+ elif sample_solver == 'dpm++':
352
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
353
+ num_train_timesteps=self.num_train_timesteps,
354
+ shift=1,
355
+ use_dynamic_shifting=False)
356
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
+ timesteps, _ = retrieve_timesteps(
358
+ sample_scheduler,
359
+ device=self.device,
360
+ sigmas=sampling_sigmas)
361
+ else:
362
+ raise NotImplementedError("Unsupported solver.")
363
+
364
+ # sample videos
365
+ latent = noise
366
+
367
+ arg_c = {
368
+ 'context': [context[0]],
369
+ 'seq_len': max_seq_len,
370
+ 'y': [y],
371
+ }
372
+
373
+ arg_null = {
374
+ 'context': context_null,
375
+ 'seq_len': max_seq_len,
376
+ 'y': [y],
377
+ }
378
+
379
+ if offload_model:
380
+ torch.cuda.empty_cache()
381
+
382
+ for _, t in enumerate(tqdm(timesteps)):
383
+ latent_model_input = [latent.to(self.device)]
384
+ timestep = [t]
385
+
386
+ timestep = torch.stack(timestep).to(self.device)
387
+
388
+ model = self._prepare_model_for_timestep(
389
+ t, boundary, offload_model)
390
+ sample_guide_scale = guide_scale[1] if t.item(
391
+ ) >= boundary else guide_scale[0]
392
+
393
+ noise_pred_cond = model(
394
+ latent_model_input, t=timestep, **arg_c)[0]
395
+ if offload_model:
396
+ torch.cuda.empty_cache()
397
+ noise_pred_uncond = model(
398
+ latent_model_input, t=timestep, **arg_null)[0]
399
+ if offload_model:
400
+ torch.cuda.empty_cache()
401
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
402
+ noise_pred_cond - noise_pred_uncond)
403
+
404
+ temp_x0 = sample_scheduler.step(
405
+ noise_pred.unsqueeze(0),
406
+ t,
407
+ latent.unsqueeze(0),
408
+ return_dict=False,
409
+ generator=seed_g)[0]
410
+ latent = temp_x0.squeeze(0)
411
+
412
+ x0 = [latent]
413
+ del latent_model_input, timestep
414
+
415
+ if offload_model:
416
+ self.low_noise_model.cpu()
417
+ self.high_noise_model.cpu()
418
+ torch.cuda.empty_cache()
419
+
420
+ if self.rank == 0:
421
+ videos = self.vae.decode(x0)
422
+
423
+ del noise, latent, x0
424
+ del sample_scheduler
425
+ if offload_model:
426
+ gc.collect()
427
+ torch.cuda.synchronize()
428
+ if dist.is_initialized():
429
+ dist.barrier()
430
+
431
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .attention import flash_attention
3
+ from .model import WanModel
4
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
+ from .tokenizers import HuggingfaceTokenizer
6
+ from .vae2_1 import Wan2_1_VAE
7
+ from .vae2_2 import Wan2_2_VAE
8
+
9
+ __all__ = [
10
+ 'Wan2_1_VAE',
11
+ 'Wan2_2_VAE',
12
+ 'WanModel',
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ 'HuggingfaceTokenizer',
18
+ 'flash_attention',
19
+ ]
wan/modules/animate/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .model_animate import WanAnimateModel
3
+ from .clip import CLIPModel
4
+ __all__ = ['WanAnimateModel', 'CLIPModel']