|
|
import os |
|
|
import torch |
|
|
import numpy |
|
|
from datetime import datetime |
|
|
from PIL import ImageFile |
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
import util as toolkit |
|
|
from loader import get_loader as fetch_train_data, get_val_loader as fetch_val_data |
|
|
from config import ConfigurationManager as Configurator |
|
|
from model import model as NeuralNetwork |
|
|
from util import bceLoss as compute_binary_loss |
|
|
|
|
|
|
|
|
def prepare_validation_config(): |
|
|
"""Create validation-specific configuration""" |
|
|
val_cfg = Configurator().parse() |
|
|
val_cfg.isTrain = False |
|
|
val_cfg.isVal = True |
|
|
|
|
|
return val_cfg |
|
|
|
|
|
|
|
|
def execute_training_iteration( |
|
|
data_provider, |
|
|
network, |
|
|
optimizer, |
|
|
epoch_index, |
|
|
storage_location |
|
|
): |
|
|
"""Perform training iteration""" |
|
|
network.train() |
|
|
global iteration_counter |
|
|
epoch_iterations = 0 |
|
|
total_loss = 0 |
|
|
|
|
|
try: |
|
|
for batch_idx, (inputs, targets) in enumerate(data_provider, start=1): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
inputs = inputs.cuda() |
|
|
targets = targets.cuda() |
|
|
|
|
|
|
|
|
outputs = network(inputs).ravel() |
|
|
|
|
|
|
|
|
loss_function = compute_binary_loss() |
|
|
batch_loss = loss_function(outputs, targets) |
|
|
|
|
|
|
|
|
batch_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
iteration_counter += 1 |
|
|
epoch_iterations += 1 |
|
|
total_loss += batch_loss.item() |
|
|
|
|
|
|
|
|
if batch_idx % 500 == 0 or batch_idx == total_batches or batch_idx == 1: |
|
|
current_time = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
|
progress_percent = (batch_idx / total_batches) * 100 |
|
|
|
|
|
status_report = ( |
|
|
f"π Epoch: {epoch_index:02d}/{config.epoch:02d} | " |
|
|
f"π’ Iteration: {batch_idx:04d}/{total_batches:04d} " |
|
|
f"({progress_percent:.1f}%) | " |
|
|
f"π Loss Metric: {batch_loss.item():.6f}" |
|
|
) |
|
|
print(status_report) |
|
|
|
|
|
|
|
|
if epoch_index % 50 == 0: |
|
|
checkpoint_path = os.path.join( |
|
|
storage_location, |
|
|
f'Network_epoch_{epoch_index}.pth' |
|
|
) |
|
|
torch.save(network.state_dict(), checkpoint_path) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("Training interrupted: saving model and exiting") |
|
|
|
|
|
|
|
|
def perform_validation( |
|
|
validation_sets, |
|
|
network, |
|
|
epoch_index, |
|
|
storage_location |
|
|
): |
|
|
"""Evaluate model on validation sets""" |
|
|
network.eval() |
|
|
global best_performing_epoch, highest_accuracy |
|
|
|
|
|
total_correct = total_samples = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for dataset in validation_sets: |
|
|
correct_ai = correct_nature = 0 |
|
|
|
|
|
name = dataset['name'] |
|
|
ai_loader = dataset['val_ai_loader'] |
|
|
ai_count = dataset['ai_size'] |
|
|
nature_loader = dataset['val_nature_loader'] |
|
|
nature_count = dataset['nature_size'] |
|
|
|
|
|
print(f"||Validating||") |
|
|
|
|
|
|
|
|
for inputs, targets in ai_loader: |
|
|
inputs = inputs.cuda() |
|
|
targets = targets.cuda() |
|
|
|
|
|
predictions = network(inputs) |
|
|
probabilities = torch.sigmoid(predictions).ravel() |
|
|
|
|
|
|
|
|
correct = ( |
|
|
((probabilities > 0.5) & (targets == 1)) | |
|
|
((probabilities < 0.5) & (targets == 0)) |
|
|
) |
|
|
correct_ai += correct.sum().item() |
|
|
|
|
|
ai_accuracy = correct_ai / ai_count |
|
|
|
|
|
|
|
|
|
|
|
for inputs, targets in nature_loader: |
|
|
inputs = inputs.cuda() |
|
|
targets = targets.cuda() |
|
|
|
|
|
predictions = network(inputs) |
|
|
probabilities = torch.sigmoid(predictions).ravel() |
|
|
|
|
|
correct = ( |
|
|
((probabilities > 0.5) & (targets == 1)) | |
|
|
((probabilities < 0.5) & (targets == 0)) |
|
|
) |
|
|
correct_nature += correct.sum().item() |
|
|
|
|
|
nature_accuracy = correct_nature / nature_count |
|
|
|
|
|
|
|
|
|
|
|
dataset_accuracy = (correct_ai + correct_nature) / (ai_count + nature_count) |
|
|
total_correct += correct_ai + correct_nature |
|
|
total_samples += ai_count + nature_count |
|
|
|
|
|
print(f"Epoch: {epoch_index}, Accuracy: {dataset_accuracy:.4f}") |
|
|
|
|
|
|
|
|
overall_accuracy = total_correct / total_samples |
|
|
|
|
|
|
|
|
if epoch_index == 1: |
|
|
best_performing_epoch = 1 |
|
|
highest_accuracy = overall_accuracy |
|
|
best_model_path = os.path.join(storage_location, 'Network_best.pth') |
|
|
torch.save(network.state_dict(), best_model_path) |
|
|
print(f"Saved best model on Epoch: {epoch_index}") |
|
|
else: |
|
|
if overall_accuracy > highest_accuracy: |
|
|
best_performing_epoch = epoch_index |
|
|
highest_accuracy = overall_accuracy |
|
|
best_model_path = os.path.join(storage_location, 'Network_best.pth') |
|
|
torch.save(network.state_dict(), best_model_path) |
|
|
print(f"Saved best model on Epoch: {epoch_index}") |
|
|
|
|
|
print( |
|
|
f"π Performance Report | " |
|
|
f"Current Epoch: {epoch_index:03d} | " |
|
|
f"Accuracy Score: {overall_accuracy:.2%} | " |
|
|
f"Peak Performance: Epoch {best_performing_epoch:03d} | " |
|
|
f"Highest Accuracy: {highest_accuracy:.2%}" |
|
|
) |
|
|
|
|
|
|
|
|
def configure_gpu(gpu_id): |
|
|
"""Set GPU configuration""" |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id |
|
|
|
|
|
|
|
|
def main_execution(): |
|
|
"""Main training procedure""" |
|
|
|
|
|
torch.set_num_threads(2) |
|
|
toolkit.set_random_seed() |
|
|
|
|
|
|
|
|
global config |
|
|
config = Configurator().parse() |
|
|
val_config = prepare_validation_config() |
|
|
|
|
|
|
|
|
global total_batches |
|
|
train_loader = fetch_train_data(config) |
|
|
total_batches = len(train_loader) |
|
|
val_loader = fetch_val_data(val_config) |
|
|
|
|
|
|
|
|
configure_gpu(config.gpu_id) |
|
|
|
|
|
|
|
|
model = NeuralNetwork().cuda() |
|
|
if config.load: |
|
|
model.load_state_dict(torch.load(config.load)) |
|
|
print(f"Loaded model from {config.load}") |
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), config.lr) |
|
|
|
|
|
|
|
|
output_dir = config.save_path |
|
|
if not os.path.exists(output_dir): |
|
|
os.makedirs(output_dir) |
|
|
|
|
|
|
|
|
global iteration_counter, best_performing_epoch, highest_accuracy |
|
|
iteration_counter = 0 |
|
|
best_performing_epoch = 0 |
|
|
highest_accuracy = 0 |
|
|
|
|
|
print("||Training||") |
|
|
|
|
|
|
|
|
for epoch in range(1, config.epoch + 1): |
|
|
|
|
|
current_lr = toolkit.poly_lr(optimizer, config.lr, epoch, config.epoch) |
|
|
|
|
|
|
|
|
execute_training_iteration( |
|
|
train_loader, model, optimizer, epoch, output_dir |
|
|
) |
|
|
|
|
|
|
|
|
perform_validation( |
|
|
val_loader, model, epoch, output_dir |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main_execution() |