Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import clip | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| import logging | |
| from sentence_transformers import SentenceTransformer, util | |
| logger = logging.getLogger(__name__) | |
| class PromptEvaluator: | |
| """Prompt following assessment using CLIP and other vision-language models""" | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.processors = {} | |
| self.load_models() | |
| def load_models(self): | |
| """Load prompt evaluation models""" | |
| try: | |
| # Load CLIP model (primary) | |
| logger.info("Loading CLIP model...") | |
| self.load_clip() | |
| # Load BLIP-2 model (secondary) | |
| logger.info("Loading BLIP-2 model...") | |
| self.load_blip2() | |
| # Load sentence transformer for text similarity | |
| logger.info("Loading sentence transformer...") | |
| self.load_sentence_transformer() | |
| except Exception as e: | |
| logger.error(f"Error loading prompt evaluation models: {str(e)}") | |
| self.use_fallback_implementation() | |
| def load_clip(self): | |
| """Load CLIP model""" | |
| try: | |
| model, preprocess = clip.load("ViT-B/32", device=self.device) | |
| self.models['clip'] = model | |
| self.processors['clip'] = preprocess | |
| logger.info("CLIP model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Could not load CLIP: {str(e)}") | |
| def load_blip2(self): | |
| """Load BLIP-2 model""" | |
| try: | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = model.to(self.device) | |
| self.models['blip2'] = model | |
| self.processors['blip2'] = processor | |
| logger.info("BLIP-2 model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Could not load BLIP-2: {str(e)}") | |
| def load_sentence_transformer(self): | |
| """Load sentence transformer for text similarity""" | |
| try: | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.models['sentence_transformer'] = model | |
| logger.info("Sentence transformer loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Could not load sentence transformer: {str(e)}") | |
| def use_fallback_implementation(self): | |
| """Use simple fallback prompt evaluation""" | |
| logger.info("Using fallback prompt evaluation implementation") | |
| self.fallback_mode = True | |
| def evaluate_with_clip(self, image: Image.Image, prompt: str) -> float: | |
| """Evaluate prompt following using CLIP""" | |
| try: | |
| if 'clip' not in self.models: | |
| return self.fallback_prompt_score(image, prompt) | |
| model = self.models['clip'] | |
| preprocess = self.processors['clip'] | |
| # Preprocess image | |
| image_tensor = preprocess(image).unsqueeze(0).to(self.device) | |
| # Tokenize text | |
| text_tokens = clip.tokenize([prompt]).to(self.device) | |
| # Get features | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_tensor) | |
| text_features = model.encode_text(text_tokens) | |
| # Normalize features | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| # Calculate similarity | |
| similarity = (image_features @ text_features.T).item() | |
| # Convert similarity to 0-10 scale | |
| # CLIP similarity is typically between -1 and 1, but usually 0-1 for related content | |
| score = max(0.0, min(10.0, (similarity + 1) * 5)) | |
| return score | |
| except Exception as e: | |
| logger.error(f"Error in CLIP evaluation: {str(e)}") | |
| return self.fallback_prompt_score(image, prompt) | |
| def evaluate_with_blip2(self, image: Image.Image, prompt: str) -> float: | |
| """Evaluate prompt following using BLIP-2""" | |
| try: | |
| if 'blip2' not in self.models: | |
| return self.fallback_prompt_score(image, prompt) | |
| model = self.models['blip2'] | |
| processor = self.processors['blip2'] | |
| # Generate caption for the image | |
| inputs = processor(image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_length=50) | |
| generated_caption = processor.decode(out[0], skip_special_tokens=True) | |
| # Compare generated caption with original prompt using text similarity | |
| if 'sentence_transformer' in self.models: | |
| similarity_score = self.calculate_text_similarity(prompt, generated_caption) | |
| else: | |
| # Simple word overlap fallback | |
| similarity_score = self.simple_text_similarity(prompt, generated_caption) | |
| return similarity_score | |
| except Exception as e: | |
| logger.error(f"Error in BLIP-2 evaluation: {str(e)}") | |
| return self.fallback_prompt_score(image, prompt) | |
| def calculate_text_similarity(self, text1: str, text2: str) -> float: | |
| """Calculate semantic similarity between two texts""" | |
| try: | |
| model = self.models['sentence_transformer'] | |
| # Encode texts | |
| embeddings = model.encode([text1, text2]) | |
| # Calculate cosine similarity | |
| similarity = util.cos_sim(embeddings[0], embeddings[1]).item() | |
| # Convert to 0-10 scale | |
| score = max(0.0, min(10.0, (similarity + 1) * 5)) | |
| return score | |
| except Exception as e: | |
| logger.error(f"Error calculating text similarity: {str(e)}") | |
| return self.simple_text_similarity(text1, text2) | |
| def simple_text_similarity(self, text1: str, text2: str) -> float: | |
| """Simple word overlap similarity""" | |
| try: | |
| # Convert to lowercase and split into words | |
| words1 = set(text1.lower().split()) | |
| words2 = set(text2.lower().split()) | |
| # Calculate Jaccard similarity | |
| intersection = len(words1.intersection(words2)) | |
| union = len(words1.union(words2)) | |
| if union == 0: | |
| return 0.0 | |
| jaccard_similarity = intersection / union | |
| # Convert to 0-10 scale | |
| score = jaccard_similarity * 10 | |
| return max(0.0, min(10.0, score)) | |
| except Exception: | |
| return 5.0 # Default neutral score | |
| def extract_key_concepts(self, prompt: str) -> list: | |
| """Extract key concepts from prompt for detailed analysis""" | |
| try: | |
| # Simple keyword extraction | |
| # In production, this could use more sophisticated NLP | |
| # Remove common words | |
| stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'} | |
| words = prompt.lower().split() | |
| key_concepts = [word for word in words if word not in stop_words and len(word) > 2] | |
| return key_concepts | |
| except Exception: | |
| return [] | |
| def evaluate_concept_presence(self, image: Image.Image, concepts: list) -> float: | |
| """Evaluate presence of specific concepts in image""" | |
| try: | |
| if 'clip' not in self.models or not concepts: | |
| return 5.0 | |
| model = self.models['clip'] | |
| preprocess = self.processors['clip'] | |
| # Preprocess image | |
| image_tensor = preprocess(image).unsqueeze(0).to(self.device) | |
| # Create concept queries | |
| concept_queries = [f"a photo of {concept}" for concept in concepts] | |
| # Tokenize concepts | |
| text_tokens = clip.tokenize(concept_queries).to(self.device) | |
| # Get features | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_tensor) | |
| text_features = model.encode_text(text_tokens) | |
| # Normalize features | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| # Calculate similarities | |
| similarities = (image_features @ text_features.T).squeeze(0) | |
| # Average similarity across concepts | |
| avg_similarity = similarities.mean().item() | |
| # Convert to 0-10 scale | |
| score = max(0.0, min(10.0, (avg_similarity + 1) * 5)) | |
| return score | |
| except Exception as e: | |
| logger.error(f"Error in concept presence evaluation: {str(e)}") | |
| return 5.0 | |
| def fallback_prompt_score(self, image: Image.Image, prompt: str) -> float: | |
| """Simple fallback prompt evaluation""" | |
| try: | |
| # Very basic evaluation based on prompt length and image properties | |
| prompt_length = len(prompt.split()) | |
| # Longer, more detailed prompts might be harder to follow perfectly | |
| if prompt_length < 5: | |
| length_penalty = 0.0 | |
| elif prompt_length < 15: | |
| length_penalty = 0.5 | |
| else: | |
| length_penalty = 1.0 | |
| # Base score | |
| base_score = 7.0 - length_penalty | |
| return max(0.0, min(10.0, base_score)) | |
| except Exception: | |
| return 5.0 # Default neutral score | |
| def evaluate(self, image: Image.Image, prompt: str) -> float: | |
| """ | |
| Evaluate how well the image follows the given prompt | |
| Args: | |
| image: PIL Image to evaluate | |
| prompt: Text prompt to compare against | |
| Returns: | |
| Prompt following score from 0-10 | |
| """ | |
| try: | |
| if not prompt or not prompt.strip(): | |
| return 0.0 # No prompt to evaluate against | |
| scores = [] | |
| # CLIP evaluation (primary) | |
| clip_score = self.evaluate_with_clip(image, prompt) | |
| scores.append(clip_score) | |
| # BLIP-2 evaluation (secondary) | |
| blip2_score = self.evaluate_with_blip2(image, prompt) | |
| scores.append(blip2_score) | |
| # Concept presence evaluation | |
| key_concepts = self.extract_key_concepts(prompt) | |
| concept_score = self.evaluate_concept_presence(image, key_concepts) | |
| scores.append(concept_score) | |
| # Ensemble scoring | |
| weights = [0.5, 0.3, 0.2] # CLIP gets highest weight | |
| final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
| logger.info(f"Prompt scores - CLIP: {clip_score:.2f}, BLIP-2: {blip2_score:.2f}, " | |
| f"Concepts: {concept_score:.2f}, Final: {final_score:.2f}") | |
| return max(0.0, min(10.0, final_score)) | |
| except Exception as e: | |
| logger.error(f"Error in prompt evaluation: {str(e)}") | |
| return self.fallback_prompt_score(image, prompt) | |