Bawno / app.py
shawno's picture
Update app.py
ca917c0 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# Define quantization config for 4-bit
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Use a recommended small, efficient model
# Try 'mistralai/Mistral-7B-Instruct-v0.2' or 'google/gemma-2b-it'
# For Gemma, ensure you've accepted its terms on Hugging Face.
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
# Cache the model loading for speed on subsequent runs in Gradio
def load_model():
print(f"Loading model {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=nf4_config,
device_map="auto" # Distributes model layers automatically
)
# Compile the model (first run will be slow, subsequent fast)
model = torch.compile(model)
print("Model loaded and compiled!")
return tokenizer, model
tokenizer, model = load_model()
def generate_text_from_file(file_obj, prompt_text, max_length=200):
if file_obj is None:
return "Please upload a file."
# Read content from the uploaded file
file_content = file_obj.read().decode("utf-8") # Assuming text file
# Combine file content with user prompt (simple RAG-like approach)
# For more sophisticated file reading and integration, consider LangChain/LlamaIndex
full_prompt = f"The following is content from a file:\n\n{file_content}\n\nBased on this, and the following instruction:\n\n{prompt_text}"
# Tokenize the prompt
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length).to(model.device)
# Generate text with optimized parameters
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
pad_token_id=tokenizer.eos_token_id, # Important for generation stability
do_sample=False, # Use greedy decoding for consistency and speed
use_cache=True # Ensures KV cache is used
)
# Decode the generated text (excluding the input prompt)
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
return generated_text
# Gradio Interface
iface = gr.Interface(
fn=generate_text_from_file,
inputs=[
gr.File(label="Upload Input File (.txt, .md, etc.)"),
gr.Textbox(label="Your Prompt", placeholder="e.g., Summarize the main points or answer this question about the file.")
],
outputs="textbox",
title="Instant LLM Text Generation from Files on Hugging Face Free Space",
description="Upload a text file and provide a prompt to get instant, accurate text generation. Optimized for Hugging Face's free T4 GPU."
)
iface.launch()