File size: 2,750 Bytes
70b5538 f782800 70b5538 f782800 dbc9d17 f782800 70b5538 f782800 e2ffc2b f782800 e2ffc2b f782800 e2ffc2b f782800 dbc9d17 f782800 e2ffc2b dbc9d17 f782800 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# import spaces # ZeroGPU用 - 通常GPU使用時はコメントアウト
import gradio as gr
import torch
import os
import sys
from loadimg import load_img
from ben_base import BEN_Base
import random
import huggingface_hub
import numpy as np
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_random_seed(9)
torch.set_float32_matmul_precision("high")
model = BEN_Base()
# Download the model file from Hugging Face Hub
model_path = huggingface_hub.hf_hub_download(
repo_id="PramaLLC/BEN2",
filename="BEN2_Base.pth"
)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model
model.loadcheckpoints(model_path)
model.to(device)
model.eval()
output_folder = 'output_images'
if not os.path.exists(output_folder):
os.makedirs(output_folder)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
result_image = process(im)
image_path = os.path.join(output_folder, "foreground.png")
result_image.save(image_path)
return result_image, image_path
# @spaces.GPU # ZeroGPU用 - 通常GPU使用時はコメントアウト
def process_video(video_path):
output_path = "./foreground.mp4"
# print(type(video_path))
# print(video_path)
model.segment_video(video_path, max_frames=999999) # 制限を実質的に解除
return output_path
# @spaces.GPU # ZeroGPU用 - 通常GPU使用時はコメントアウト
def process(image):
foreground = model.inference(image)
print(type(foreground))
return foreground
def process_file(f):
name_path = f.rsplit(".",1)[0]+".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
# Interface components
image = gr.Image(label="画像をアップロード")
video = gr.Video(label="動画をアップロード")
# Image processing tab
tab1 = gr.Interface(
fn,
inputs=image,
outputs=[
gr.Image(label="結果 前景"),
gr.File(label="PNGをダウンロード")
],
api_name="image"
)
# Video processing tab
tab2 = gr.Interface(
process_video,
inputs=video,
outputs=gr.Video(label="結果動画"),
api_name="video",
title="動画処理(実験的)"
)
# Combined interface
demo = gr.TabbedInterface(
[tab1, tab2],
["画像処理", "動画処理"],
title="背景除去ツール"
)
if __name__ == "__main__":
demo.launch(show_error=True) |