Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| class AnimeCNN(nn.Module): | |
| def __init__(self, num_classes=4): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout(0.25), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout(0.25) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(64*16*16, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| model = AnimeCNN() | |
| model.load_state_dict(torch.load('model.pth', map_location=device, weights_only=True)) | |
| model.eval() | |
| classes = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"] | |
| transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def predict(image): | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| confidences = {classes[i]: float(probabilities[i]) for i in range(4)} | |
| return confidences | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="入力画像"), | |
| outputs=gr.Label(num_top_classes=4, label="予測結果"), | |
| title="アニメキャラクター分類器", | |
| description="うさだぺこら・逢坂大河・めぐみん・湊あくあの画像を分類します。画像をアップロードしてください。", | |
| examples=[ | |
| ["examples/usada_pekora.jpg"], | |
| ["examples/aisaka_taiga.jpg"], | |
| ["examples/megumin.jpg"], | |
| ["examples/minato_aqua.jpg"] | |
| ], | |
| ) | |
| interface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |