Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from transformers import AutoModel, AutoProcessor | |
| import cv2 | |
| import logging | |
| from scipy import ndimage | |
| logger = logging.getLogger(__name__) | |
| class AIDetectionEvaluator: | |
| """AI-generated image detection using multiple approaches""" | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.processors = {} | |
| self.load_models() | |
| def load_models(self): | |
| """Load AI detection models""" | |
| try: | |
| # Load Sentry-Image model (primary) | |
| logger.info("Loading Sentry-Image model...") | |
| self.load_sentry_image() | |
| # Load custom ensemble model (secondary) | |
| logger.info("Loading custom ensemble model...") | |
| self.load_custom_ensemble() | |
| # Load traditional artifact detection | |
| logger.info("Loading traditional artifact detection...") | |
| self.load_artifact_detection() | |
| except Exception as e: | |
| logger.error(f"Error loading AI detection models: {str(e)}") | |
| self.use_fallback_implementation() | |
| def load_sentry_image(self): | |
| """Load Sentry-Image model""" | |
| try: | |
| # Placeholder implementation for Sentry-Image | |
| # In production, this would load the actual Sentry-Image model | |
| self.models['sentry'] = self.create_mock_detection_model() | |
| self.processors['sentry'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load Sentry-Image: {str(e)}") | |
| def load_custom_ensemble(self): | |
| """Load custom ensemble detection model""" | |
| try: | |
| # Placeholder for custom ensemble | |
| self.models['ensemble'] = self.create_mock_detection_model() | |
| self.processors['ensemble'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load custom ensemble: {str(e)}") | |
| def load_artifact_detection(self): | |
| """Load traditional artifact detection methods""" | |
| try: | |
| # These would be implemented using opencv and scipy | |
| self.artifact_detection_available = True | |
| except Exception as e: | |
| logger.warning(f"Could not load artifact detection: {str(e)}") | |
| self.artifact_detection_available = False | |
| def create_mock_detection_model(self): | |
| """Create a mock detection model for demonstration""" | |
| class MockDetectionModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(64, 128, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d((1, 1)), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(128, 64), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(64, 1), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) # Returns probability 0-1 | |
| model = MockDetectionModel().to(self.device) | |
| model.eval() | |
| return model | |
| def use_fallback_implementation(self): | |
| """Use simple fallback AI detection""" | |
| logger.info("Using fallback AI detection implementation") | |
| self.fallback_mode = True | |
| def evaluate_with_sentry(self, image: Image.Image) -> float: | |
| """Evaluate AI generation probability using Sentry-Image""" | |
| try: | |
| if 'sentry' not in self.models: | |
| return self.fallback_detection_score(image) | |
| # Preprocess image | |
| tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| probability = self.models['sentry'](tensor).item() | |
| return max(0.0, min(1.0, probability)) | |
| except Exception as e: | |
| logger.error(f"Error in Sentry evaluation: {str(e)}") | |
| return self.fallback_detection_score(image) | |
| def evaluate_with_ensemble(self, image: Image.Image) -> float: | |
| """Evaluate AI generation probability using custom ensemble""" | |
| try: | |
| if 'ensemble' not in self.models: | |
| return self.fallback_detection_score(image) | |
| # Preprocess image | |
| tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| probability = self.models['ensemble'](tensor).item() | |
| return max(0.0, min(1.0, probability)) | |
| except Exception as e: | |
| logger.error(f"Error in ensemble evaluation: {str(e)}") | |
| return self.fallback_detection_score(image) | |
| def detect_compression_artifacts(self, image: Image.Image) -> float: | |
| """Detect compression artifacts that might indicate AI generation""" | |
| try: | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| # Convert to grayscale | |
| if len(img_array.shape) == 3: | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img_array | |
| # Detect JPEG compression artifacts using DCT analysis | |
| # This is a simplified version - real implementation would be more complex | |
| # Calculate local variance to detect blocking artifacts | |
| kernel = np.ones((8, 8), np.float32) / 64 | |
| local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel) | |
| local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel) | |
| # High variance in 8x8 blocks might indicate JPEG artifacts | |
| block_variance = np.mean(local_variance) | |
| # Normalize to 0-1 probability | |
| artifact_probability = min(1.0, block_variance / 1000.0) | |
| return artifact_probability | |
| except Exception as e: | |
| logger.error(f"Error in compression artifact detection: {str(e)}") | |
| return 0.5 | |
| def detect_frequency_anomalies(self, image: Image.Image) -> float: | |
| """Detect frequency domain anomalies common in AI-generated images""" | |
| try: | |
| # Convert to numpy array and grayscale | |
| img_array = np.array(image) | |
| if len(img_array.shape) == 3: | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img_array | |
| # Apply FFT | |
| f_transform = np.fft.fft2(gray) | |
| f_shift = np.fft.fftshift(f_transform) | |
| magnitude_spectrum = np.log(np.abs(f_shift) + 1) | |
| # Analyze frequency distribution | |
| # AI-generated images often have specific frequency patterns | |
| # Calculate radial frequency distribution | |
| h, w = magnitude_spectrum.shape | |
| center_y, center_x = h // 2, w // 2 | |
| # Create radial mask | |
| y, x = np.ogrid[:h, :w] | |
| mask = (x - center_x) ** 2 + (y - center_y) ** 2 | |
| # Calculate mean magnitude at different frequencies | |
| low_freq_mask = mask <= (min(h, w) // 8) ** 2 | |
| high_freq_mask = mask >= (min(h, w) // 4) ** 2 | |
| low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask]) | |
| high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask]) | |
| # AI images often have unusual low/high frequency ratios | |
| if high_freq_energy > 0: | |
| freq_ratio = low_freq_energy / high_freq_energy | |
| # Normalize to probability | |
| anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0) | |
| else: | |
| anomaly_probability = 0.5 | |
| return anomaly_probability | |
| except Exception as e: | |
| logger.error(f"Error in frequency analysis: {str(e)}") | |
| return 0.5 | |
| def detect_pixel_patterns(self, image: Image.Image) -> float: | |
| """Detect suspicious pixel patterns common in AI-generated images""" | |
| try: | |
| img_array = np.array(image) | |
| # Check for perfect pixel repetitions (uncommon in natural images) | |
| if len(img_array.shape) == 3: | |
| # Flatten to check for repeated pixel values | |
| pixels = img_array.reshape(-1, 3) | |
| unique_pixels = np.unique(pixels, axis=0) | |
| # Calculate pixel diversity | |
| pixel_diversity = len(unique_pixels) / len(pixels) | |
| # Very low diversity might indicate AI generation | |
| if pixel_diversity < 0.1: | |
| pattern_probability = 0.8 | |
| elif pixel_diversity < 0.3: | |
| pattern_probability = 0.6 | |
| else: | |
| pattern_probability = 0.2 | |
| else: | |
| pattern_probability = 0.5 | |
| # Check for unnatural smoothness | |
| if len(img_array.shape) == 3: | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img_array | |
| # Calculate local standard deviation | |
| local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3) | |
| avg_local_std = np.mean(local_std) | |
| # Very smooth images might be AI-generated | |
| if avg_local_std < 5.0: | |
| smoothness_probability = 0.7 | |
| elif avg_local_std < 15.0: | |
| smoothness_probability = 0.4 | |
| else: | |
| smoothness_probability = 0.2 | |
| # Combine pattern and smoothness indicators | |
| combined_probability = (pattern_probability + smoothness_probability) / 2 | |
| return max(0.0, min(1.0, combined_probability)) | |
| except Exception as e: | |
| logger.error(f"Error in pixel pattern detection: {str(e)}") | |
| return 0.5 | |
| def analyze_metadata_indicators(self, image: Image.Image) -> float: | |
| """Analyze image metadata for AI generation indicators""" | |
| try: | |
| # Check image format and properties | |
| format_probability = 0.0 | |
| # PNG format is more common for AI-generated images | |
| if image.format == 'PNG': | |
| format_probability += 0.3 | |
| # Check for specific dimensions common in AI generation | |
| width, height = image.size | |
| # Common AI generation resolutions | |
| ai_resolutions = [ | |
| (512, 512), (768, 768), (1024, 1024), # Square formats | |
| (512, 768), (768, 512), # 2:3 ratios | |
| (1024, 768), (768, 1024) # 4:3 ratios | |
| ] | |
| if (width, height) in ai_resolutions: | |
| format_probability += 0.4 | |
| # Check for perfect aspect ratios (less common in natural photos) | |
| aspect_ratio = width / height | |
| common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25] | |
| for ratio in common_ai_ratios: | |
| if abs(aspect_ratio - ratio) < 0.01: | |
| format_probability += 0.2 | |
| break | |
| return max(0.0, min(1.0, format_probability)) | |
| except Exception as e: | |
| logger.error(f"Error in metadata analysis: {str(e)}") | |
| return 0.5 | |
| def fallback_detection_score(self, image: Image.Image) -> float: | |
| """Simple fallback AI detection""" | |
| try: | |
| # Combine multiple simple heuristics | |
| scores = [] | |
| # Compression artifacts | |
| artifact_score = self.detect_compression_artifacts(image) | |
| scores.append(artifact_score) | |
| # Frequency anomalies | |
| freq_score = self.detect_frequency_anomalies(image) | |
| scores.append(freq_score) | |
| # Pixel patterns | |
| pattern_score = self.detect_pixel_patterns(image) | |
| scores.append(pattern_score) | |
| # Metadata indicators | |
| metadata_score = self.analyze_metadata_indicators(image) | |
| scores.append(metadata_score) | |
| # Average the scores | |
| final_score = np.mean(scores) | |
| return max(0.0, min(1.0, final_score)) | |
| except Exception: | |
| return 0.5 # Default neutral probability | |
| def evaluate(self, image: Image.Image) -> float: | |
| """ | |
| Evaluate probability that image is AI-generated | |
| Args: | |
| image: PIL Image to evaluate | |
| Returns: | |
| AI generation probability from 0-1 (0 = likely real, 1 = likely AI) | |
| """ | |
| try: | |
| scores = [] | |
| # Sentry-Image evaluation (primary) | |
| sentry_score = self.evaluate_with_sentry(image) | |
| scores.append(sentry_score) | |
| # Custom ensemble evaluation (secondary) | |
| ensemble_score = self.evaluate_with_ensemble(image) | |
| scores.append(ensemble_score) | |
| # Traditional artifact detection | |
| artifact_score = self.fallback_detection_score(image) | |
| scores.append(artifact_score) | |
| # Ensemble scoring | |
| weights = [0.5, 0.3, 0.2] # Sentry gets highest weight | |
| final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
| logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, " | |
| f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, " | |
| f"Final: {final_score:.3f}") | |
| return max(0.0, min(1.0, final_score)) | |
| except Exception as e: | |
| logger.error(f"Error in AI detection evaluation: {str(e)}") | |
| return self.fallback_detection_score(image) | |