Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, MistralForCausalLM | |
| from huggingface_hub import snapshot_download | |
| from mcp_serpapi import search_pdf_url | |
| from utils import download_pdf | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # stored in Spaces secrets | |
| MODEL_REPO = "mistralai/Mistral-7B-Instruct-v0.3" | |
| MODEL_PATH = "/tmp/mistral" | |
| CACHE_DIR = "/tmp/.cache/huggingface" | |
| OFFLOAD_DIR = "/tmp/offload" | |
| os.makedirs(MODEL_PATH, exist_ok=True) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.makedirs(OFFLOAD_DIR, exist_ok=True) | |
| os.environ["HF_HOME"] = CACHE_DIR | |
| # --------------------------- | |
| # Download model if needed | |
| # --------------------------- | |
| if not os.listdir(MODEL_PATH): | |
| print("Downloading Mistral model snapshot to /tmp...") | |
| snapshot_download( | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| local_dir=MODEL_PATH, | |
| use_auth_token=HF_TOKEN | |
| ) | |
| else: | |
| print("Model already exists in /tmp. Skipping download.") | |
| # --------------------------- | |
| # Load tokenizer and Mistral model | |
| # --------------------------- | |
| print("Loading tokenizer and MistralForCausalLM model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) | |
| model = MistralForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| device_map="auto", | |
| offload_folder=OFFLOAD_DIR, | |
| torch_dtype=torch.float16 | |
| ) | |
| # --------------------------- | |
| # Model query function | |
| # --------------------------- | |
| def query_mistral(prompt: str) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # --------------------------- | |
| # PDF search function | |
| # --------------------------- | |
| def find_and_download_annual_report(company: str) -> str: | |
| query = f"site:example.com {company} annual report filetype:pdf" | |
| pdf_url = search_pdf_url(query) | |
| if pdf_url: | |
| local_path = download_pdf(pdf_url, save_dir="/tmp/pdfs") | |
| return f"Downloaded: {local_path}" | |
| return "No PDF found" | |
| # --------------------------- | |
| # Gradio UI | |
| # --------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Mistral-7B-Instruct Chatbot + Annual Report PDF Finder") | |
| with gr.Tab("Chat with Mistral"): | |
| prompt_box = gr.Textbox(label="Enter prompt") | |
| mistral_output = gr.Textbox(label="Model output") | |
| mistral_button = gr.Button("Run") | |
| mistral_button.click( | |
| fn=query_mistral, | |
| inputs=prompt_box, | |
| outputs=mistral_output | |
| ) | |
| with gr.Tab("Download Annual Report PDF"): | |
| company_box = gr.Textbox(label="Company Name") | |
| pdf_output = gr.Textbox(label="Result") | |
| pdf_button = gr.Button("Find PDF") | |
| pdf_button.click( | |
| fn=find_and_download_annual_report, | |
| inputs=company_box, | |
| outputs=pdf_output | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |