import os import gradio as gr import torch from transformers import AutoTokenizer, MistralForCausalLM from huggingface_hub import snapshot_download from mcp_serpapi import search_pdf_url from utils import download_pdf # --------------------------- # Config # --------------------------- HF_TOKEN = os.environ.get("HF_TOKEN") # stored in Spaces secrets MODEL_REPO = "mistralai/Mistral-7B-Instruct-v0.3" MODEL_PATH = "/tmp/mistral" CACHE_DIR = "/tmp/.cache/huggingface" OFFLOAD_DIR = "/tmp/offload" os.makedirs(MODEL_PATH, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(OFFLOAD_DIR, exist_ok=True) os.environ["HF_HOME"] = CACHE_DIR # --------------------------- # Download model if needed # --------------------------- if not os.listdir(MODEL_PATH): print("Downloading Mistral model snapshot to /tmp...") snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir=MODEL_PATH, use_auth_token=HF_TOKEN ) else: print("Model already exists in /tmp. Skipping download.") # --------------------------- # Load tokenizer and Mistral model # --------------------------- print("Loading tokenizer and MistralForCausalLM model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) model = MistralForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", offload_folder=OFFLOAD_DIR, torch_dtype=torch.float16 ) # --------------------------- # Model query function # --------------------------- def query_mistral(prompt: str) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # --------------------------- # PDF search function # --------------------------- def find_and_download_annual_report(company: str) -> str: query = f"site:example.com {company} annual report filetype:pdf" pdf_url = search_pdf_url(query) if pdf_url: local_path = download_pdf(pdf_url, save_dir="/tmp/pdfs") return f"Downloaded: {local_path}" return "No PDF found" # --------------------------- # Gradio UI # --------------------------- with gr.Blocks() as demo: gr.Markdown("## Mistral-7B-Instruct Chatbot + Annual Report PDF Finder") with gr.Tab("Chat with Mistral"): prompt_box = gr.Textbox(label="Enter prompt") mistral_output = gr.Textbox(label="Model output") mistral_button = gr.Button("Run") mistral_button.click( fn=query_mistral, inputs=prompt_box, outputs=mistral_output ) with gr.Tab("Download Annual Report PDF"): company_box = gr.Textbox(label="Company Name") pdf_output = gr.Textbox(label="Result") pdf_button = gr.Button("Find PDF") pdf_button.click( fn=find_and_download_annual_report, inputs=company_box, outputs=pdf_output ) demo.launch(server_name="0.0.0.0", server_port=7860)