Mr.Phil / app.py
luciagomez's picture
Update app.py
d1b24e1 verified
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, MistralForCausalLM
from huggingface_hub import snapshot_download
from mcp_serpapi import search_pdf_url
from utils import download_pdf
# ---------------------------
# Config
# ---------------------------
HF_TOKEN = os.environ.get("HF_TOKEN") # stored in Spaces secrets
MODEL_REPO = "mistralai/Mistral-7B-Instruct-v0.3"
MODEL_PATH = "/tmp/mistral"
CACHE_DIR = "/tmp/.cache/huggingface"
OFFLOAD_DIR = "/tmp/offload"
os.makedirs(MODEL_PATH, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(OFFLOAD_DIR, exist_ok=True)
os.environ["HF_HOME"] = CACHE_DIR
# ---------------------------
# Download model if needed
# ---------------------------
if not os.listdir(MODEL_PATH):
print("Downloading Mistral model snapshot to /tmp...")
snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir=MODEL_PATH,
use_auth_token=HF_TOKEN
)
else:
print("Model already exists in /tmp. Skipping download.")
# ---------------------------
# Load tokenizer and Mistral model
# ---------------------------
print("Loading tokenizer and MistralForCausalLM model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
model = MistralForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
offload_folder=OFFLOAD_DIR,
torch_dtype=torch.float16
)
# ---------------------------
# Model query function
# ---------------------------
def query_mistral(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ---------------------------
# PDF search function
# ---------------------------
def find_and_download_annual_report(company: str) -> str:
query = f"site:example.com {company} annual report filetype:pdf"
pdf_url = search_pdf_url(query)
if pdf_url:
local_path = download_pdf(pdf_url, save_dir="/tmp/pdfs")
return f"Downloaded: {local_path}"
return "No PDF found"
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("## Mistral-7B-Instruct Chatbot + Annual Report PDF Finder")
with gr.Tab("Chat with Mistral"):
prompt_box = gr.Textbox(label="Enter prompt")
mistral_output = gr.Textbox(label="Model output")
mistral_button = gr.Button("Run")
mistral_button.click(
fn=query_mistral,
inputs=prompt_box,
outputs=mistral_output
)
with gr.Tab("Download Annual Report PDF"):
company_box = gr.Textbox(label="Company Name")
pdf_output = gr.Textbox(label="Result")
pdf_button = gr.Button("Find PDF")
pdf_button.click(
fn=find_and_download_annual_report,
inputs=company_box,
outputs=pdf_output
)
demo.launch(server_name="0.0.0.0", server_port=7860)