import os
import math
import random
import logging
import requests
import numpy as np
import torch
import spaces
from fastapi import FastAPI, HTTPException
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
from PIL import Image
import gradio as gr
import tempfile
import gc
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
import aoti
import re
import spacy
from datetime import datetime, date
logging.basicConfig(
level=logging.INFO,
filename="wan_i2v.log",
filemode="a",
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# -------------------------------------------------
# DAILY QUOTA SETTINGS
# -------------------------------------------------
DAILY_LIMIT = 20
USAGE = {"count": 0, "date": date.today()}
PLACEHOLDER_IMG = Image.new("RGB", (512, 512), color=(0, 0, 0))
# -------------------------------------------------
# MODEL CONFIGURATION
# -------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_DIM = 832
MIN_DIM = 480
SQUARE_DIM = 640
MULTIPLE_OF = 16
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 7720
# -------------------------------------------------
# PIPELINE BUILD
# -------------------------------------------------
print("Loading pipeline components...")
transformer = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
transformer_2 = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer_2",
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
print("Assembling pipeline...")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=transformer,
transformer_2=transformer_2,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
pipe = pipe.to("cuda")
# -------------------------------------------------
# LoRA ADAPTERS
# -------------------------------------------------
print("Loading LoRA adapters...")
try:
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v",
)
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
load_into_transformer_2=True,
)
pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipe.unload_lora_weights()
print("LoRA loaded and fused successfully.")
except Exception as e:
print(f"Warning: Failed to load LoRA. Continuing without it. Error: {e}")
# -------------------------------------------------
# QUANTISATION & AOTI
# -------------------------------------------------
print("Applying quantisation...")
torch.cuda.empty_cache()
gc.collect()
try:
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
print("Loading AOTI blocks...")
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
except Exception as e:
print(f"Warning: Quantisation/AOTI failed – will run in standard mode. Error: {e}")
# -------------------------------------------------
# PROMPTS
# -------------------------------------------------
QUALITY_PROMPT = ", high quality, detailed, vibrant, professional lighting, smooth motion, cinematic"
default_negative_prompt = (
"low quality, worst quality, motion artifacts, unstable motion, jitter, frame jitter, wobbling limbs, "
"motion distortion, inconsistent movement, robotic movement, animation‑like motion, awkward transitions, "
"incorrect body mechanics, unnatural posing, off‑balance poses, broken motion paths, frozen frames, "
"duplicated frames, frame skipping, warped motion, stretching artifacts, bad anatomy, incorrect proportions, "
"deformed body, twisted torso, broken joints, dislocated limbs, distorted neck, unnatural spine curvature, "
"malformed hands, extra fingers, missing fingers, fused fingers, distorted legs, extra limbs, collapsed feet, "
"floating feet, foot sliding, foot jitter, backward walking, unnatural gait, blurry details, long exposure blur, "
"ghosting, shadow trails, smearing, washed‑out colors, overexposure, underexposure, excessive contrast, "
"blown highlights, poorly rendered clothing, fabric glitches, texture warping, clothing merging with body, "
"incorrect cloth physics, ugly background, cluttered scene, crowded background, random objects, unwanted text, "
"subtitles, logos, graffiti, grain, noise, static artifacts, compression noise, jpeg artifacts, image‑like "
"stillness, painting‑like look, cartoon texture, low‑resolution textures"
)
# -------------------------------------------------
# IMAGE RESIZING
# -------------------------------------------------
def resize_image(image: Image.Image) -> Image.Image:
w, h = image.size
if w == h:
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
aspect = w / h
max_ar = MAX_DIM / MIN_DIM
min_ar = MIN_DIM / MAX_DIM
img = image
if aspect > max_ar:
cw = int(round(h * max_ar))
left = (w - cw) // 2
img = image.crop((left, 0, left + cw, h))
elif aspect < min_ar:
ch = int(round(w / min_ar))
top = (h - ch) // 2
img = image.crop((0, top, w, top + ch))
if w > h:
tw = MAX_DIM
th = int(round(tw / aspect))
else:
th = MAX_DIM
tw = int(round(th * aspect))
tw = round(tw / MULTIPLE_OF) * MULTIPLE_OF
th = round(th / MULTIPLE_OF) * MULTIPLE_OF
tw = max(MIN_DIM, min(MAX_DIM, tw))
th = max(MIN_DIM, min(MAX_DIM, th))
return img.resize((tw, th), Image.LANCZOS)
def get_num_frames(duration_seconds: float) -> int:
return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
# -------------------------------------------------
# MDF TRANSLATOR
# -------------------------------------------------
@spaces.GPU
def translate_albanian_to_english(text: str, language: str = "en"):
if not text.strip():
raise gr.Error("Please enter a description.")
for attempt in range(2):
try:
response = requests.post(
"https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
json={"from_language": "sq", "to_language": "en", "input_text": text},
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=5,
)
response.raise_for_status()
translated = response.json().get("translate", "")
logger.info(f"Translation response: {translated}")
return translated
except Exception as e:
logger.error(f"Translation error (attempt {attempt + 1}): {e}")
if attempt == 1:
raise gr.Error("Translation failed. Please try again.")
raise gr.Error("Translation failed. Please try again.")
# -------------------------------------------------
# NSFW FILTER (identical to reference)
# -------------------------------------------------
NSFW_BLACKLIST = {
"nude", "naked", "porn", "sex", "sexual", "erotic", "erotica",
"nsfw", "explicit", "cum", "orgasm", "penis", "vagina",
"breast", "boob", "butt", "ass", "dick", "cock", "pussy",
"fuck", "fucking", "suck", "sucking", "masturb", "bdsm",
"kink", "fetish", "hentai", "gore", "violence", "blood",
}
SAFE_CLOTH = {
"thong", "lingerie", "bra", "panty", "stockings",
"underwear", "bikini", "swimsuit", "dress", "skirt", "shorts",
"jeans", "trousers", "pants", "leggings", "suit", "coat",
}
SAFE_PHRASE_PATTERNS = [
re.compile(r"\bthong\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I),
re.compile(r"\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I),
re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\bthong\b", re.I),
re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b", re.I),
]
def is_safe_phrase(text: str) -> bool:
return any(p.search(text) for p in SAFE_PHRASE_PATTERNS)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("spaCy model 'en_core_web_sm' not found. Downloading...")
import subprocess
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
nlp = spacy.load("en_core_web_sm")
def has_safe_modifier(token) -> bool:
for child in token.children:
if child.lemma_ in SAFE_CLOTH:
return True
if token.head.lemma_ in SAFE_CLOTH:
return True
for ancestor in token.ancestors:
if ancestor.lemma_ in SAFE_CLOTH:
return True
return False
def _contains_nsfw(text: str) -> bool:
lowered = text.lower()
if is_safe_phrase(lowered):
return False
doc = nlp(lowered)
for token in doc:
if token.lemma_ in NSFW_BLACKLIST:
if has_safe_modifier(token):
continue
return True
return False
NSFW_ERROR_MSG = (
"🚫 Your prompt contains content that is not allowed on this service. "
"Repeated attempts may result in a permanent ban."
)
# -------------------------------------------------
# CORE INFERENCE
# -------------------------------------------------
@spaces.GPU(duration=180)
def infer(image, prompt):
global USAGE
today = date.today()
if USAGE["date"] != today:
USAGE["date"] = today
USAGE["count"] = 0
if USAGE["count"] >= DAILY_LIMIT:
return None, gr.update(value="🚫 You have used all your free generations. Please come back tomorrow.", visible=True)
# Translate
prompt_en = translate_albanian_to_english(prompt.strip()) + QUALITY_PROMPT
# NSFW check
if _contains_nsfw(prompt_en):
logger.warning(f"NSFW attempt detected (hashed): {hash(prompt)}")
return None, gr.update(value=NSFW_ERROR_MSG, visible=True)
# Preprocess image
if image is None:
raise gr.Error("Please upload an input image.")
pil_img = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
img_resized = resize_image(pil_img)
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
gc.collect()
torch.cuda.empty_cache()
out_frames = pipe(
image=img_resized,
prompt=prompt_en,
negative_prompt=default_negative_prompt,
height=img_resized.height,
width=img_resized.width,
num_frames=get_num_frames(30.0), # fixed max duration
guidance_scale=1.0,
guidance_scale_2=1.0,
num_inference_steps=4,
generator=generator,
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
video_path = tmp.name
export_to_video(out_frames, video_path, fps=FIXED_FPS)
USAGE["count"] += 1
return video_path, gr.update(visible=False)
# -------------------------------------------------
# GRADIO DEMO (exact replica of reference UI)
# -------------------------------------------------
def create_demo():
with gr.Blocks(css="", title="Wan Image to Video") as demo:
gr.HTML(
"""
"""
)
with gr.Row(elem_id="general_items"):
gr.Markdown("# ")
gr.Markdown(
"**⚠️ This app is safe‑for‑work only.** "
"Any attempt to generate adult or explicit content will be blocked and may result in a ban.",
elem_id="top_warning",
)
gr.Markdown("Turn your image into a video with motion description", elem_id="subtitle")
with gr.Column(elem_id="input_column"):
input_image = gr.Image(
label="Input Image",
type="pil",
sources=["upload"],
show_download_button=False,
show_share_button=False,
interactive=True,
elem_classes=["gradio-component", "image-container"]
)
prompt = gr.Textbox(
label="Prompt",
lines=3,
elem_classes=["gradio-component"]
)
warning = gr.Markdown("", visible=False, elem_id="nsfw_warning")
run_button = gr.Button(
"Generate Video!",
variant="primary",
elem_classes=["gradio-component", "gr-button-primary"]
)
result_video = gr.Video(
label="Result Video",
interactive=False,
show_share_button=False,
show_download_button=True,
elem_classes=["gradio-component", "video-container"]
)
run_button.click(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning])
prompt.submit(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning])
return demo
# -------------------------------------------------
# FASTAPI MOUNT & 500 GUARD
# -------------------------------------------------
app = FastAPI()
demo = create_demo()
app.mount("/b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0", demo.app)
@app.get("/{path:path}")
async def catch_all(path: str):
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
logger.info(f"Gradio version: {gr.__version__}")
demo.queue().launch(share=True)