Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler | |
| from sklearn.model_selection import train_test_split | |
| def load_dataset(folder_path, max_images_per_class=60, allowed_classes=None): | |
| dataset = {} | |
| class_names = [ | |
| name for name in os.listdir(folder_path) | |
| if os.path.isdir(os.path.join(folder_path, name)) and | |
| (allowed_classes is None or name in allowed_classes) | |
| ] | |
| if allowed_classes: | |
| class_names = [cls for cls in allowed_classes if cls in class_names] | |
| for class_name in class_names: | |
| class_path = os.path.join(folder_path, class_name) | |
| images = [] | |
| for file_name in os.listdir(class_path): | |
| if len(images) >= max_images_per_class: | |
| break | |
| if file_name.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| img_path = os.path.join(class_path, file_name) | |
| img = Image.open(img_path).convert('RGB') | |
| images.append(np.array(img)) | |
| dataset[class_name] = images | |
| return dataset | |
| class AnimeDataset(Dataset): | |
| def __init__(self, images, transform=None, classes=None): | |
| self.images = [] | |
| self.labels = [] | |
| self.transform = transform | |
| self.classes = classes or list(images.keys()) | |
| for label, class_name in enumerate(self.classes): | |
| class_images = images.get(class_name, []) | |
| self.images.extend(class_images) | |
| self.labels.extend([label] * len(class_images)) | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| image = Image.fromarray(self.images[idx]) | |
| label = self.labels[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| class AnimeCNN(nn.Module): | |
| def __init__(self, num_classes=4): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout(0.25), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Dropout(0.25) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(64*16*16, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x | |
| def main(): | |
| SEED = 42 | |
| CLASSES = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"] | |
| IMG_SIZE = 64 | |
| BATCH_SIZE = 16 | |
| NUM_EPOCHS = 15 | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| dataset = load_dataset("dataset", allowed_classes=CLASSES) | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| anime_dataset = AnimeDataset(dataset, transform=transform, classes=CLASSES) | |
| indices = list(range(len(anime_dataset))) | |
| train_indices, val_indices = train_test_split( | |
| indices, | |
| test_size=0.2, | |
| random_state=SEED, | |
| stratify=anime_dataset.labels | |
| ) | |
| train_loader = DataLoader( | |
| anime_dataset, | |
| batch_size=BATCH_SIZE, | |
| sampler=SubsetRandomSampler(train_indices), | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| anime_dataset, | |
| batch_size=40, | |
| sampler=SubsetRandomSampler(val_indices), | |
| pin_memory=True | |
| ) | |
| model = AnimeCNN(num_classes=len(CLASSES)) | |
| optimizer = optim.Adam( | |
| model.parameters(), | |
| lr=0.001, | |
| weight_decay=1e-4 | |
| ) | |
| criterion = nn.CrossEntropyLoss() | |
| for epoch in range(NUM_EPOCHS): | |
| model.train() | |
| train_loss = 0.0 | |
| for inputs, labels in train_loader: | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| model.eval() | |
| val_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for inputs, labels in val_loader: | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| train_loss /= len(train_loader) | |
| val_loss /= len(val_loader) | |
| val_acc = 100 * correct / total | |
| print(f"Epoch {epoch+1:02d} | " | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | " | |
| f"Accuracy: {val_acc:.2f}%") | |
| print("Model saved as model.pth") | |
| torch.save(model.state_dict(), "model.pth") | |
| if __name__ == "__main__": | |
| main() | |