cd2412's picture
Update app.py
a7b143d verified
# app.py (fixed, no concurrency_count)
import os, sys, time, traceback, subprocess
from typing import Tuple, Optional
from PIL import Image
try:
import gradio as gr
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
import gradio as gr
def _make_fallback():
def _fallback_answer_with_controller(image, question, source="auto", distilled_model="auto"):
return "Placeholder answer (wire your models in controller.py).", "baseline", 0
return _fallback_answer_with_controller
try:
from controller import answer_with_controller
except Exception as e:
print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
answer_with_controller = _make_fallback()
TITLE = "VQA — Memory + RL Controller"
DESCRIPTION = "Upload an image, enter a question, and the controller will choose the best decoding strategy."
CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]
def vqa_demo_fn(image: Optional[Image.Image], question: str, source: str, distilled_model: str) -> Tuple[str, str, float]:
if image is None:
return "Please upload an image.", "", 0.0
question = (question or "").strip()
if not question:
return "Please enter a question.", "", 0.0
t0 = time.perf_counter()
try:
image_rgb = image.convert("RGB")
pred, strategy_name, action_id = answer_with_controller(
image_rgb, question, source=source, distilled_model=distilled_model
)
latency_ms = (time.perf_counter() - t0) * 1000.0
return str(pred), f"{action_id}{strategy_name}", round(latency_ms, 1)
except Exception as err:
latency_ms = (time.perf_counter() - t0) * 1000.0
print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
return f"Error: {err}", "error", round(latency_ms, 1)
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
with gr.Row():
with gr.Column():
img_in = gr.Image(
type="pil",
label="Image",
height=320,
sources=["upload", "webcam", "clipboard"], # valid
)
q_in = gr.Textbox(label="Question", placeholder="e.g., What colour is the bus?", lines=2, max_lines=4)
source_in = gr.Radio(CONTROLLER_SOURCES, value="auto", label="Controller Source")
dist_in = gr.Radio(DISTILLED_CHOICES, value="auto", label="Distilled Gate (if used)")
run_btn = gr.Button("Predict", variant="primary")
with gr.Column():
ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
strat_out = gr.Textbox(label="Chosen Strategy", interactive=False)
lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False)
run_btn.click(
vqa_demo_fn,
inputs=[img_in, q_in, source_in, dist_in],
outputs=[ans_out, strat_out, lat_out],
api_name="predict",
)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
demo.queue() # no concurrency_count
demo.launch(server_name="0.0.0.0", server_port=port, share=False, show_error=True)