File size: 1,820 Bytes
eafb42b
c55ae80
 
 
 
 
 
eafb42b
c55ae80
eafb42b
1f63757
c55ae80
 
 
eafb42b
c55ae80
eafb42b
c55ae80
eafb42b
 
 
 
bc1ab2f
eafb42b
c55ae80
eafb42b
 
c55ae80
 
eafb42b
 
 
 
 
 
 
 
c55ae80
 
eafb42b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import torch
from transformers import AutoProcessor, BlipForConditionalGeneration

app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
device = "cuda" if torch.cuda.is_available() else "cpu"
vision_processor, vision_model = None, None

try:
    vision_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    vision_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
    print("--- VISION SERVICE --- BLIP Vision model loaded successfully.")
except Exception as e:
    print(f"--- VISION SERVICE --- CRITICAL ERROR loading Vision model: {e}")

@app.route("/describe_image", methods=["POST"])
def describe_image():
    if not vision_model:
        return jsonify({"error": "Vision model not available."}), 500
    
    user_prompt = request.form.get("prompt", "")
    image_file = request.files.get("image")
    if not image_file:
        return jsonify({"error": "No image file found."}), 400

    try:
        image_obj = Image.open(image_file.stream).convert("RGB")
        inputs = (vision_processor(images=image_obj, text=user_prompt, return_tensors="pt").to(device) if user_prompt else vision_processor(images=image_obj, return_tensors="pt").to(device))
        output = vision_model.generate(**inputs, max_new_tokens=50)
        caption = vision_processor.decode(output[0], skip_special_tokens=True).strip()
        return jsonify({"content": caption})
    except Exception as e:
        print(f"Error processing image: {e}")
        return jsonify({"error": "Sorry, I had trouble processing that image."}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8081) # Use a different port for local testing