ArcaneGANs / app.py
Arrcttacsrks's picture
Update app.py
bfd2a72 verified
import os
import sys
from huggingface_hub import hf_hub_download
import spaces
from facenet_pytorch import MTCNN
from torchvision import transforms
import torch
import PIL
from PIL import Image
import gradio as gr
# Download models
modelarcanev4 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.4", filename="ArcaneGANv0.4.jit")
modelarcanev3 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.3", filename="ArcaneGANv0.3.jit")
modelarcanev2 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.2", filename="ArcaneGANv0.2.jit")
# Check if GPU is available
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_GPU = DEVICE == 'cuda'
print(f"Running on: {DEVICE.upper()}")
mtcnn = MTCNN(image_size=256, margin=80, device=DEVICE)
# Face detection
def detect(img):
batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
if not mtcnn.keep_all:
batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
)
return batch_boxes, batch_points
def makeEven(_x):
return _x if (_x % 2 == 0) else _x+1
def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):
x, y = _img.size
ratio = 2
if (boxes is not None):
if len(boxes)>0:
ratio = target_face/max(boxes[0][2:]-boxes[0][:2])
ratio = min(ratio, max_upscale)
if fixed_ratio>0:
ratio = fixed_ratio
x*=ratio
y*=ratio
res = x*y
if res > max_res:
ratio = pow(res/max_res,1/2)
x=int(x/ratio)
y=int(y/ratio)
x = makeEven(int(x))
y = makeEven(int(y))
size = (x, y)
return _img.resize(size)
def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):
boxes = None
boxes, _ = detect(_img)
img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
return img_resized
# Image processing setup
size = 256
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]
img_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
def proc_pil_img_cpu(input_image, model_path):
"""High-quality CPU-based image processing"""
# Load model on CPU with full precision
model = torch.jit.load(model_path, map_location='cpu').eval()
# Create tensors on CPU in full precision
t_stds = torch.tensor(stds).view(3, 1, 1)
t_means = torch.tensor(means).view(3, 1, 1)
# Transform image (full precision on CPU)
transformed_image = img_transforms(input_image).unsqueeze(0)
with torch.no_grad():
result_image = model(transformed_image)[0]
# Post-process with full precision
output_image = result_image.mul(t_stds).add(t_means).mul(255.).clamp(0, 255).permute(1, 2, 0)
output_image = output_image.numpy().astype('uint8')
output_image = PIL.Image.fromarray(output_image)
# Clean up
del model
return output_image
@spaces.GPU
def proc_pil_img_gpu(input_image, model_path):
"""GPU-accelerated image processing with half precision support"""
# Load model on GPU
model = torch.jit.load(model_path, map_location='cuda').eval()
# Create tensors on GPU in half precision to match model
t_stds = torch.tensor(stds).cuda().half().view(3, 1, 1)
t_means = torch.tensor(means).cuda().half().view(3, 1, 1)
# Transform image and move to GPU with half precision
transformed_image = img_transforms(input_image).unsqueeze(0).cuda().half()
with torch.no_grad():
result_image = model(transformed_image)[0]
# Convert back to float for post-processing
output_image = result_image.float().mul(t_stds.float()).add(t_means.float()).mul(255.).clamp(0, 255).permute(1, 2, 0)
output_image = output_image.cpu().numpy().astype('uint8')
output_image = PIL.Image.fromarray(output_image)
# Clean up
del model
torch.cuda.empty_cache()
return output_image
def process(im, version):
"""Main processing function with automatic GPU/CPU selection"""
if im is None:
return None
try:
# Ensure image is PIL Image
if not isinstance(im, Image.Image):
im = Image.fromarray(im)
# Convert to RGB if needed
if im.mode != 'RGB':
im = im.convert('RGB')
# Scale image (CPU operation)
im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
# Select model based on version
if version == 'v0.4 (Recommended)':
model_path = modelarcanev4
elif version == 'v0.3':
model_path = modelarcanev3
else:
model_path = modelarcanev2
# Use GPU or CPU based on availability
if USE_GPU:
res = proc_pil_img_gpu(im, model_path)
else:
res = proc_pil_img_cpu(im, model_path)
return res
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
# Build the interface with old Gradio syntax
device_info = "⚡ Zero-GPU" if USE_GPU else "🖥️ CPU (High Quality)"
compute_details = (
"Uses Hugging Face's Zero-GPU infrastructure for efficient GPU allocation."
if USE_GPU else
"Running with high-quality full-precision (FP32) inference on CPU. Maintains maximum quality."
)
title = "🎨 ArcaneGAN - Transform Your Photos into Arcane-Style Art"
description = f"""
Transform your portrait photos into the stunning visual style of Netflix's Arcane series!
**Status**: {device_info}
{compute_details}
**Tips for Best Results:**
- Use clear, well-lit portrait photos
- Face should be clearly visible and not too small
- Works best with frontal or slightly angled faces
- Try different model versions for varied artistic styles
"""
article = """
---
**ArcaneGAN** by [Alexander S](https://twitter.com/devdef) |
[GitHub Repository](https://github.com/Sxela/ArcaneGAN) |
[Original Space](https://huggingface.co/spaces/akhaliq/ArcaneGAN)
**Model Versions:**
- **v0.4**: Latest and recommended - best quality and style accuracy
- **v0.3**: Alternative style interpretation
- **v0.2**: Original version with unique characteristics
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
"""
# Create interface using old Gradio Interface API
demo = gr.Interface(
fn=process,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(
choices=['v0.4 (Recommended)', 'v0.3', 'v0.2'],
value='v0.4 (Recommended)',
label="Model Version"
)
],
outputs=gr.Image(type="pil", label="Arcane-Style Result"),
title=title,
description=description,
article=article
)
# Launch
demo.launch()