Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from transformers import AutoModel, AutoProcessor | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AestheticsEvaluator: | |
| """Image aesthetics assessment using multiple SOTA models""" | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.processors = {} | |
| self.load_models() | |
| def load_models(self): | |
| """Load aesthetics assessment models""" | |
| try: | |
| # Load UNIAA model (primary) | |
| logger.info("Loading UNIAA model...") | |
| self.load_uniaa() | |
| # Load MUSIQ model (secondary) | |
| logger.info("Loading MUSIQ model...") | |
| self.load_musiq() | |
| # Load anime-specific aesthetic model | |
| logger.info("Loading anime aesthetic model...") | |
| self.load_anime_aesthetic_model() | |
| except Exception as e: | |
| logger.error(f"Error loading aesthetic models: {str(e)}") | |
| self.use_fallback_implementation() | |
| def load_uniaa(self): | |
| """Load UNIAA model""" | |
| try: | |
| # Placeholder implementation for UNIAA | |
| self.models['uniaa'] = self.create_mock_aesthetic_model() | |
| self.processors['uniaa'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load UNIAA: {str(e)}") | |
| def load_musiq(self): | |
| """Load MUSIQ model""" | |
| try: | |
| # Placeholder implementation for MUSIQ | |
| self.models['musiq'] = self.create_mock_aesthetic_model() | |
| self.processors['musiq'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load MUSIQ: {str(e)}") | |
| def load_anime_aesthetic_model(self): | |
| """Load anime-specific aesthetic model""" | |
| try: | |
| # Placeholder for anime-specific model | |
| self.models['anime_aesthetic'] = self.create_mock_aesthetic_model() | |
| self.processors['anime_aesthetic'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load anime aesthetic model: {str(e)}") | |
| def create_mock_aesthetic_model(self): | |
| """Create a mock aesthetic model for demonstration""" | |
| class MockAestheticModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(64, 128, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d((1, 1)), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(128, 64), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(64, 1), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) * 10 # Scale to 0-10 | |
| model = MockAestheticModel().to(self.device) | |
| model.eval() | |
| return model | |
| def use_fallback_implementation(self): | |
| """Use simple fallback aesthetic assessment""" | |
| logger.info("Using fallback aesthetic assessment implementation") | |
| self.fallback_mode = True | |
| def evaluate_with_uniaa(self, image: Image.Image) -> float: | |
| """Evaluate aesthetics using UNIAA""" | |
| try: | |
| if 'uniaa' not in self.models: | |
| return self.fallback_aesthetic_score(image) | |
| # Preprocess image | |
| tensor = self.processors['uniaa'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| score = self.models['uniaa'](tensor).item() | |
| return max(0.0, min(10.0, score)) | |
| except Exception as e: | |
| logger.error(f"Error in UNIAA evaluation: {str(e)}") | |
| return self.fallback_aesthetic_score(image) | |
| def evaluate_with_musiq(self, image: Image.Image) -> float: | |
| """Evaluate aesthetics using MUSIQ""" | |
| try: | |
| if 'musiq' not in self.models: | |
| return self.fallback_aesthetic_score(image) | |
| # Preprocess image | |
| tensor = self.processors['musiq'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| score = self.models['musiq'](tensor).item() | |
| return max(0.0, min(10.0, score)) | |
| except Exception as e: | |
| logger.error(f"Error in MUSIQ evaluation: {str(e)}") | |
| return self.fallback_aesthetic_score(image) | |
| def evaluate_with_anime_model(self, image: Image.Image) -> float: | |
| """Evaluate aesthetics using anime-specific model""" | |
| try: | |
| if 'anime_aesthetic' not in self.models: | |
| return self.fallback_aesthetic_score(image) | |
| # Preprocess image | |
| tensor = self.processors['anime_aesthetic'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| score = self.models['anime_aesthetic'](tensor).item() | |
| return max(0.0, min(10.0, score)) | |
| except Exception as e: | |
| logger.error(f"Error in anime aesthetic evaluation: {str(e)}") | |
| return self.fallback_aesthetic_score(image) | |
| def evaluate_composition_rules(self, image: Image.Image) -> float: | |
| """Evaluate based on composition rules (rule of thirds, etc.)""" | |
| try: | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| height, width = img_array.shape[:2] | |
| # Convert to grayscale for analysis | |
| if len(img_array.shape) == 3: | |
| gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) | |
| else: | |
| gray = img_array | |
| # Rule of thirds analysis | |
| third_h, third_w = height // 3, width // 3 | |
| # Check for interesting content at rule of thirds intersections | |
| intersections = [ | |
| (third_h, third_w), (third_h, 2*third_w), | |
| (2*third_h, third_w), (2*third_h, 2*third_w) | |
| ] | |
| composition_score = 0.0 | |
| for y, x in intersections: | |
| # Check local variance around intersection points | |
| region = gray[max(0, y-10):min(height, y+10), | |
| max(0, x-10):min(width, x+10)] | |
| if region.size > 0: | |
| composition_score += region.var() | |
| # Normalize composition score | |
| composition_score = min(10.0, composition_score / 1000.0) | |
| # Color harmony analysis | |
| if len(img_array.shape) == 3: | |
| # Calculate color distribution | |
| colors = img_array.reshape(-1, 3) | |
| color_std = np.std(colors, axis=0).mean() | |
| color_harmony_score = min(10.0, color_std / 25.0) | |
| else: | |
| color_harmony_score = 5.0 | |
| # Combine scores | |
| final_score = (composition_score * 0.6 + color_harmony_score * 0.4) | |
| return max(0.0, min(10.0, final_score)) | |
| except Exception as e: | |
| logger.error(f"Error in composition analysis: {str(e)}") | |
| return 5.0 | |
| def fallback_aesthetic_score(self, image: Image.Image) -> float: | |
| """Simple fallback aesthetic assessment""" | |
| try: | |
| # Basic aesthetic assessment based on image properties | |
| width, height = image.size | |
| # Aspect ratio score (prefer aesthetically pleasing ratios) | |
| aspect_ratio = width / height | |
| golden_ratio = 1.618 | |
| if abs(aspect_ratio - golden_ratio) < 0.1 or abs(aspect_ratio - 1/golden_ratio) < 0.1: | |
| aspect_score = 9.0 | |
| elif 0.7 <= aspect_ratio <= 1.4: # Square-ish | |
| aspect_score = 7.0 | |
| elif 1.4 <= aspect_ratio <= 2.0: # Landscape | |
| aspect_score = 8.0 | |
| else: | |
| aspect_score = 5.0 | |
| # Resolution score (higher resolution often looks better) | |
| total_pixels = width * height | |
| resolution_score = min(10.0, total_pixels / 200000.0) # Normalize by 2MP | |
| # Color analysis | |
| img_array = np.array(image) | |
| if len(img_array.shape) == 3: | |
| # Color variety score | |
| unique_colors = len(np.unique(img_array.reshape(-1, 3), axis=0)) | |
| color_variety_score = min(10.0, unique_colors / 1000.0) | |
| # Brightness distribution | |
| brightness = np.mean(img_array, axis=2) | |
| brightness_score = 10.0 - abs(brightness.mean() - 127.5) / 12.75 | |
| else: | |
| color_variety_score = 5.0 | |
| brightness_score = 5.0 | |
| # Combine scores | |
| aesthetic_score = (aspect_score * 0.3 + | |
| resolution_score * 0.2 + | |
| color_variety_score * 0.3 + | |
| brightness_score * 0.2) | |
| return max(0.0, min(10.0, aesthetic_score)) | |
| except Exception: | |
| return 5.0 # Default neutral score | |
| def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: | |
| """ | |
| Evaluate image aesthetics using ensemble of models | |
| Args: | |
| image: PIL Image to evaluate | |
| anime_mode: Whether to use anime-specific evaluation | |
| Returns: | |
| Aesthetic score from 0-10 | |
| """ | |
| try: | |
| scores = [] | |
| if anime_mode: | |
| # For anime images, prioritize anime-specific model | |
| anime_score = self.evaluate_with_anime_model(image) | |
| scores.append(anime_score) | |
| # Also use general models but with lower weight | |
| uniaa_score = self.evaluate_with_uniaa(image) | |
| scores.append(uniaa_score) | |
| # Composition rules | |
| composition_score = self.evaluate_composition_rules(image) | |
| scores.append(composition_score) | |
| # Weights for anime mode | |
| weights = [0.5, 0.3, 0.2] | |
| else: | |
| # For realistic images, use general aesthetic models | |
| uniaa_score = self.evaluate_with_uniaa(image) | |
| scores.append(uniaa_score) | |
| musiq_score = self.evaluate_with_musiq(image) | |
| scores.append(musiq_score) | |
| # Composition rules | |
| composition_score = self.evaluate_composition_rules(image) | |
| scores.append(composition_score) | |
| # Weights for realistic mode | |
| weights = [0.4, 0.4, 0.2] | |
| # Ensemble scoring | |
| final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
| logger.info(f"Aesthetic scores - Scores: {scores}, Final: {final_score:.2f}") | |
| return max(0.0, min(10.0, final_score)) | |
| except Exception as e: | |
| logger.error(f"Error in aesthetic evaluation: {str(e)}") | |
| return self.fallback_aesthetic_score(image) | |