File size: 7,557 Bytes
d972a8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import os
import torch
import numpy
from datetime import datetime
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Import original modules with aliases
import util as toolkit
from loader import get_loader as fetch_train_data, get_val_loader as fetch_val_data
from config import ConfigurationManager as Configurator
from model import model as NeuralNetwork
from util import bceLoss as compute_binary_loss
def prepare_validation_config():
"""Create validation-specific configuration"""
val_cfg = Configurator().parse()
val_cfg.isTrain = False
val_cfg.isVal = True
return val_cfg
def execute_training_iteration(
data_provider,
network,
optimizer,
epoch_index,
storage_location
):
"""Perform training iteration"""
network.train()
global iteration_counter
epoch_iterations = 0
total_loss = 0
try:
for batch_idx, (inputs, targets) in enumerate(data_provider, start=1):
optimizer.zero_grad()
# Move data to GPU
inputs = inputs.cuda()
targets = targets.cuda()
# Forward pass
outputs = network(inputs).ravel()
# Compute loss
loss_function = compute_binary_loss()
batch_loss = loss_function(outputs, targets)
# Backward pass
batch_loss.backward()
optimizer.step()
# Update counters
iteration_counter += 1
epoch_iterations += 1
total_loss += batch_loss.item()
# Log progress
if batch_idx % 500 == 0 or batch_idx == total_batches or batch_idx == 1:
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
progress_percent = (batch_idx / total_batches) * 100
status_report = (
f"π Epoch: {epoch_index:02d}/{config.epoch:02d} | "
f"π’ Iteration: {batch_idx:04d}/{total_batches:04d} "
f"({progress_percent:.1f}%) | "
f"π Loss Metric: {batch_loss.item():.6f}"
)
print(status_report)
# Save periodic checkpoint
if epoch_index % 50 == 0:
checkpoint_path = os.path.join(
storage_location,
f'Network_epoch_{epoch_index}.pth'
)
torch.save(network.state_dict(), checkpoint_path)
except KeyboardInterrupt:
print("Training interrupted: saving model and exiting")
def perform_validation(
validation_sets,
network,
epoch_index,
storage_location
):
"""Evaluate model on validation sets"""
network.eval()
global best_performing_epoch, highest_accuracy
total_correct = total_samples = 0
with torch.no_grad():
for dataset in validation_sets:
correct_ai = correct_nature = 0
name = dataset['name']
ai_loader = dataset['val_ai_loader']
ai_count = dataset['ai_size']
nature_loader = dataset['val_nature_loader']
nature_count = dataset['nature_size']
print(f"||Validating||")
# Process AI-generated images
for inputs, targets in ai_loader:
inputs = inputs.cuda()
targets = targets.cuda()
predictions = network(inputs)
probabilities = torch.sigmoid(predictions).ravel()
# Count correct predictions
correct = (
((probabilities > 0.5) & (targets == 1)) |
((probabilities < 0.5) & (targets == 0))
)
correct_ai += correct.sum().item()
ai_accuracy = correct_ai / ai_count
#print(f"AI Accuracy: {ai_accuracy:.4f}")
# Process natural images
for inputs, targets in nature_loader:
inputs = inputs.cuda()
targets = targets.cuda()
predictions = network(inputs)
probabilities = torch.sigmoid(predictions).ravel()
correct = (
((probabilities > 0.5) & (targets == 1)) |
((probabilities < 0.5) & (targets == 0))
)
correct_nature += correct.sum().item()
nature_accuracy = correct_nature / nature_count
#print(f"Nature Accuracy: {nature_accuracy:.4f}")
# Calculate dataset accuracy
dataset_accuracy = (correct_ai + correct_nature) / (ai_count + nature_count)
total_correct += correct_ai + correct_nature
total_samples += ai_count + nature_count
print(f"Epoch: {epoch_index}, Accuracy: {dataset_accuracy:.4f}")
# Calculate overall accuracy
overall_accuracy = total_correct / total_samples
# Save best model
if epoch_index == 1:
best_performing_epoch = 1
highest_accuracy = overall_accuracy
best_model_path = os.path.join(storage_location, 'Network_best.pth')
torch.save(network.state_dict(), best_model_path)
print(f"Saved best model on Epoch: {epoch_index}")
else:
if overall_accuracy > highest_accuracy:
best_performing_epoch = epoch_index
highest_accuracy = overall_accuracy
best_model_path = os.path.join(storage_location, 'Network_best.pth')
torch.save(network.state_dict(), best_model_path)
print(f"Saved best model on Epoch: {epoch_index}")
print(
f"π Performance Report | "
f"Current Epoch: {epoch_index:03d} | "
f"Accuracy Score: {overall_accuracy:.2%} | "
f"Peak Performance: Epoch {best_performing_epoch:03d} | "
f"Highest Accuracy: {highest_accuracy:.2%}"
)
def configure_gpu(gpu_id):
"""Set GPU configuration"""
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
def main_execution():
"""Main training procedure"""
# Initialize environment
torch.set_num_threads(2)
toolkit.set_random_seed()
# Load configurations
global config
config = Configurator().parse()
val_config = prepare_validation_config()
# Prepare data
global total_batches
train_loader = fetch_train_data(config)
total_batches = len(train_loader)
val_loader = fetch_val_data(val_config)
# Configure GPU
configure_gpu(config.gpu_id)
# Initialize model
model = NeuralNetwork().cuda()
if config.load:
model.load_state_dict(torch.load(config.load))
print(f"Loaded model from {config.load}")
# Prepare optimizer
optimizer = torch.optim.Adam(model.parameters(), config.lr)
# Create output directory
output_dir = config.save_path
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Initialize training state
global iteration_counter, best_performing_epoch, highest_accuracy
iteration_counter = 0
best_performing_epoch = 0
highest_accuracy = 0
print("||Training||")
# Training loop
for epoch in range(1, config.epoch + 1):
# Adjust learning rate
current_lr = toolkit.poly_lr(optimizer, config.lr, epoch, config.epoch)
# Training iteration
execute_training_iteration(
train_loader, model, optimizer, epoch, output_dir
)
# Validation
perform_validation(
val_loader, model, epoch, output_dir
)
if __name__ == '__main__':
main_execution() |