import os import torch import numpy as np import lightning.pytorch as pl import gradio as gr import imageio import random import matplotlib.pyplot as plt import cv2 import skdim from torch.utils.data import Dataset, DataLoader from PIL import Image from matplotlib import cm from safetensors.torch import save_file, load_file from sklearn.cluster import AgglomerativeClustering from sklearn.manifold import TSNE from sklearn.neighbors import KDTree from sklearn.preprocessing import StandardScaler from minimal_script import EmbeddingNetwork, closest_interval, adj_size, PLModule class PredictDataset(Dataset): def __init__(self, data_dir, sample=None): self.image_paths = [] extensions = ('jpg', 'jpeg', 'png', 'tif', 'webp') for fname in sorted(os.listdir(data_dir)): if any(fname.lower().endswith(ext) for ext in extensions): self.image_paths.append(os.path.join(data_dir, fname)) if sample: self.image_paths = random.sample(self.image_paths, sample) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): path = self.image_paths[idx] image = imageio.v3.imread(path).copy() image = torch.from_numpy(image).permute(2, 0, 1) processed = closest_interval(adj_size(image, 1024)) processed = 2*(processed/255)-1 return processed.detach(), path def explore_embedding_space(embeddings, image_paths, model): """ Create an interface for exploring N-dimensional image embeddings Args: embeddings: NumPy array of shape [B, N] image_paths: List of B image file paths """ # Validate inputs assert len(embeddings) == len(image_paths), "Mismatch between embeddings and image paths" assert embeddings.ndim == 2, "Embeddings should be 2-dimensional" # Precompute min/max for each dimension min_vals = embeddings.min(axis=0) max_vals = embeddings.max(axis=0) ranges = max_vals - min_vals # Build KDTree for efficient nearest neighbor search tree = KDTree(embeddings) # Create initial point (mean of embeddings) initial_point = embeddings.mean(axis=0).tolist() # Create slider components for each dimension sliders = [] for i in range(embeddings.shape[1]): slider = gr.Slider( float(min_vals[i]), float(max_vals[i]), value=float(initial_point[i]), step=float(ranges[i]) / 100, label=f"Dimension {i + 1}" ) sliders.append(slider) def compute_gradient_heatmap(image_path): """Compute gradient heatmap for an image""" # Load and preprocess image img = imageio.v3.imread(image_path).copy() img = torch.from_numpy(img).permute(2, 0, 1) img_tensor = closest_interval(adj_size(img, 1024)).unsqueeze(0) img_tensor = 2*(img_tensor/255)-1 img_tensor.requires_grad_(True) # Move to GPU if available device = 'cuda' if torch.cuda.is_available() else 'cpu' img_tensor = img_tensor.to(device).to(torch.float16) # Compute embedding and gradient with torch.enable_grad(): embd = model(img_tensor) norm = embd.norm(p=2, dim=1).sum() grad = torch.autograd.grad(norm, img_tensor, retain_graph=False)[0] # Compute gradient magnitude grad_mag = grad.squeeze(0).norm(dim=0).detach().cpu().numpy() # Normalize and apply colormap grad_min, grad_max = grad_mag.min(), grad_mag.max() if grad_max > grad_min: grad_norm = (grad_mag - grad_min) / (grad_max - grad_min) else: grad_norm = grad_mag * 0 # Handle uniform case heatmap = cm.jet(grad_norm)[..., :3] # Use jet colormap return heatmap def overlay_heatmap(original_img, heatmap, alpha=0.4): """Overlay heatmap on original image""" # Resize heatmap to match original image heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)) heatmap_img = heatmap_img.resize(original_img.size) # Convert original to RGBA and heatmap to RGBA #original_rgba = original_img.convert("RGBA") #heatmap_rgba = heatmap_img.convert("RGBA") # Blend images blended = Image.blend(original_img, heatmap_img, alpha) return blended def get_overlay_image(image_path): """Get image with gradient overlay""" img = Image.open(image_path).convert('RGB') #heatmap = compute_gradient_heatmap(image_path) #return overlay_heatmap(img, heatmap) return img def add_caption_to_image(image, caption): """Add text caption to the bottom of an image""" # Convert to OpenCV format if isinstance(image, Image.Image): img = np.array(image) else: img = image.copy() # Add black bar at bottom bar_height = 30 img = cv2.copyMakeBorder(img, 0, bar_height, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) # Add white text font = cv2.FONT_HERSHEY_SIMPLEX text_size = cv2.getTextSize(caption, font, 0.5, 1)[0] text_x = (img.shape[1] - text_size[0]) // 2 text_y = img.shape[0] - 10 cv2.putText(img, caption, (text_x, text_y), font, 0.5, (255, 255, 255), 1) return Image.fromarray(img) # Function to find nearby images def find_nearby_images(*point): point = np.array(point).reshape(1, -1) distances, indices = tree.query(point, k=8) indices = indices[0] distances = distances[0] # Get paths and create overlay images paths = [image_paths[i] for i in indices] images_with_gradients = [get_overlay_image(p) for p in paths] # Create images with baked-in captions final_images = [] for img, dist in zip(images_with_gradients, distances): caption = f"Dist: {dist:.2f}" final_img = add_caption_to_image(img, caption) final_images.append(final_img) warning = "" if distances[0] > 5.0: # Warn if nearest image is far warning = "⚠️ Nearest image is far (distance={:.2f}). Consider adjusting sliders.".format(distances[0]) return final_images, warning # Build interface with gr.Blocks() as demo: gr.Markdown("## N-Dimensional Embedding Space Explorer") gr.Markdown("Adjust sliders to navigate. Images show gradient of embedding norm w.r.t. input.") # Warning output warning = gr.Textbox(label="Status", interactive=False) # Gallery for images gallery = gr.Gallery( label="Nearest Images (Distance Ordered)", columns=4, object_fit="contain", height="auto", show_label=True, ) # Create sliders in a compact row with gr.Row(): for slider in sliders: slider.render() # Connect slider changes to update function for slider in sliders: slider.change( find_nearby_images, inputs=sliders, outputs=[gallery, warning] ) # Initial trigger demo.load( find_nearby_images, inputs=sliders, outputs=[gallery, warning] ) return demo def generate_embeddings(image_folder, mode, model): predict_dataset = PredictDataset(image_folder, 5000) predict_loader = DataLoader(predict_dataset, batch_size=1, num_workers=5, pin_memory=True) trainer = pl.Trainer(accelerator="gpu", logger=False, enable_checkpointing=False, precision="16-mixed") predictions_0 = trainer.predict(model, predict_loader) predictions = torch.cat([pred[0] for pred in predictions_0], dim=0).numpy() paths = [] for pred in predictions_0: for i in pred[1]: paths.append(i) if mode == 'Grouping': #estimate global intrinsic dimension #scaler = StandardScaler() #normalised_predictions = scaler.fit_transform(predictions) # Initialize estimators estimators = [skdim.id.TwoNN(), skdim.id.CorrInt(), skdim.id.DANCo()] results = {} for est in estimators: est.fit(predictions) results[type(est).__name__] = est.dimension_ print("Intrinsic Dimension Estimates:") for name, dim in results.items(): print(f"{name}: {dim:.2f}") labels = cluster_embeddings(predictions) row_norms = np.linalg.norm(predictions, axis=1) average_norms = np.mean(np.abs(predictions), axis=0) plt.figure(figsize=(8, 5)) plt.bar(range(predictions.shape[1]), average_norms, color='skyblue') plt.xlabel('Feature Index (C)') plt.ylabel('Average Norm') plt.title(f'Average Norm for Each Feature (Column)') plt.xticks(range(predictions.shape[1])) #plt.show() plt.savefig('Norms.png') plt.figure(figsize=(8, 6)) tsne = TSNE(n_components=2, random_state=42) reduced_data = tsne.fit_transform(predictions) plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=row_norms, cmap='viridis', s=50, edgecolor='k', label="Data Points") plt.colorbar(label='Norm Value') plt.xlabel('Feature 1') plt.ylabel('Feature 2') plt.title(f'Scatter Plot of Data Points and Average Norm') plt.legend() plt.grid(True) plt.axis('equal') #plt.show() plt.savefig('Groups.png') # List unique clusters unique_clusters = np.unique(labels) # Gradio UI with gr.Blocks() as demo: gr.Markdown("## Explore Image Clusters by Style") # Dropdown for selecting a cluster cluster_selector = gr.Dropdown(choices=unique_clusters.tolist(), label="Select Cluster to Explore") # Gallery to display images image_gallery = gr.Gallery(label="Sample Images from Selected Cluster") # Gradio Interface for Cluster Exploration def explore_clusters(cluster_idx): # Find images that belong to the selected cluster cluster_images = [paths[i] for i in range(len(labels)) if labels[i] == cluster_idx] # Load and return images images = [Image.open(img_path) for img_path in cluster_images[:50]] # Show a sample of 50 images return images # Update function for the gallery cluster_selector.change(fn=explore_clusters, inputs=cluster_selector, outputs=image_gallery) demo.launch() elif mode == 'Explore': demo = explore_embedding_space(predictions, paths, model.to('cuda').to(torch.float16)) demo.launch() # Apply Agglomerative Clustering def cluster_embeddings(predictions, distance_threshold=32.0): agg_clustering = AgglomerativeClustering( n_clusters=None, distance_threshold=distance_threshold, linkage='ward' ) labels = agg_clustering.fit_predict(predictions) return labels if __name__ == '__main__': folder = 'Enter Images folder name here' #folder = 'images_for_style_embedding' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PLModule() state_dict = load_file("Style_Embedder_v3.safetensors") model.network.load_state_dict(state_dict) # 'Grouping' or 'Explore' generate_embeddings(folder, 'Grouping', model)