red-lab-91 / app.py
matthartman's picture
Deploy Gradio app with multiple files
a9d4b99 verified
import spaces
from huggingface_hub import snapshot_download, hf_hub_download
import os
import subprocess
import importlib, site
from PIL import Image
import uuid
import shutil
import time
import cv2
import json
import gradio as gr
import sys
import gc
BASE = os.path.dirname(os.path.abspath(__file__))
PREPROCESS_DIR = os.path.join(BASE, "wan", "modules", "animate", "preprocess")
sys.path.append(PREPROCESS_DIR)
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
def sh(cmd): subprocess.check_call(cmd, shell=True)
try:
sh("pip install flash-attn --no-build-isolation")
# print("Attempting to download and build sam2...")
# print("download sam")
# sam_dir = snapshot_download(repo_id="alexnasa/sam2")
# @spaces.GPU(duration=500)
# def install_sam():
# os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
# sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .")
# print("install sam")
# install_sam()
print("Attempting to download")
print("download sam")
snapshot_download(repo_id="alexnasa/sam2_C", local_dir=f"{os.getcwd()}" )
# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
print("sam2 installed successfully.")
except Exception as e:
raise gr.Error("sam2 installation failed")
import torch
from generate import generate, load_model
from preprocess_data import run as run_preprocess
from preprocess_data import load_preprocess_models
print(f"Torch version: {torch.__version__}")
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B")
wan_animate = load_model(True)
rc_mapping = {
"Video β†’ Ref Image" : False,
"Video ← Ref Image" : True
}
def preprocess_video(input_video_path, duration, session_id=None):
if session_id is None:
session_id = uuid.uuid4().hex
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
os.makedirs(output_dir, exist_ok=True)
process_video_path = os.path.join(output_dir, 'input_video.mp4')
clip_and_set_fps(input_video_path, process_video_path, duration_s=duration)
return process_video_path
def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None):
"""
Extracts the audio track from a video file and saves it as a WAV file.
Args:
video_path (str): Path to the input video file.
output_wav_path (str): Path to save the extracted WAV file.
sample_rate (int, optional): Output sample rate (e.g., 16000).
If None, keep the original.
"""
cmd = [
'ffmpeg',
'-i', video_path, # Input video
'-vn', # Disable video
'-acodec', 'pcm_s16le', # 16-bit PCM (WAV format)
'-ac', '1', # Mono channel (use '2' for stereo)
'-y', # Overwrite output
'-loglevel', 'error' # Cleaner output
]
# Only add the sample rate option if explicitly specified
if sample_rate is not None:
cmd.extend(['-ar', str(sample_rate)])
cmd.append(output_wav_path)
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
return True
except subprocess.CalledProcessError as e:
return False
def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path):
"""
Combines a silent MP4 video with a WAV audio file into a single MP4 with sound.
Args:
video_path (str): Path to the silent video file.
audio_path (str): Path to the WAV audio file.
output_video_path (str): Path to save the output MP4 with audio.
"""
cmd = [
'ffmpeg',
'-i', video_path, # Input video
'-i', audio_path, # Input audio
'-c:v', 'copy', # Copy video without re-encoding
'-c:a', 'aac', # Encode audio as AAC (MP4-compatible)
'-shortest', # Stop when the shortest stream ends
'-y', # Overwrite output
'-loglevel', 'error',
output_video_path
]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
def clip_and_set_fps(input_video_path, output_video_path, duration_s=2, target_fps=30):
"""
Trim to duration_s and (optionally) change FPS, without resizing.
- If target_fps is None, keeps the original FPS.
- Re-encodes video when changing FPS for predictable timing.
"""
vf = []
if target_fps is not None:
vf.append(f"fps={target_fps}")
vf_arg = ",".join(vf) if vf else None
cmd = [
"ffmpeg",
"-nostdin",
"-hide_banner",
"-y",
"-i", input_video_path,
"-t", str(duration_s),
]
if vf_arg:
cmd += ["-vf", vf_arg]
cmd += [
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-preset", "veryfast",
"-crf", "18",
"-c:a", "aac", # use aac so MP4 stays compatible
"-movflags", "+faststart",
output_video_path,
]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
def is_portrait(video_file):
# Get video information
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
error_msg = "Cannot open video file"
gr.Warning(error_msg)
orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return orig_width < orig_height
def calculate_time_required(max_duration_s, rc_bool):
if max_duration_s == 2:
return 120
elif max_duration_s == 4:
return 180
elif max_duration_s == 6:
return 260
elif max_duration_s == 8:
return 330
elif max_duration_s == 10:
return 340
def get_display_time_required(max_duration_s, rc_bool):
# the 30 seconds extra is just for saftey in case of a unexpected slow down
return calculate_time_required(max_duration_s, rc_bool) - 30
def update_time_required(max_duration_s, rc_str):
rc_bool = rc_mapping[rc_str]
duration_s = get_display_time_required(max_duration_s, rc_bool)
duration_m = duration_s / 60
return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
def get_duration(input_video, max_duration_s, edited_frame, rc_bool, session_id, progress):
return calculate_time_required(max_duration_s, rc_bool)
@spaces.GPU(duration=get_duration)
def _animate(input_video, max_duration_s, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),):
if session_id is None:
session_id = uuid.uuid4().hex
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
os.makedirs(output_dir, exist_ok=True)
preprocess_dir = os.path.join(output_dir, "preprocess_dir")
os.makedirs(preprocess_dir, exist_ok=True)
output_video_path = os.path.join(output_dir, 'result.mp4')
# --- Measure preprocess time ---
start_preprocess = time.time()
if is_portrait(input_video):
w = 480
h = 832
else:
w = 832
h = 480
tag_string = "retarget_flag"
if rc_bool:
tag_string = "replace_flag"
preprocess_model = load_preprocess_models()
run_preprocess(preprocess_model, input_video, edited_frame, preprocess_dir, w, h, tag_string)
preprocess_time = time.time() - start_preprocess
print(f"Preprocess took {preprocess_time:.2f} seconds")
# --- Measure generate time ---
start_generate = time.time()
generate(wan_animate, preprocess_dir, output_video_path, rc_bool)
generate_time = time.time() - start_generate
print(f"Generate took {generate_time:.2f} seconds")
# --- Optional total time ---
total_time = preprocess_time + generate_time
print(f"Total time: {total_time:.2f} seconds")
gc.collect()
torch.cuda.empty_cache()
return output_video_path
def animate_scene(input_video, max_duration_s, edited_frame, rc_str, use_ai_image, ai_prompt, session_id = None, progress=gr.Progress(track_tqdm=True),):
if not input_video:
raise gr.Error("Please provide an video")
if not use_ai_image and not edited_frame:
raise gr.Error("Please provide an image or enable AI generation")
if use_ai_image and not ai_prompt:
raise gr.Error("Please provide a prompt for AI image generation")
if session_id is None:
session_id = uuid.uuid4().hex
input_video = preprocess_video(input_video, max_duration_s, session_id)
rc_bool = rc_mapping[rc_str]
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
os.makedirs(output_dir, exist_ok=True)
input_audio_path = os.path.join(output_dir, 'input_audio.wav')
audio_extracted = extract_audio_from_video_ffmpeg(input_video, input_audio_path)
edited_frame_png = os.path.join(output_dir, 'edited_frame.png')
if use_ai_image:
# Generate image using AI model
generated_image = generate_ai_image(ai_prompt, session_id)
edited_frame_img = generated_image
else:
edited_frame_img = Image.open(edited_frame)
edited_frame_img.save(edited_frame_png)
print(f'{session_id} inference started')
output_video_path = _animate(input_video, max_duration_s, edited_frame_png, rc_bool, session_id, progress)
final_video_path = os.path.join(output_dir, 'final_result.mp4')
preprocess_dir = os.path.join(output_dir, "preprocess_dir")
pose_video = os.path.join(preprocess_dir, 'src_pose.mp4')
if rc_bool:
mask_video = os.path.join(preprocess_dir, 'src_mask.mp4')
bg_video = os.path.join(preprocess_dir, 'src_bg.mp4')
face_video = os.path.join(preprocess_dir, 'src_face.mp4')
else:
mask_video = os.path.join(preprocess_dir, 'src_pose.mp4')
bg_video = os.path.join(preprocess_dir, 'src_pose.mp4')
face_video = os.path.join(preprocess_dir, 'src_pose.mp4')
if audio_extracted:
combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path)
else:
final_video_path = output_video_path
print(f"task for {session_id} finalised")
return final_video_path, pose_video, bg_video, mask_video, face_video
css = """
#col-container {
margin: 0 auto;
max-width: 1600px;
}
#step-column {
padding: 10px;
border-radius: 8px;
box-shadow: var(--card-shadow);
margin: 10px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
.button-gradient {
background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
border: none;
padding: 14px 28px;
font-size: 16px;
font-weight: bold;
color: white;
border-radius: 10px;
cursor: pointer;
transition: 0.3s ease-in-out;
animation: 2s linear 0s infinite normal none running gradientAnimation;
box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
}
.toggle-container {
display: inline-flex;
background-color: #ffd6ff; /* light pink background */
border-radius: 9999px;
padding: 4px;
position: relative;
width: fit-content;
font-family: sans-serif;
}
.toggle-container input[type="radio"] {
display: none;
}
.toggle-container label {
position: relative;
z-index: 2;
flex: 1;
text-align: center;
font-weight: 700;
color: #4b2ab5; /* dark purple text for unselected */
padding: 6px 22px;
border-radius: 9999px;
cursor: pointer;
transition: color 0.25s ease;
}
/* Moving highlight */
.toggle-highlight {
position: absolute;
top: 4px;
left: 4px;
width: calc(50% - 4px);
height: calc(100% - 8px);
background-color: #4b2ab5; /* dark purple background */
border-radius: 9999px;
transition: transform 0.25s ease;
z-index: 1;
}
/* When "True" is checked */
#true:checked ~ label[for="true"] {
color: #ffd6ff; /* light pink text */
}
/* When "False" is checked */
#false:checked ~ label[for="false"] {
color: #ffd6ff; /* light pink text */
}
/* Move highlight to right side when False is checked */
#false:checked ~ .toggle-highlight {
transform: translateX(100%);
}
"""
def log_change(log_source, session_id, meta_data = None):
if not meta_data:
print(f'{session_id} changed {log_source}')
else:
print(f'{session_id} changed {log_source} with {meta_data}')
def generate_ai_image(prompt, session_id):
"""
Generate an image using an AI model based on the prompt.
This is a placeholder - implement with your preferred image generation model.
"""
# TODO: Implement actual AI image generation
# Example using a hypothetical image generation model:
# from diffusers import StableDiffusionPipeline
# pipe = StableDiffusionPipeline.from_pretrained("model_name")
# image = pipe(prompt).images[0]
# For now, return a placeholder
raise gr.Error("AI image generation not yet implemented. Please upload an image instead.")
def start_session(request: gr.Request):
return request.session_hash
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
print(f"{sid} left")
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
shutil.rmtree(d1, ignore_errors=True)
with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
with gr.Column(elem_id="col-container"):
with gr.Row():
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>Wan2.2-Animate-14B </strong>
</p>
<a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
[Model]
</a>
<p style="font-size:16px; display: inline; margin: 0;">
-- HF Space By:
</p>
<a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/πŸ€—-Follow Me-yellow.svg">
</a>
</div>
"""
)
with gr.Row():
with gr.Column(elem_id="step-column"):
gr.HTML("""
<div>
<span style="font-size: 24px;">1. Upload a Video</span><br>
</div>
""")
input_video = gr.Video(label="Input Video", height=512)
max_duration_slider = gr.Slider(2, 10, 2, step=2, label="Max Duration", visible=False)
gr.Examples(
examples=[
[
"./examples/martialart.mp4",
],
[
"./examples/test_example.mp4",
],
],
inputs=[input_video],
cache_examples=False,
)
with gr.Column(elem_id="step-column"):
gr.HTML("""
<div>
<span style="font-size: 24px;">2. Upload or Generate Ref Image</span><br>
</div>
""")
use_ai_image = gr.Checkbox(label="Generate Image with AI", value=False)
with gr.Group() as upload_group:
edited_frame = gr.Image(label="Ref Image", type="filepath", height=512)
with gr.Group(visible=False) as ai_group:
ai_prompt = gr.Textbox(label="AI Image Prompt", placeholder="Describe the image you want to generate...")
generate_btn = gr.Button("Generate Image", variant="secondary")
ai_generated_preview = gr.Image(label="Generated Preview", type="pil", height=512)
default_replace_string = "Video ← Ref Image"
replace_character_string = gr.Radio(
["Video β†’ Ref Image", "Video ← Ref Image"], value=default_replace_string, show_label=False
)
def toggle_image_input(use_ai):
return gr.update(visible=not use_ai), gr.update(visible=use_ai)
use_ai_image.change(
toggle_image_input,
inputs=[use_ai_image],
outputs=[upload_group, ai_group]
)
gr.Examples(
examples=[
[
"./examples/ali.png",
],
[
"./examples/amber.png",
],
[
"./examples/ella.png",
],
[
"./examples/sydney.png",
],
],
inputs=[edited_frame],
cache_examples=False,
)
with gr.Column(elem_id="step-column"):
gr.HTML("""
<div>
<span style="font-size: 24px;">3. Wan Animate it!</span><br>
</div>
""")
output_video = gr.Video(label="Edited Video", height=512)
duration_s = get_display_time_required(2, default_replace_string)
duration_m = duration_s / 60
time_required = f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)"
time_required = gr.Text(value=time_required, show_label=False, visible=False)
action_button = gr.Button("Wan Animate πŸ¦†", variant='primary', elem_classes="button-gradient")
with gr.Accordion("Preprocessed Data", open=False, visible=True):
with gr.Row():
pose_video = gr.Video(label="Pose Video")
bg_video = gr.Video(label="Background Video")
face_video = gr.Video(label="Face Video")
mask_video = gr.Video(label="Mask Video")
with gr.Row():
with gr.Column(elem_id="col-showcase"):
gr.Examples(
examples=[
[
"./examples/okay.mp4",
2,
"./examples/amber.png",
"Video ← Ref Image",
False,
""
],
[
"./examples/superman.mp4",
2,
"./examples/superman.png",
"Video ← Ref Image",
False,
""
],
[
"./examples/test_example.mp4",
2,
"./examples/ella.png",
"Video ← Ref Image",
False,
""
],
[
"./examples/paul.mp4",
2,
"./examples/man.png",
"Video β†’ Ref Image",
False,
""
],
[
"./examples/desi.mp4",
2,
"./examples/desi.png",
"Video ← Ref Image",
False,
""
],
],
inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, use_ai_image, ai_prompt],
outputs=[output_video, pose_video, bg_video, mask_video, face_video],
fn=animate_scene,
cache_examples=True,
)
action_button.click(fn=animate_scene, inputs=[input_video, max_duration_slider, edited_frame, replace_character_string, use_ai_image, ai_prompt, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video])
replace_character_string.change(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required])
max_duration_slider.change(log_change, inputs=[gr.State("slider"), session_state, max_duration_slider]).then(update_time_required, inputs=[max_duration_slider, replace_character_string], outputs=[time_required])
input_video.change(log_change, inputs=[gr.State("video"), session_state])
edited_frame.change(log_change, inputs=[gr.State("ref image"), session_state])
if __name__ == "__main__":
demo.queue()
demo.unload(cleanup)
demo.launch(ssr_mode=False, share=True)