|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
from loadimg import load_img |
|
|
from ben_base import BEN_Base |
|
|
import random |
|
|
import huggingface_hub |
|
|
import numpy as np |
|
|
|
|
|
def set_random_seed(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
set_random_seed(9) |
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
model = BEN_Base() |
|
|
|
|
|
model_path = huggingface_hub.hf_hub_download( |
|
|
repo_id="PramaLLC/BEN2", |
|
|
filename="BEN2_Base.pth" |
|
|
) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model.loadcheckpoints(model_path) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
output_folder = 'output_images' |
|
|
if not os.path.exists(output_folder): |
|
|
os.makedirs(output_folder) |
|
|
|
|
|
def fn(image): |
|
|
im = load_img(image, output_type="pil") |
|
|
im = im.convert("RGB") |
|
|
result_image = process(im) |
|
|
image_path = os.path.join(output_folder, "foreground.png") |
|
|
result_image.save(image_path) |
|
|
return result_image, image_path |
|
|
|
|
|
|
|
|
def process_video(video_path): |
|
|
output_path = "./foreground.mp4" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.segment_video(video_path, max_frames=999999) |
|
|
return output_path |
|
|
|
|
|
|
|
|
def process(image): |
|
|
foreground = model.inference(image) |
|
|
print(type(foreground)) |
|
|
return foreground |
|
|
|
|
|
def process_file(f): |
|
|
name_path = f.rsplit(".",1)[0]+".png" |
|
|
im = load_img(f, output_type="pil") |
|
|
im = im.convert("RGB") |
|
|
transparent = process(im) |
|
|
transparent.save(name_path) |
|
|
return name_path |
|
|
|
|
|
|
|
|
image = gr.Image(label="画像をアップロード") |
|
|
video = gr.Video(label="動画をアップロード") |
|
|
|
|
|
|
|
|
tab1 = gr.Interface( |
|
|
fn, |
|
|
inputs=image, |
|
|
outputs=[ |
|
|
gr.Image(label="結果 前景"), |
|
|
gr.File(label="PNGをダウンロード") |
|
|
], |
|
|
api_name="image" |
|
|
) |
|
|
|
|
|
|
|
|
tab2 = gr.Interface( |
|
|
process_video, |
|
|
inputs=video, |
|
|
outputs=gr.Video(label="結果動画"), |
|
|
api_name="video", |
|
|
title="動画処理(実験的)" |
|
|
) |
|
|
|
|
|
|
|
|
demo = gr.TabbedInterface( |
|
|
[tab1, tab2], |
|
|
["画像処理", "動画処理"], |
|
|
title="背景除去ツール" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |