import os import io import torch import logging import base64 import requests import numpy as np import cv2 from PIL import Image from gfpgan import GFPGANer from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path="."): logger.info("🚀 [INIT] GFPGAN + Real-ESRGAN handler starting...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.half = self.device == "cuda" self.path = path # Model URLs (GFPGAN + RealESRGAN) self.gfpgan_model_url = ( "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" ) self.realesr_model_url = ( "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" ) # Local cache paths self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth") self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth") self.bg_upsampler = None self.restorer = None # Ensure model weights exist self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path) self._ensure_model(self.realesr_model_url, self.realesr_model_path) logger.info(f"🧠 Device: {self.device}, half precision: {self.half}") def _ensure_model(self, url, path): """Download model if missing.""" if not os.path.exists(path): logger.info(f"⬇️ Downloading model from {url}") r = requests.get(url, timeout=60) r.raise_for_status() with open(path, "wb") as f: f.write(r.content) logger.info(f"✅ Model saved to {path}") else: logger.info(f"📁 Found cached model: {path}") def _init_models(self): """Lazy-load ESRGAN + GFPGAN models.""" if self.bg_upsampler is None: logger.info("🧩 Initializing Real-ESRGAN upsampler...") model = SRVGGNetCompact( num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type="prelu" ) self.bg_upsampler = RealESRGANer( scale=4, model_path=self.realesr_model_path, model=model, tile=400, tile_pad=10, pre_pad=0, half=self.half, device=self.device, ) if self.restorer is None: logger.info("🧬 Initializing GFPGAN restorer...") self.restorer = GFPGANer( model_path=self.gfpgan_model_path, upscale=2, arch="clean", channel_multiplier=2, bg_upsampler=self.bg_upsampler, ) logger.info("✅ Models ready!") def _load_image(self, data): """Accept base64, raw bytes, or URL and return PIL image.""" if isinstance(data, dict) and "inputs" in data: data = data["inputs"] if isinstance(data, (bytes, bytearray)): logger.info("📦 Received raw bytes input") return Image.open(io.BytesIO(data)).convert("RGB") if isinstance(data, str): if data.startswith("http"): logger.info(f"🌐 Downloading image from URL: {data}") resp = requests.get(data) return Image.open(io.BytesIO(resp.content)).convert("RGB") else: # Base64 logger.info("🧬 Decoding base64 image input") try: decoded = base64.b64decode(data) return Image.open(io.BytesIO(decoded)).convert("RGB") except Exception as e: logger.error(f"❌ Failed to decode base64: {e}") raise ValueError("Invalid base64 image input") raise ValueError("Unsupported input type") def __call__(self, data): logger.info("⚙️ Starting GFPGAN inference pipeline...") self._init_models() # Load input image = self._load_image(data) input_img = np.array(image, dtype=np.uint8) logger.info(f"📏 Input image shape: {input_img.shape}") # Restore face(s) cropped_faces, restored_faces, restored_img = self.restorer.enhance( input_img, has_aligned=False, only_center_face=False, paste_back=True ) logger.info("🖼️ Restoration complete, preparing output...") # ✅ Convert color from BGR → RGB (fix hue issue) restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8) # ✅ Encode output as base64 string for JSON _, buffer = cv2.imencode(".jpg", restored_img_rgb) b64_output = base64.b64encode(buffer).decode("utf-8") logger.info("✅ Returning base64-encoded image JSON response") return { "image": b64_output, "status": "success", "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)" }