Ham01's picture
Update app.py
3cb8888 verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import aiohttp
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "laion/CLIP-ViT-g-14-laion2B-s12B-b42K"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).to(device)
model.eval()
embedding_cache = {}
async def get_embedding_async(image_url):
if image_url in embedding_cache:
return embedding_cache[image_url]
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
data = await resp.read()
image = Image.open(BytesIO(data)).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
embedding = model.get_image_features(**inputs)
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
embedding_list = embedding.cpu().numpy().tolist()
embedding_cache[image_url] = embedding_list
return embedding_list
except Exception as e:
return str(e)
demo = gr.Interface(
fn=get_embedding_async,
inputs=gr.Textbox(label="Image URL", placeholder="Enter an image URL"),
outputs="json",
title="OpenCLIP ViT-g-14 Embeddings",
)
demo.launch()