gfpgan-handler / handler.py
mastari's picture
Fix color hue and add RGB output conversion
c8867e7
import os
import io
import torch
import logging
import base64
import requests
import numpy as np
import cv2
from PIL import Image
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path="."):
logger.info("πŸš€ [INIT] GFPGAN + Real-ESRGAN handler starting...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.half = self.device == "cuda"
self.path = path
# Model URLs (GFPGAN + RealESRGAN)
self.gfpgan_model_url = (
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
)
self.realesr_model_url = (
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
)
# Local cache paths
self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
self.bg_upsampler = None
self.restorer = None
# Ensure model weights exist
self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
self._ensure_model(self.realesr_model_url, self.realesr_model_path)
logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")
def _ensure_model(self, url, path):
"""Download model if missing."""
if not os.path.exists(path):
logger.info(f"⬇️ Downloading model from {url}")
r = requests.get(url, timeout=60)
r.raise_for_status()
with open(path, "wb") as f:
f.write(r.content)
logger.info(f"βœ… Model saved to {path}")
else:
logger.info(f"πŸ“ Found cached model: {path}")
def _init_models(self):
"""Lazy-load ESRGAN + GFPGAN models."""
if self.bg_upsampler is None:
logger.info("🧩 Initializing Real-ESRGAN upsampler...")
model = SRVGGNetCompact(
num_in_ch=3, num_out_ch=3, num_feat=64,
num_conv=32, upscale=4, act_type="prelu"
)
self.bg_upsampler = RealESRGANer(
scale=4,
model_path=self.realesr_model_path,
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=self.half,
device=self.device,
)
if self.restorer is None:
logger.info("🧬 Initializing GFPGAN restorer...")
self.restorer = GFPGANer(
model_path=self.gfpgan_model_path,
upscale=2,
arch="clean",
channel_multiplier=2,
bg_upsampler=self.bg_upsampler,
)
logger.info("βœ… Models ready!")
def _load_image(self, data):
"""Accept base64, raw bytes, or URL and return PIL image."""
if isinstance(data, dict) and "inputs" in data:
data = data["inputs"]
if isinstance(data, (bytes, bytearray)):
logger.info("πŸ“¦ Received raw bytes input")
return Image.open(io.BytesIO(data)).convert("RGB")
if isinstance(data, str):
if data.startswith("http"):
logger.info(f"🌐 Downloading image from URL: {data}")
resp = requests.get(data)
return Image.open(io.BytesIO(resp.content)).convert("RGB")
else:
# Base64
logger.info("🧬 Decoding base64 image input")
try:
decoded = base64.b64decode(data)
return Image.open(io.BytesIO(decoded)).convert("RGB")
except Exception as e:
logger.error(f"❌ Failed to decode base64: {e}")
raise ValueError("Invalid base64 image input")
raise ValueError("Unsupported input type")
def __call__(self, data):
logger.info("βš™οΈ Starting GFPGAN inference pipeline...")
self._init_models()
# Load input
image = self._load_image(data)
input_img = np.array(image, dtype=np.uint8)
logger.info(f"πŸ“ Input image shape: {input_img.shape}")
# Restore face(s)
cropped_faces, restored_faces, restored_img = self.restorer.enhance(
input_img, has_aligned=False, only_center_face=False, paste_back=True
)
logger.info("πŸ–ΌοΈ Restoration complete, preparing output...")
# βœ… Convert color from BGR β†’ RGB (fix hue issue)
restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8)
# βœ… Encode output as base64 string for JSON
_, buffer = cv2.imencode(".jpg", restored_img_rgb)
b64_output = base64.b64encode(buffer).decode("utf-8")
logger.info("βœ… Returning base64-encoded image JSON response")
return {
"image": b64_output,
"status": "success",
"info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)"
}