|
|
import os |
|
|
import io |
|
|
import torch |
|
|
import logging |
|
|
import base64 |
|
|
import requests |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
from gfpgan import GFPGANer |
|
|
from realesrgan import RealESRGANer |
|
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
logger.info("π [INIT] GFPGAN + Real-ESRGAN handler starting...") |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.half = self.device == "cuda" |
|
|
self.path = path |
|
|
|
|
|
|
|
|
self.gfpgan_model_url = ( |
|
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" |
|
|
) |
|
|
self.realesr_model_url = ( |
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" |
|
|
) |
|
|
|
|
|
|
|
|
self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth") |
|
|
self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth") |
|
|
|
|
|
self.bg_upsampler = None |
|
|
self.restorer = None |
|
|
|
|
|
|
|
|
self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path) |
|
|
self._ensure_model(self.realesr_model_url, self.realesr_model_path) |
|
|
|
|
|
logger.info(f"π§ Device: {self.device}, half precision: {self.half}") |
|
|
|
|
|
def _ensure_model(self, url, path): |
|
|
"""Download model if missing.""" |
|
|
if not os.path.exists(path): |
|
|
logger.info(f"β¬οΈ Downloading model from {url}") |
|
|
r = requests.get(url, timeout=60) |
|
|
r.raise_for_status() |
|
|
with open(path, "wb") as f: |
|
|
f.write(r.content) |
|
|
logger.info(f"β
Model saved to {path}") |
|
|
else: |
|
|
logger.info(f"π Found cached model: {path}") |
|
|
|
|
|
def _init_models(self): |
|
|
"""Lazy-load ESRGAN + GFPGAN models.""" |
|
|
if self.bg_upsampler is None: |
|
|
logger.info("π§© Initializing Real-ESRGAN upsampler...") |
|
|
model = SRVGGNetCompact( |
|
|
num_in_ch=3, num_out_ch=3, num_feat=64, |
|
|
num_conv=32, upscale=4, act_type="prelu" |
|
|
) |
|
|
self.bg_upsampler = RealESRGANer( |
|
|
scale=4, |
|
|
model_path=self.realesr_model_path, |
|
|
model=model, |
|
|
tile=400, |
|
|
tile_pad=10, |
|
|
pre_pad=0, |
|
|
half=self.half, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
if self.restorer is None: |
|
|
logger.info("𧬠Initializing GFPGAN restorer...") |
|
|
self.restorer = GFPGANer( |
|
|
model_path=self.gfpgan_model_path, |
|
|
upscale=2, |
|
|
arch="clean", |
|
|
channel_multiplier=2, |
|
|
bg_upsampler=self.bg_upsampler, |
|
|
) |
|
|
logger.info("β
Models ready!") |
|
|
|
|
|
def _load_image(self, data): |
|
|
"""Accept base64, raw bytes, or URL and return PIL image.""" |
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
data = data["inputs"] |
|
|
|
|
|
if isinstance(data, (bytes, bytearray)): |
|
|
logger.info("π¦ Received raw bytes input") |
|
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
|
|
|
|
if isinstance(data, str): |
|
|
if data.startswith("http"): |
|
|
logger.info(f"π Downloading image from URL: {data}") |
|
|
resp = requests.get(data) |
|
|
return Image.open(io.BytesIO(resp.content)).convert("RGB") |
|
|
else: |
|
|
|
|
|
logger.info("𧬠Decoding base64 image input") |
|
|
try: |
|
|
decoded = base64.b64decode(data) |
|
|
return Image.open(io.BytesIO(decoded)).convert("RGB") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to decode base64: {e}") |
|
|
raise ValueError("Invalid base64 image input") |
|
|
|
|
|
raise ValueError("Unsupported input type") |
|
|
|
|
|
def __call__(self, data): |
|
|
logger.info("βοΈ Starting GFPGAN inference pipeline...") |
|
|
self._init_models() |
|
|
|
|
|
|
|
|
image = self._load_image(data) |
|
|
input_img = np.array(image, dtype=np.uint8) |
|
|
logger.info(f"π Input image shape: {input_img.shape}") |
|
|
|
|
|
|
|
|
cropped_faces, restored_faces, restored_img = self.restorer.enhance( |
|
|
input_img, has_aligned=False, only_center_face=False, paste_back=True |
|
|
) |
|
|
|
|
|
logger.info("πΌοΈ Restoration complete, preparing output...") |
|
|
|
|
|
|
|
|
restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) |
|
|
restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
_, buffer = cv2.imencode(".jpg", restored_img_rgb) |
|
|
b64_output = base64.b64encode(buffer).decode("utf-8") |
|
|
|
|
|
logger.info("β
Returning base64-encoded image JSON response") |
|
|
|
|
|
return { |
|
|
"image": b64_output, |
|
|
"status": "success", |
|
|
"info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)" |
|
|
} |
|
|
|
|
|
|