File size: 2,750 Bytes
70b5538
f782800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b5538
f782800
 
 
 
 
 
dbc9d17
f782800
 
70b5538
f782800
 
 
 
 
 
 
 
 
 
 
 
 
 
e2ffc2b
 
f782800
 
 
 
 
 
e2ffc2b
 
f782800
 
 
 
 
 
 
 
e2ffc2b
f782800
dbc9d17
f782800
 
 
 
 
e2ffc2b
dbc9d17
f782800
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# import spaces  # ZeroGPU用 - 通常GPU使用時はコメントアウト
import gradio as gr
import torch
import os
import sys
from loadimg import load_img
from ben_base import BEN_Base
import random
import huggingface_hub
import numpy as np

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed(9)
torch.set_float32_matmul_precision("high")

model = BEN_Base()
# Download the model file from Hugging Face Hub
model_path = huggingface_hub.hf_hub_download(
    repo_id="PramaLLC/BEN2",
    filename="BEN2_Base.pth"
)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
model.loadcheckpoints(model_path)
model.to(device)
model.eval()

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    result_image = process(im)
    image_path = os.path.join(output_folder, "foreground.png")
    result_image.save(image_path)
    return result_image, image_path

# @spaces.GPU  # ZeroGPU用 - 通常GPU使用時はコメントアウト
def process_video(video_path):
    output_path = "./foreground.mp4"
    
    # print(type(video_path))
    # print(video_path)
    
    model.segment_video(video_path, max_frames=999999)  # 制限を実質的に解除
    return output_path

# @spaces.GPU  # ZeroGPU用 - 通常GPU使用時はコメントアウト
def process(image):
    foreground = model.inference(image)
    print(type(foreground))
    return foreground

def process_file(f):
    name_path = f.rsplit(".",1)[0]+".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

# Interface components
image = gr.Image(label="画像をアップロード")
video = gr.Video(label="動画をアップロード")

# Image processing tab
tab1 = gr.Interface(
    fn,
    inputs=image,
    outputs=[
        gr.Image(label="結果 前景"),
        gr.File(label="PNGをダウンロード")
    ],
    api_name="image"
)

# Video processing tab
tab2 = gr.Interface(
    process_video,
    inputs=video,
    outputs=gr.Video(label="結果動画"),
    api_name="video",
    title="動画処理(実験的)"
)

# Combined interface 
demo = gr.TabbedInterface(
    [tab1, tab2],
    ["画像処理", "動画処理"],
    title="背景除去ツール"
)

if __name__ == "__main__":
    demo.launch(show_error=True)