HAL1993's picture
Update app.py
33d6a74 verified
import os
import math
import random
import logging
import requests
import numpy as np
import torch
import spaces
from fastapi import FastAPI, HTTPException
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
from PIL import Image
import gradio as gr
import tempfile
import gc
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
import aoti
import re
import spacy
from datetime import datetime, date
logging.basicConfig(
level=logging.INFO,
filename="wan_i2v.log",
filemode="a",
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# -------------------------------------------------
# DAILY QUOTA SETTINGS
# -------------------------------------------------
DAILY_LIMIT = 20
USAGE = {"count": 0, "date": date.today()}
PLACEHOLDER_IMG = Image.new("RGB", (512, 512), color=(0, 0, 0))
# -------------------------------------------------
# MODEL CONFIGURATION
# -------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_DIM = 832
MIN_DIM = 480
SQUARE_DIM = 640
MULTIPLE_OF = 16
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 7720
# -------------------------------------------------
# PIPELINE BUILD
# -------------------------------------------------
print("Loading pipeline components...")
transformer = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
transformer_2 = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer_2",
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
print("Assembling pipeline...")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=transformer,
transformer_2=transformer_2,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
pipe = pipe.to("cuda")
# -------------------------------------------------
# LoRA ADAPTERS
# -------------------------------------------------
print("Loading LoRA adapters...")
try:
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v",
)
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
load_into_transformer_2=True,
)
pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipe.unload_lora_weights()
print("LoRA loaded and fused successfully.")
except Exception as e:
print(f"Warning: Failed to load LoRA. Continuing without it. Error: {e}")
# -------------------------------------------------
# QUANTISATION & AOTI
# -------------------------------------------------
print("Applying quantisation...")
torch.cuda.empty_cache()
gc.collect()
try:
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
print("Loading AOTI blocks...")
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
except Exception as e:
print(f"Warning: Quantisation/AOTI failed – will run in standard mode. Error: {e}")
# -------------------------------------------------
# PROMPTS
# -------------------------------------------------
QUALITY_PROMPT = ", high quality, detailed, vibrant, professional lighting, smooth motion, cinematic"
default_negative_prompt = (
"low quality, worst quality, motion artifacts, unstable motion, jitter, frame jitter, wobbling limbs, "
"motion distortion, inconsistent movement, robotic movement, animation‑like motion, awkward transitions, "
"incorrect body mechanics, unnatural posing, off‑balance poses, broken motion paths, frozen frames, "
"duplicated frames, frame skipping, warped motion, stretching artifacts, bad anatomy, incorrect proportions, "
"deformed body, twisted torso, broken joints, dislocated limbs, distorted neck, unnatural spine curvature, "
"malformed hands, extra fingers, missing fingers, fused fingers, distorted legs, extra limbs, collapsed feet, "
"floating feet, foot sliding, foot jitter, backward walking, unnatural gait, blurry details, long exposure blur, "
"ghosting, shadow trails, smearing, washed‑out colors, overexposure, underexposure, excessive contrast, "
"blown highlights, poorly rendered clothing, fabric glitches, texture warping, clothing merging with body, "
"incorrect cloth physics, ugly background, cluttered scene, crowded background, random objects, unwanted text, "
"subtitles, logos, graffiti, grain, noise, static artifacts, compression noise, jpeg artifacts, image‑like "
"stillness, painting‑like look, cartoon texture, low‑resolution textures"
)
# -------------------------------------------------
# IMAGE RESIZING
# -------------------------------------------------
def resize_image(image: Image.Image) -> Image.Image:
w, h = image.size
if w == h:
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
aspect = w / h
max_ar = MAX_DIM / MIN_DIM
min_ar = MIN_DIM / MAX_DIM
img = image
if aspect > max_ar:
cw = int(round(h * max_ar))
left = (w - cw) // 2
img = image.crop((left, 0, left + cw, h))
elif aspect < min_ar:
ch = int(round(w / min_ar))
top = (h - ch) // 2
img = image.crop((0, top, w, top + ch))
if w > h:
tw = MAX_DIM
th = int(round(tw / aspect))
else:
th = MAX_DIM
tw = int(round(th * aspect))
tw = round(tw / MULTIPLE_OF) * MULTIPLE_OF
th = round(th / MULTIPLE_OF) * MULTIPLE_OF
tw = max(MIN_DIM, min(MAX_DIM, tw))
th = max(MIN_DIM, min(MAX_DIM, th))
return img.resize((tw, th), Image.LANCZOS)
def get_num_frames(duration_seconds: float) -> int:
return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
# -------------------------------------------------
# MDF TRANSLATOR
# -------------------------------------------------
@spaces.GPU
def translate_albanian_to_english(text: str, language: str = "en"):
if not text.strip():
raise gr.Error("Please enter a description.")
for attempt in range(2):
try:
response = requests.post(
"https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
json={"from_language": "sq", "to_language": "en", "input_text": text},
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=5,
)
response.raise_for_status()
translated = response.json().get("translate", "")
logger.info(f"Translation response: {translated}")
return translated
except Exception as e:
logger.error(f"Translation error (attempt {attempt + 1}): {e}")
if attempt == 1:
raise gr.Error("Translation failed. Please try again.")
raise gr.Error("Translation failed. Please try again.")
# -------------------------------------------------
# NSFW FILTER (identical to reference)
# -------------------------------------------------
NSFW_BLACKLIST = {
"nude", "naked", "porn", "sex", "sexual", "erotic", "erotica",
"nsfw", "explicit", "cum", "orgasm", "penis", "vagina",
"breast", "boob", "butt", "ass", "dick", "cock", "pussy",
"fuck", "fucking", "suck", "sucking", "masturb", "bdsm",
"kink", "fetish", "hentai", "gore", "violence", "blood",
}
SAFE_CLOTH = {
"thong", "lingerie", "bra", "panty", "stockings",
"underwear", "bikini", "swimsuit", "dress", "skirt", "shorts",
"jeans", "trousers", "pants", "leggings", "suit", "coat",
}
SAFE_PHRASE_PATTERNS = [
re.compile(r"\bthong\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I),
re.compile(r"\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I),
re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\bthong\b", re.I),
re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b", re.I),
]
def is_safe_phrase(text: str) -> bool:
return any(p.search(text) for p in SAFE_PHRASE_PATTERNS)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("spaCy model 'en_core_web_sm' not found. Downloading...")
import subprocess
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
nlp = spacy.load("en_core_web_sm")
def has_safe_modifier(token) -> bool:
for child in token.children:
if child.lemma_ in SAFE_CLOTH:
return True
if token.head.lemma_ in SAFE_CLOTH:
return True
for ancestor in token.ancestors:
if ancestor.lemma_ in SAFE_CLOTH:
return True
return False
def _contains_nsfw(text: str) -> bool:
lowered = text.lower()
if is_safe_phrase(lowered):
return False
doc = nlp(lowered)
for token in doc:
if token.lemma_ in NSFW_BLACKLIST:
if has_safe_modifier(token):
continue
return True
return False
NSFW_ERROR_MSG = (
"🚫 Your prompt contains content that is not allowed on this service. "
"Repeated attempts may result in a permanent ban."
)
# -------------------------------------------------
# CORE INFERENCE
# -------------------------------------------------
@spaces.GPU(duration=180)
def infer(image, prompt):
global USAGE
today = date.today()
if USAGE["date"] != today:
USAGE["date"] = today
USAGE["count"] = 0
if USAGE["count"] >= DAILY_LIMIT:
return None, gr.update(value="🚫 You have used all your free generations. Please come back tomorrow.", visible=True)
# Translate
prompt_en = translate_albanian_to_english(prompt.strip()) + QUALITY_PROMPT
# NSFW check
if _contains_nsfw(prompt_en):
logger.warning(f"NSFW attempt detected (hashed): {hash(prompt)}")
return None, gr.update(value=NSFW_ERROR_MSG, visible=True)
# Preprocess image
if image is None:
raise gr.Error("Please upload an input image.")
pil_img = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
img_resized = resize_image(pil_img)
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
gc.collect()
torch.cuda.empty_cache()
out_frames = pipe(
image=img_resized,
prompt=prompt_en,
negative_prompt=default_negative_prompt,
height=img_resized.height,
width=img_resized.width,
num_frames=get_num_frames(30.0), # fixed max duration
guidance_scale=1.0,
guidance_scale_2=1.0,
num_inference_steps=4,
generator=generator,
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
video_path = tmp.name
export_to_video(out_frames, video_path, fps=FIXED_FPS)
USAGE["count"] += 1
return video_path, gr.update(visible=False)
# -------------------------------------------------
# GRADIO DEMO (exact replica of reference UI)
# -------------------------------------------------
def create_demo():
with gr.Blocks(css="", title="Wan Image to Video") as demo:
gr.HTML(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;600;700&display=swap');
@keyframes glow {0%{box-shadow:0 0 14px rgba(0,255,128,0.5);}50%{box-shadow:0 0 14px rgba(0,255,128,0.7);}100%{box-shadow:0 0 14px rgba(0,255,128,0.5);}}
@keyframes glow-hover {0%{box-shadow:0 0 20px rgba(0,255,128,0.7);}50%{box-shadow:0 0 20px rgba(0,255,128,0.9);}100%{box-shadow:0 0 20px rgba(0,255,128,0.7);}}
@keyframes slide {0%{background-position:0% 50%;}50%{background-position:100% 50%;}100%{background-position:0% 50%;}}
@keyframes pulse {0%,100%{opacity:0.7;}50%{opacity:1;}}
body{
background:#000000 !important;
color:#FFFFFF !important;
font-family:'Orbitron',sans-serif;
min-height:100vh;
margin:0 !important;
padding:0 !important;
width:100% !important;
max-width:100vw !important;
overflow-x:hidden !important;
display:flex !important;
justify-content:center;
align-items:center;
flex-direction:column;
}
body::before{
content:"";
display:block;
height:600px;
background:#000000 !important;
}
.gr-blocks,.container{
width:100% !important;
max-width:100vw !important;
margin:0 !important;
padding:0 !important;
box-sizing:border-box !important;
overflow-x:hidden !important;
background:#000000 !important;
color:#FFFFFF !important;
}
.gr-row,.gr-column{
width:100% !important;
max-width:100vw !important;
margin:0 !important;
padding:0 !important;
box-sizing:border-box !important;
}
.gradio-container,.gradio-app,.gradio-interface{
width:100% !important;
max-width:100vw !important;
margin:0 !important;
padding:0 !important;
box-sizing:border-box !important;
}
#general_items{
width:100% !important;
max-width:100vw !important;
margin:2rem 0 !important;
display:flex !important;
flex-direction:column;
align-items:center;
justify-content:center;
background:#000000 !important;
color:#FFFFFF !important;
}
#input_column{
background:#000000 !important;
border:none !important;
border-radius:8px;
padding:1rem !important;
box-shadow:0 0 10px rgba(255,255,255,0.3) !important;
width:100% !important;
max-width:100vw !important;
box-sizing:border-box !important;
color:#FFFFFF !important;
}
h1{
font-size:5rem;
font-weight:700;
text-align:center;
color:#FFFFFF !important;
text-shadow:0 0 8px rgba(255,255,255,0.3) !important;
margin:0 auto .5rem;
display:block;
max-width:100%;
}
#subtitle{
font-size:1rem;
text-align:center;
color:#FFFFFF !important;
opacity:0.8;
margin-bottom:1rem;
display:block;
max-width:100%;
}
.gradio-component{
background:#000000 !important;
border:none;
margin:0.75rem 0;
width:100% !important;
max-width:100vw !important;
color:#FFFFFF !important;
}
.image-container,.video-container{
aspect-ratio:1/1;
width:100% !important;
max-width:100vw !important;
min-height:500px;
height:auto;
border:0.5px solid #FFFFFF !important;
border-radius:4px;
box-sizing:border-box !important;
background:#000000 !important;
box-shadow:0 0 10px rgba(255,255,255,0.3) !important;
position:relative;
color:#FFFFFF !important;
}
.image-container img,.video-container video{
width:100% !important;
height:auto;
box-sizing:border-box !important;
display:block !important;
}
.image-container[aria-label="Input Image"] .file-upload,
.image-container[aria-label="Input Image"] .file-preview,
.image-container[aria-label="Input Image"] .image-actions,
.video-container .file-upload,
.video-container .file-preview,
.video-container .image-actions{
display:none !important;
}
.video-container.processing{
background:#000000 !important;
position:relative !important;
}
.video-container.processing::before{
content:"PROCESSING...";
position:absolute !important;
top:50% !important;
left:50% !important;
transform:translate(-50%,-50%) !important;
color:#FFFFFF !important;
font-family:'Orbitron',sans-serif !important;
font-size:1.8rem !important;
font-weight:700 !important;
text-align:center !important;
text-shadow:0 0 10px rgba(0,255,128,0.8) !important;
animation:pulse 1.5s ease-in-out infinite,glow 2s ease-in-out infinite !important;
z-index:9999 !important;
width:100% !important;
height:100% !important;
display:flex !important;
align-items:center !important;
justify-content:center !important;
pointer-events:none !important;
background:#000000 !important;
border-radius:4px !important;
box-sizing:border-box !important;
}
.video-container.processing *{
display:none !important;
}
input,textarea{
background:#000000 !important;
color:#FFFFFF !important;
border:1px solid #FFFFFF !important;
border-radius:4px;
padding:0.5rem;
width:100% !important;
max-width:100vw !important;
box-sizing:border-box !important;
}
input:hover,textarea:hover{
box-shadow:0 0 8px rgba(255,255,255,0.3) !important;
transition:box-shadow 0.3s;
}
.gr-button-primary{
background:linear-gradient(90deg,rgba(0,255,128,0.3),rgba(0,200,100,0.3),rgba(0,255,128,0.3)) !important;
background-size:200% 100%;
animation:slide 4s ease-in-out infinite,glow 3s ease-in-out infinite;
color:#FFFFFF !important;
border:1px solid #FFFFFF !important;
border-radius:6px;
padding:0.75rem 1.5rem;
font-size:1.1rem;
font-weight:600;
box-shadow:0 0 14px rgba(0,255,128,0.7) !important;
transition:box-shadow 0.3s,transform 0.3s;
width:100% !important;
max-width:100vw !important;
min-height:48px;
cursor:pointer;
}
.gr-button-primary:hover{
box-shadow:0 0 20px rgba(0,255,128,0.9) !important;
animation:slide 4s ease-in-out infinite,glow-hover 3s ease-in-out infinite;
transform:scale(1.05);
}
button[aria-label="Fullscreen"],button[aria-label="Share"]{
display:none !important;
}
button[aria-label="Download"]{
transform:scale(3);
transform-origin:top right;
background:#000000 !important;
color:#FFFFFF !important;
border:1px solid #FFFFFF !important;
border-radius:4px;
padding:0.4rem !important;
margin:0.5rem !important;
box-shadow:0 0 8px rgba(255,255,255,0.3) !important;
transition:box-shadow 0.3s;
}
button[aria-label="Download"]:hover{
box-shadow:0 0 12px rgba(255,255,255,0.5) !important;
}
.progress-text,.gr-progress,.gr-prose,.gr-log{
display:none !important;
}
footer,.gr-button-secondary{
display:none !important;
}
.gr-group{
background:#000000 !important;
border:none !important;
width:100% !important;
max-width:100vw !important;
}
@media (max-width:768px){
h1{font-size:4rem;}
#subtitle{font-size:0.9rem;}
.gr-button-primary{
padding:0.6rem 1rem;
font-size:1rem;
box-shadow:0 0 10px rgba(0,255,128,0.7) !important;
animation:slide 4s ease-in-out infinite,glow 3s ease-in-out infinite;
}
.image-container,.video-container{min-height:300px;box-shadow:0 0 8px rgba(255,255,255,0.3) !important;}
.video-container.processing::before{font-size:1.2rem !important;}
}
#top_warning{
color:#ffdd00;
font-weight:600;
text-align:center;
margin-bottom:0.5rem;
}
#nsfw_warning{
color:#ff4d4d;
font-weight:600;
text-align:center;
margin-top:0.5rem;
}
</style>
<script>
const allowedPath = /^\\/b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0(\\/.*)?$/;
if (!allowedPath.test(window.location.pathname)) {
document.body.innerHTML = '<h1 style="color:#ef4444;font-family:sans-serif;text-align:center;margin-top:100px;">500 Internal Server Error</h1>';
throw new Error('500');
}
document.addEventListener('DOMContentLoaded', () => {
const generateBtn = document.querySelector('.gr-button-primary');
const resultContainer = document.querySelector('.video-container');
if (generateBtn && resultContainer) {
generateBtn.addEventListener('click', () => {
resultContainer.classList.add('processing');
resultContainer.querySelectorAll('*').forEach(child => {
if (child.tagName !== 'VIDEO') child.style.display = 'none';
});
});
const vidObserver = new MutationObserver(muts => {
muts.forEach(m => {
m.addedNodes.forEach(node => {
if (node.nodeType === 1 && (node.tagName === 'VIDEO' || node.querySelector('video'))) {
resultContainer.classList.remove('processing');
vidObserver.disconnect();
}
});
});
});
vidObserver.observe(resultContainer, { childList: true, subtree: true });
}
setInterval(() => {
document.querySelectorAll('.progress-text,.gr-progress,[class*="progress"]').forEach(el => el.remove());
}, 500);
});
</script>
"""
)
with gr.Row(elem_id="general_items"):
gr.Markdown("# ")
gr.Markdown(
"**⚠️ This app is safe‑for‑work only.** "
"Any attempt to generate adult or explicit content will be blocked and may result in a ban.",
elem_id="top_warning",
)
gr.Markdown("Turn your image into a video with motion description", elem_id="subtitle")
with gr.Column(elem_id="input_column"):
input_image = gr.Image(
label="Input Image",
type="pil",
sources=["upload"],
show_download_button=False,
show_share_button=False,
interactive=True,
elem_classes=["gradio-component", "image-container"]
)
prompt = gr.Textbox(
label="Prompt",
lines=3,
elem_classes=["gradio-component"]
)
warning = gr.Markdown("", visible=False, elem_id="nsfw_warning")
run_button = gr.Button(
"Generate Video!",
variant="primary",
elem_classes=["gradio-component", "gr-button-primary"]
)
result_video = gr.Video(
label="Result Video",
interactive=False,
show_share_button=False,
show_download_button=True,
elem_classes=["gradio-component", "video-container"]
)
run_button.click(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning])
prompt.submit(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning])
return demo
# -------------------------------------------------
# FASTAPI MOUNT & 500 GUARD
# -------------------------------------------------
app = FastAPI()
demo = create_demo()
app.mount("/b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0", demo.app)
@app.get("/{path:path}")
async def catch_all(path: str):
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
logger.info(f"Gradio version: {gr.__version__}")
demo.queue().launch(share=True)