testfacebook / app.py
TLH01's picture
Update app.py
cc3cb31 verified
import streamlit as st
import torch
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Streamlit page configuration
st.set_page_config(page_title="Review Keypoint Extractor (BART-Large-CNN)", page_icon="🔑")
# Define the model
model_name = "facebook/bart-large-cnn"
# Cache the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return tokenizer, model, device
# Keypoint generation function
def generate_keypoint(review, max_new_tokens=64):
tokenizer, model, device = load_model_and_tokenizer()
start_time = time.time()
# BART-specific prompt (no additional prompt engineering)
prompt = review
# Inference
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
keypoint = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Post-process: Normalize "no key point" outputs
if keypoint.lower() in ["none", "no keypoint", "no key point", "n/a", "na", "", "nothing"]:
keypoint = "No key point"
elapsed = time.time() - start_time
return keypoint, elapsed
# Streamlit UI
st.title("🔑 Review Keypoint Extractor (BART-Large-CNN)")
st.write("Enter a product review below to extract its key points using the facebook/bart-large-cnn model.")
# Input field for review
review = st.text_area("Product Review", placeholder="e.g., The Jackery power station is lightweight and charges quickly, but the battery life could be longer.")
# Button to generate keypoint
if st.button("Extract Keypoint"):
if review.strip():
with st.spinner("Generating keypoint..."):
keypoint, elapsed = generate_keypoint(review)
st.success(f"✅ Keypoint generated in {elapsed:.2f} seconds!")
st.subheader("Results")
st.write(f"**Review:** {review}")
st.write(f"**Keypoint:** {keypoint}")
else:
st.error("⚠️ Please enter a valid review.")
# Footer
st.markdown("---")
st.markdown("Powered by [Hugging Face Transformers](https://huggingface.co/) and [Streamlit](https://streamlit.io/)")