AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
import argparse
import json
import time
import yaml
import torch
from PIL import Image
import torchvision.transforms.v2 as Tv2
from networks import ImageClassifier
import sys
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from support.detect_utils import format_result, save_result, get_device
# ----------------------------------------------------------------------------
# IMAGE PREPROCESSING
# ----------------------------------------------------------------------------
def preprocess_image(image_path):
"""
Load and preprocess a single image for model input.
Uses the same normalization as test.py (ImageNet stats).
"""
# Load image
image = Image.open(image_path).convert('RGB')
# Apply transforms (same as test split without augmentation)
transform = Tv2.Compose([
Tv2.ToImage(),
Tv2.ToDtype(torch.float32, scale=True),
Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply transforms and add batch dimension
tensor = transform(image)
tensor = tensor.unsqueeze(0) # Add batch dimension
return tensor
# ----------------------------------------------------------------------------
# CONFIG LOADING AND PARSING
# ----------------------------------------------------------------------------
def load_config(config_path):
"""Load configuration from YAML file."""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def parse_detector_args(detector_args, default_num_centers=1):
"""
Parse detector_args list (e.g., ["--arch", "nodown", "--prototype", "--freeze"])
into a settings object.
"""
class Settings:
def __init__(self):
self.arch = "nodown"
self.freeze = False
self.prototype = False
self.num_centers = default_num_centers
settings = Settings()
i = 0
while i < len(detector_args):
arg = detector_args[i]
if arg == "--arch":
if i + 1 < len(detector_args):
settings.arch = detector_args[i + 1]
i += 2
else:
i += 1
elif arg == "--freeze":
settings.freeze = True
i += 1
elif arg == "--prototype":
settings.prototype = True
i += 1
elif arg == "--num_centers":
if i + 1 < len(detector_args):
settings.num_centers = int(detector_args[i + 1])
i += 2
else:
i += 1
else:
i += 1
return settings
def resolve_config_path(config_path):
"""
Resolve config path. If relative, resolve relative to project root
(two levels up from detect.py location).
"""
if os.path.isabs(config_path):
return config_path
# Get directory of detect.py (detectors/R50_TF/)
detect_dir = os.path.dirname(os.path.abspath(__file__))
# Go two levels up to project root
project_root = os.path.dirname(os.path.dirname(detect_dir))
# Join with config path
return os.path.join(project_root, config_path)
# ----------------------------------------------------------------------------
# INFERENCE
# ----------------------------------------------------------------------------
def run_inference(model, image_path, device):
"""
Run inference on a single image.
Returns: (probability, label, runtime_ms)
"""
start_time = time.time()
# Preprocess image
image_tensor = preprocess_image(image_path)
image_tensor = image_tensor.to(device)
# Run inference
model.eval()
with torch.no_grad():
raw_score_tensor = model(image_tensor).squeeze(1) # shape [1]
# Convert to probability using sigmoid
probability = torch.sigmoid(raw_score_tensor).item()
# Determine label (fake if probability > 0.5, else real)
label = "fake" if probability > 0.5 else "real"
# Calculate runtime in milliseconds
runtime_ms = int((time.time() - start_time) * 1000)
return probability, label, runtime_ms
# ----------------------------------------------------------------------------
# MAIN
# ----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description='Single image inference for R50_TF detector')
parser.add_argument('--input', type=str, required=False, help='Path to input image (alias: --image)')
parser.add_argument('--image', type=str, required=False, help='Path to input image (alias for --input)')
parser.add_argument('--output', type=str, default='/tmp/result.json', help='Path to output JSON file')
parser.add_argument('--checkpoint', type=str, required=False, help='Path to model checkpoint file')
parser.add_argument('--model', type=str, required=False, help='Model name or checkpoint directory (alias for --checkpoint)')
parser.add_argument('--config', type=str, default='configs/R50_TF.yaml', help='Path to YAML config file')
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda:0, cpu, etc.)')
args = parser.parse_args()
# Normalize image argument: prefer --image over --input if provided
if args.image:
args.input = args.image
checkpoint_path = None
if args.checkpoint:
checkpoint_path = args.checkpoint
elif getattr(args, 'model', None):
detect_dir = os.path.dirname(os.path.abspath(__file__))
candidate = os.path.join(detect_dir, 'checkpoint', args.model, 'weights', 'best.pt')
if os.path.exists(candidate):
checkpoint_path = candidate
else:
# If model refers directly to a file path, accept it
if os.path.isabs(args.model) and os.path.exists(args.model):
checkpoint_path = args.model
else:
# Try resolving relative to project root
project_root = os.path.dirname(os.path.dirname(detect_dir))
candidate2 = os.path.join(project_root, args.model)
if os.path.exists(candidate2):
checkpoint_path = candidate2
# If still not found, keep existing behavior (will raise later)
if checkpoint_path:
args.checkpoint = checkpoint_path
# Resolve config path
config_path = resolve_config_path(args.config)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Load config
config = load_config(config_path)
# Parse detector_args from config
detector_args = config.get('detector_args', [])
settings = parse_detector_args(detector_args)
# Get device from config if available, else use argument
# Prioritize argument if explicitly provided (we assume if it's not default, or if we trust the caller)
# Since we want to support --device cpu override, we should prioritize args.device
device_str = args.device
# Only check config if args.device wasn't explicitly passed (but here it has a default)
# Let's assume if the user passed --device, they want that.
# But args.device has a default 'cuda:0'.
# We should change the default to None to distinguish.
if args.device is None:
if config.get('global', {}).get('device_override'):
device_override = config['global']['device_override']
if device_override and device_override != "null" and device_override != "":
device_str = device_override
else:
device_str = 'cuda:0'
else:
device_str = args.device
# Determine device
if device_str.startswith('cuda') and not torch.cuda.is_available():
print(f"Warning: CUDA requested but not available. Using CPU.")
device = torch.device('cpu')
else:
device = torch.device(device_str if torch.cuda.is_available() else 'cpu')
# Load model
print(f"Loading model from {args.checkpoint}")
model = ImageClassifier(settings)
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
model.to(device)
model.eval()
# Run inference
print(f"Running inference on {args.input}")
probability, label, runtime_ms = run_inference(model, args.input, device)
# Format result to match other detectors (prediction/confidence/elapsed_time)
elapsed_time = runtime_ms / 1000.0
formatted = format_result(label, float(round(probability, 4)), elapsed_time)
# Save using shared utility (if output path is provided)
if args.output:
save_result(formatted, args.output)
print(f"Results saved to {args.output}")
# Print concise output for user
print(f"Prediction: {formatted['prediction']}")
print(f"Confidence: {formatted['confidence']:.4f}")
print(f"Time: {formatted['elapsed_time']:.3f}s")
if __name__ == '__main__':
main()