image-evaluation-tool / models /prompt_evaluator.py
VOIDER's picture
Upload 14 files
83b7522 verified
import torch
import numpy as np
from PIL import Image
import clip
from transformers import BlipProcessor, BlipForConditionalGeneration
import logging
from sentence_transformers import SentenceTransformer, util
logger = logging.getLogger(__name__)
class PromptEvaluator:
"""Prompt following assessment using CLIP and other vision-language models"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models = {}
self.processors = {}
self.load_models()
def load_models(self):
"""Load prompt evaluation models"""
try:
# Load CLIP model (primary)
logger.info("Loading CLIP model...")
self.load_clip()
# Load BLIP-2 model (secondary)
logger.info("Loading BLIP-2 model...")
self.load_blip2()
# Load sentence transformer for text similarity
logger.info("Loading sentence transformer...")
self.load_sentence_transformer()
except Exception as e:
logger.error(f"Error loading prompt evaluation models: {str(e)}")
self.use_fallback_implementation()
def load_clip(self):
"""Load CLIP model"""
try:
model, preprocess = clip.load("ViT-B/32", device=self.device)
self.models['clip'] = model
self.processors['clip'] = preprocess
logger.info("CLIP model loaded successfully")
except Exception as e:
logger.warning(f"Could not load CLIP: {str(e)}")
def load_blip2(self):
"""Load BLIP-2 model"""
try:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model = model.to(self.device)
self.models['blip2'] = model
self.processors['blip2'] = processor
logger.info("BLIP-2 model loaded successfully")
except Exception as e:
logger.warning(f"Could not load BLIP-2: {str(e)}")
def load_sentence_transformer(self):
"""Load sentence transformer for text similarity"""
try:
model = SentenceTransformer('all-MiniLM-L6-v2')
self.models['sentence_transformer'] = model
logger.info("Sentence transformer loaded successfully")
except Exception as e:
logger.warning(f"Could not load sentence transformer: {str(e)}")
def use_fallback_implementation(self):
"""Use simple fallback prompt evaluation"""
logger.info("Using fallback prompt evaluation implementation")
self.fallback_mode = True
def evaluate_with_clip(self, image: Image.Image, prompt: str) -> float:
"""Evaluate prompt following using CLIP"""
try:
if 'clip' not in self.models:
return self.fallback_prompt_score(image, prompt)
model = self.models['clip']
preprocess = self.processors['clip']
# Preprocess image
image_tensor = preprocess(image).unsqueeze(0).to(self.device)
# Tokenize text
text_tokens = clip.tokenize([prompt]).to(self.device)
# Get features
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text_tokens)
# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Calculate similarity
similarity = (image_features @ text_features.T).item()
# Convert similarity to 0-10 scale
# CLIP similarity is typically between -1 and 1, but usually 0-1 for related content
score = max(0.0, min(10.0, (similarity + 1) * 5))
return score
except Exception as e:
logger.error(f"Error in CLIP evaluation: {str(e)}")
return self.fallback_prompt_score(image, prompt)
def evaluate_with_blip2(self, image: Image.Image, prompt: str) -> float:
"""Evaluate prompt following using BLIP-2"""
try:
if 'blip2' not in self.models:
return self.fallback_prompt_score(image, prompt)
model = self.models['blip2']
processor = self.processors['blip2']
# Generate caption for the image
inputs = processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
out = model.generate(**inputs, max_length=50)
generated_caption = processor.decode(out[0], skip_special_tokens=True)
# Compare generated caption with original prompt using text similarity
if 'sentence_transformer' in self.models:
similarity_score = self.calculate_text_similarity(prompt, generated_caption)
else:
# Simple word overlap fallback
similarity_score = self.simple_text_similarity(prompt, generated_caption)
return similarity_score
except Exception as e:
logger.error(f"Error in BLIP-2 evaluation: {str(e)}")
return self.fallback_prompt_score(image, prompt)
def calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate semantic similarity between two texts"""
try:
model = self.models['sentence_transformer']
# Encode texts
embeddings = model.encode([text1, text2])
# Calculate cosine similarity
similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
# Convert to 0-10 scale
score = max(0.0, min(10.0, (similarity + 1) * 5))
return score
except Exception as e:
logger.error(f"Error calculating text similarity: {str(e)}")
return self.simple_text_similarity(text1, text2)
def simple_text_similarity(self, text1: str, text2: str) -> float:
"""Simple word overlap similarity"""
try:
# Convert to lowercase and split into words
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
# Calculate Jaccard similarity
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
if union == 0:
return 0.0
jaccard_similarity = intersection / union
# Convert to 0-10 scale
score = jaccard_similarity * 10
return max(0.0, min(10.0, score))
except Exception:
return 5.0 # Default neutral score
def extract_key_concepts(self, prompt: str) -> list:
"""Extract key concepts from prompt for detailed analysis"""
try:
# Simple keyword extraction
# In production, this could use more sophisticated NLP
# Remove common words
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'}
words = prompt.lower().split()
key_concepts = [word for word in words if word not in stop_words and len(word) > 2]
return key_concepts
except Exception:
return []
def evaluate_concept_presence(self, image: Image.Image, concepts: list) -> float:
"""Evaluate presence of specific concepts in image"""
try:
if 'clip' not in self.models or not concepts:
return 5.0
model = self.models['clip']
preprocess = self.processors['clip']
# Preprocess image
image_tensor = preprocess(image).unsqueeze(0).to(self.device)
# Create concept queries
concept_queries = [f"a photo of {concept}" for concept in concepts]
# Tokenize concepts
text_tokens = clip.tokenize(concept_queries).to(self.device)
# Get features
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text_tokens)
# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Calculate similarities
similarities = (image_features @ text_features.T).squeeze(0)
# Average similarity across concepts
avg_similarity = similarities.mean().item()
# Convert to 0-10 scale
score = max(0.0, min(10.0, (avg_similarity + 1) * 5))
return score
except Exception as e:
logger.error(f"Error in concept presence evaluation: {str(e)}")
return 5.0
def fallback_prompt_score(self, image: Image.Image, prompt: str) -> float:
"""Simple fallback prompt evaluation"""
try:
# Very basic evaluation based on prompt length and image properties
prompt_length = len(prompt.split())
# Longer, more detailed prompts might be harder to follow perfectly
if prompt_length < 5:
length_penalty = 0.0
elif prompt_length < 15:
length_penalty = 0.5
else:
length_penalty = 1.0
# Base score
base_score = 7.0 - length_penalty
return max(0.0, min(10.0, base_score))
except Exception:
return 5.0 # Default neutral score
def evaluate(self, image: Image.Image, prompt: str) -> float:
"""
Evaluate how well the image follows the given prompt
Args:
image: PIL Image to evaluate
prompt: Text prompt to compare against
Returns:
Prompt following score from 0-10
"""
try:
if not prompt or not prompt.strip():
return 0.0 # No prompt to evaluate against
scores = []
# CLIP evaluation (primary)
clip_score = self.evaluate_with_clip(image, prompt)
scores.append(clip_score)
# BLIP-2 evaluation (secondary)
blip2_score = self.evaluate_with_blip2(image, prompt)
scores.append(blip2_score)
# Concept presence evaluation
key_concepts = self.extract_key_concepts(prompt)
concept_score = self.evaluate_concept_presence(image, key_concepts)
scores.append(concept_score)
# Ensemble scoring
weights = [0.5, 0.3, 0.2] # CLIP gets highest weight
final_score = sum(score * weight for score, weight in zip(scores, weights))
logger.info(f"Prompt scores - CLIP: {clip_score:.2f}, BLIP-2: {blip2_score:.2f}, "
f"Concepts: {concept_score:.2f}, Final: {final_score:.2f}")
return max(0.0, min(10.0, final_score))
except Exception as e:
logger.error(f"Error in prompt evaluation: {str(e)}")
return self.fallback_prompt_score(image, prompt)