Spaces:
Runtime error
Runtime error
| """A generic training wrapper.""" | |
| from copy import deepcopy | |
| import logging | |
| from typing import Callable, List, Optional | |
| import torch | |
| from torch.utils.data import DataLoader | |
| LOGGER = logging.getLogger(__name__) | |
| class Trainer: | |
| def __init__( | |
| self, | |
| epochs: int = 20, | |
| batch_size: int = 32, | |
| device: str = "cpu", | |
| optimizer_fn: Callable = torch.optim.Adam, | |
| optimizer_kwargs: dict = {"lr": 1e-3}, | |
| use_scheduler: bool = False, | |
| ) -> None: | |
| self.epochs = epochs | |
| self.batch_size = batch_size | |
| self.device = device | |
| self.optimizer_fn = optimizer_fn | |
| self.optimizer_kwargs = optimizer_kwargs | |
| self.epoch_test_losses: List[float] = [] | |
| self.use_scheduler = use_scheduler | |
| def forward_and_loss(model, criterion, batch_x, batch_y, **kwargs): | |
| batch_out = model(batch_x) | |
| batch_loss = criterion(batch_out, batch_y) | |
| return batch_out, batch_loss | |
| class GDTrainer(Trainer): | |
| def train( | |
| self, | |
| dataset: torch.utils.data.Dataset, | |
| model: torch.nn.Module, | |
| test_len: Optional[float] = None, | |
| test_dataset: Optional[torch.utils.data.Dataset] = None, | |
| ): | |
| if test_dataset is not None: | |
| train = dataset | |
| test = test_dataset | |
| else: | |
| test_len = int(len(dataset) * test_len) | |
| train_len = len(dataset) - test_len | |
| lengths = [train_len, test_len] | |
| train, test = torch.utils.data.random_split(dataset, lengths) | |
| train_loader = DataLoader( | |
| train, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| num_workers=6, | |
| ) | |
| test_loader = DataLoader( | |
| test, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| num_workers=6, | |
| ) | |
| criterion = torch.nn.BCEWithLogitsLoss() | |
| optim = self.optimizer_fn(model.parameters(), **self.optimizer_kwargs) | |
| best_model = None | |
| best_acc = 0 | |
| LOGGER.info(f"Starting training for {self.epochs} epochs!") | |
| forward_and_loss_fn = forward_and_loss | |
| if self.use_scheduler: | |
| batches_per_epoch = len(train_loader) * 2 # every 2nd epoch | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| optimizer=optim, | |
| T_0=batches_per_epoch, | |
| T_mult=1, | |
| eta_min=5e-6, | |
| # verbose=True, | |
| ) | |
| use_cuda = self.device != "cpu" | |
| for epoch in range(self.epochs): | |
| LOGGER.info(f"Epoch num: {epoch}") | |
| running_loss = 0 | |
| num_correct = 0.0 | |
| num_total = 0.0 | |
| model.train() | |
| for i, (batch_x, _, batch_y) in enumerate(train_loader): | |
| batch_size = batch_x.size(0) | |
| num_total += batch_size | |
| batch_x = batch_x.to(self.device) | |
| batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) | |
| batch_out, batch_loss = forward_and_loss_fn( | |
| model, criterion, batch_x, batch_y, use_cuda=use_cuda | |
| ) | |
| batch_pred = (torch.sigmoid(batch_out) + 0.5).int() | |
| num_correct += (batch_pred == batch_y.int()).sum(dim=0).item() | |
| running_loss += batch_loss.item() * batch_size | |
| if i % 100 == 0: | |
| LOGGER.info( | |
| f"[{epoch:04d}][{i:05d}]: {running_loss / num_total} {num_correct/num_total*100}" | |
| ) | |
| optim.zero_grad() | |
| batch_loss.backward() | |
| optim.step() | |
| if self.use_scheduler: | |
| scheduler.step() | |
| running_loss /= num_total | |
| train_accuracy = (num_correct / num_total) * 100 | |
| LOGGER.info( | |
| f"Epoch [{epoch+1}/{self.epochs}]: train/loss: {running_loss}, train/accuracy: {train_accuracy}" | |
| ) | |
| test_running_loss = 0.0 | |
| num_correct = 0.0 | |
| num_total = 0.0 | |
| model.eval() | |
| eer_val = 0 | |
| for batch_x, _, batch_y in test_loader: | |
| batch_size = batch_x.size(0) | |
| num_total += batch_size | |
| batch_x = batch_x.to(self.device) | |
| with torch.no_grad(): | |
| batch_pred = model(batch_x) | |
| batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) | |
| batch_loss = criterion(batch_pred, batch_y) | |
| test_running_loss += batch_loss.item() * batch_size | |
| batch_pred = torch.sigmoid(batch_pred) | |
| batch_pred_label = (batch_pred + 0.5).int() | |
| num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() | |
| if num_total == 0: | |
| num_total = 1 | |
| test_running_loss /= num_total | |
| test_acc = 100 * (num_correct / num_total) | |
| LOGGER.info( | |
| f"Epoch [{epoch+1}/{self.epochs}]: test/loss: {test_running_loss}, test/accuracy: {test_acc}, test/eer: {eer_val}" | |
| ) | |
| if best_model is None or test_acc > best_acc: | |
| best_acc = test_acc | |
| best_model = deepcopy(model.state_dict()) | |
| LOGGER.info( | |
| f"[{epoch:04d}]: {running_loss} - train acc: {train_accuracy} - test_acc: {test_acc}" | |
| ) | |
| model.load_state_dict(best_model) | |
| return model | |