import os import math import random import logging import requests import numpy as np import torch import spaces from fastapi import FastAPI, HTTPException from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.utils.export_utils import export_to_video from PIL import Image import gradio as gr import tempfile import gc from torchao.quantization import quantize_ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig import aoti import re import spacy from datetime import datetime, date logging.basicConfig( level=logging.INFO, filename="wan_i2v.log", filemode="a", format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # ------------------------------------------------- # DAILY QUOTA SETTINGS # ------------------------------------------------- DAILY_LIMIT = 20 USAGE = {"count": 0, "date": date.today()} PLACEHOLDER_IMG = Image.new("RGB", (512, 512), color=(0, 0, 0)) # ------------------------------------------------- # MODEL CONFIGURATION # ------------------------------------------------- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" HF_TOKEN = os.environ.get("HF_TOKEN") MAX_DIM = 832 MIN_DIM = 480 SQUARE_DIM = 640 MULTIPLE_OF = 16 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 7720 # ------------------------------------------------- # PIPELINE BUILD # ------------------------------------------------- print("Loading pipeline components...") transformer = WanTransformer3DModel.from_pretrained( MODEL_ID, subfolder="transformer", torch_dtype=torch.bfloat16, token=HF_TOKEN, ) transformer_2 = WanTransformer3DModel.from_pretrained( MODEL_ID, subfolder="transformer_2", torch_dtype=torch.bfloat16, token=HF_TOKEN, ) print("Assembling pipeline...") pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, transformer=transformer, transformer_2=transformer_2, torch_dtype=torch.bfloat16, token=HF_TOKEN, ) pipe = pipe.to("cuda") # ------------------------------------------------- # LoRA ADAPTERS # ------------------------------------------------- print("Loading LoRA adapters...") try: pipe.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v", ) pipe.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v_2", load_into_transformer_2=True, ) pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0]) pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"]) pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"]) pipe.unload_lora_weights() print("LoRA loaded and fused successfully.") except Exception as e: print(f"Warning: Failed to load LoRA. Continuing without it. Error: {e}") # ------------------------------------------------- # QUANTISATION & AOTI # ------------------------------------------------- print("Applying quantisation...") torch.cuda.empty_cache() gc.collect() try: quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) print("Loading AOTI blocks...") aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da") aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da") except Exception as e: print(f"Warning: Quantisation/AOTI failed – will run in standard mode. Error: {e}") # ------------------------------------------------- # PROMPTS # ------------------------------------------------- QUALITY_PROMPT = ", high quality, detailed, vibrant, professional lighting, smooth motion, cinematic" default_negative_prompt = ( "low quality, worst quality, motion artifacts, unstable motion, jitter, frame jitter, wobbling limbs, " "motion distortion, inconsistent movement, robotic movement, animation‑like motion, awkward transitions, " "incorrect body mechanics, unnatural posing, off‑balance poses, broken motion paths, frozen frames, " "duplicated frames, frame skipping, warped motion, stretching artifacts, bad anatomy, incorrect proportions, " "deformed body, twisted torso, broken joints, dislocated limbs, distorted neck, unnatural spine curvature, " "malformed hands, extra fingers, missing fingers, fused fingers, distorted legs, extra limbs, collapsed feet, " "floating feet, foot sliding, foot jitter, backward walking, unnatural gait, blurry details, long exposure blur, " "ghosting, shadow trails, smearing, washed‑out colors, overexposure, underexposure, excessive contrast, " "blown highlights, poorly rendered clothing, fabric glitches, texture warping, clothing merging with body, " "incorrect cloth physics, ugly background, cluttered scene, crowded background, random objects, unwanted text, " "subtitles, logos, graffiti, grain, noise, static artifacts, compression noise, jpeg artifacts, image‑like " "stillness, painting‑like look, cartoon texture, low‑resolution textures" ) # ------------------------------------------------- # IMAGE RESIZING # ------------------------------------------------- def resize_image(image: Image.Image) -> Image.Image: w, h = image.size if w == h: return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) aspect = w / h max_ar = MAX_DIM / MIN_DIM min_ar = MIN_DIM / MAX_DIM img = image if aspect > max_ar: cw = int(round(h * max_ar)) left = (w - cw) // 2 img = image.crop((left, 0, left + cw, h)) elif aspect < min_ar: ch = int(round(w / min_ar)) top = (h - ch) // 2 img = image.crop((0, top, w, top + ch)) if w > h: tw = MAX_DIM th = int(round(tw / aspect)) else: th = MAX_DIM tw = int(round(th * aspect)) tw = round(tw / MULTIPLE_OF) * MULTIPLE_OF th = round(th / MULTIPLE_OF) * MULTIPLE_OF tw = max(MIN_DIM, min(MAX_DIM, tw)) th = max(MIN_DIM, min(MAX_DIM, th)) return img.resize((tw, th), Image.LANCZOS) def get_num_frames(duration_seconds: float) -> int: return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)) # ------------------------------------------------- # MDF TRANSLATOR # ------------------------------------------------- @spaces.GPU def translate_albanian_to_english(text: str, language: str = "en"): if not text.strip(): raise gr.Error("Please enter a description.") for attempt in range(2): try: response = requests.post( "https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate", json={"from_language": "sq", "to_language": "en", "input_text": text}, headers={"accept": "application/json", "Content-Type": "application/json"}, timeout=5, ) response.raise_for_status() translated = response.json().get("translate", "") logger.info(f"Translation response: {translated}") return translated except Exception as e: logger.error(f"Translation error (attempt {attempt + 1}): {e}") if attempt == 1: raise gr.Error("Translation failed. Please try again.") raise gr.Error("Translation failed. Please try again.") # ------------------------------------------------- # NSFW FILTER (identical to reference) # ------------------------------------------------- NSFW_BLACKLIST = { "nude", "naked", "porn", "sex", "sexual", "erotic", "erotica", "nsfw", "explicit", "cum", "orgasm", "penis", "vagina", "breast", "boob", "butt", "ass", "dick", "cock", "pussy", "fuck", "fucking", "suck", "sucking", "masturb", "bdsm", "kink", "fetish", "hentai", "gore", "violence", "blood", } SAFE_CLOTH = { "thong", "lingerie", "bra", "panty", "stockings", "underwear", "bikini", "swimsuit", "dress", "skirt", "shorts", "jeans", "trousers", "pants", "leggings", "suit", "coat", } SAFE_PHRASE_PATTERNS = [ re.compile(r"\bthong\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I), re.compile(r"\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b.*\b(?:butt|ass|booty|rear|rump|glutes)\b", re.I), re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\bthong\b", re.I), re.compile(r"\b(?:butt|ass|booty|rear|rump|glutes)\b.*\b(?:lingerie|bra|panty|stockings|bikini|swimsuit)\b", re.I), ] def is_safe_phrase(text: str) -> bool: return any(p.search(text) for p in SAFE_PHRASE_PATTERNS) try: nlp = spacy.load("en_core_web_sm") except OSError: print("spaCy model 'en_core_web_sm' not found. Downloading...") import subprocess subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) nlp = spacy.load("en_core_web_sm") def has_safe_modifier(token) -> bool: for child in token.children: if child.lemma_ in SAFE_CLOTH: return True if token.head.lemma_ in SAFE_CLOTH: return True for ancestor in token.ancestors: if ancestor.lemma_ in SAFE_CLOTH: return True return False def _contains_nsfw(text: str) -> bool: lowered = text.lower() if is_safe_phrase(lowered): return False doc = nlp(lowered) for token in doc: if token.lemma_ in NSFW_BLACKLIST: if has_safe_modifier(token): continue return True return False NSFW_ERROR_MSG = ( "🚫 Your prompt contains content that is not allowed on this service. " "Repeated attempts may result in a permanent ban." ) # ------------------------------------------------- # CORE INFERENCE # ------------------------------------------------- @spaces.GPU(duration=180) def infer(image, prompt): global USAGE today = date.today() if USAGE["date"] != today: USAGE["date"] = today USAGE["count"] = 0 if USAGE["count"] >= DAILY_LIMIT: return None, gr.update(value="🚫 You have used all your free generations. Please come back tomorrow.", visible=True) # Translate prompt_en = translate_albanian_to_english(prompt.strip()) + QUALITY_PROMPT # NSFW check if _contains_nsfw(prompt_en): logger.warning(f"NSFW attempt detected (hashed): {hash(prompt)}") return None, gr.update(value=NSFW_ERROR_MSG, visible=True) # Preprocess image if image is None: raise gr.Error("Please upload an input image.") pil_img = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB") img_resized = resize_image(pil_img) seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) gc.collect() torch.cuda.empty_cache() out_frames = pipe( image=img_resized, prompt=prompt_en, negative_prompt=default_negative_prompt, height=img_resized.height, width=img_resized.width, num_frames=get_num_frames(30.0), # fixed max duration guidance_scale=1.0, guidance_scale_2=1.0, num_inference_steps=4, generator=generator, ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: video_path = tmp.name export_to_video(out_frames, video_path, fps=FIXED_FPS) USAGE["count"] += 1 return video_path, gr.update(visible=False) # ------------------------------------------------- # GRADIO DEMO (exact replica of reference UI) # ------------------------------------------------- def create_demo(): with gr.Blocks(css="", title="Wan Image to Video") as demo: gr.HTML( """ """ ) with gr.Row(elem_id="general_items"): gr.Markdown("# ") gr.Markdown( "**⚠️ This app is safe‑for‑work only.** " "Any attempt to generate adult or explicit content will be blocked and may result in a ban.", elem_id="top_warning", ) gr.Markdown("Turn your image into a video with motion description", elem_id="subtitle") with gr.Column(elem_id="input_column"): input_image = gr.Image( label="Input Image", type="pil", sources=["upload"], show_download_button=False, show_share_button=False, interactive=True, elem_classes=["gradio-component", "image-container"] ) prompt = gr.Textbox( label="Prompt", lines=3, elem_classes=["gradio-component"] ) warning = gr.Markdown("", visible=False, elem_id="nsfw_warning") run_button = gr.Button( "Generate Video!", variant="primary", elem_classes=["gradio-component", "gr-button-primary"] ) result_video = gr.Video( label="Result Video", interactive=False, show_share_button=False, show_download_button=True, elem_classes=["gradio-component", "video-container"] ) run_button.click(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning]) prompt.submit(fn=infer, inputs=[input_image, prompt], outputs=[result_video, warning]) return demo # ------------------------------------------------- # FASTAPI MOUNT & 500 GUARD # ------------------------------------------------- app = FastAPI() demo = create_demo() app.mount("/b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0", demo.app) @app.get("/{path:path}") async def catch_all(path: str): raise HTTPException(status_code=500, detail="Internal Server Error") if __name__ == "__main__": logger.info(f"Gradio version: {gr.__version__}") demo.queue().launch(share=True)