Update app.py
Browse files
app.py
CHANGED
|
@@ -1,14 +1,25 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import torch
|
| 3 |
-
from
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Load model and tokenizer
|
| 7 |
@st.cache(allow_output_mutation=True)
|
| 8 |
def load_model():
|
| 9 |
-
|
| 10 |
-
tokenizer =
|
| 11 |
-
model
|
|
|
|
|
|
|
|
|
|
| 12 |
return model, tokenizer
|
| 13 |
|
| 14 |
model, tokenizer = load_model()
|
|
@@ -18,8 +29,20 @@ text_input = st.text_area("Enter text here:")
|
|
| 18 |
|
| 19 |
# Prediction
|
| 20 |
if st.button("Predict"):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from prediction_sinhala import MDFEND, TokenizerFromPreTrained
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Set constants for model and tokenizer paths
|
| 8 |
+
MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth"
|
| 9 |
+
BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si'
|
| 10 |
+
DOMAIN_NUM = 3
|
| 11 |
+
MAX_LEN = 160
|
| 12 |
+
BATCH_SIZE = 100
|
| 13 |
|
| 14 |
# Load model and tokenizer
|
| 15 |
@st.cache(allow_output_mutation=True)
|
| 16 |
def load_model():
|
| 17 |
+
# Load the tokenizer from the pre-trained model name
|
| 18 |
+
tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME)
|
| 19 |
+
# Initialize and load the custom model from saved state
|
| 20 |
+
model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400])
|
| 21 |
+
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu')))
|
| 22 |
+
model.eval() # Set the model to evaluation mode
|
| 23 |
return model, tokenizer
|
| 24 |
|
| 25 |
model, tokenizer = load_model()
|
|
|
|
| 29 |
|
| 30 |
# Prediction
|
| 31 |
if st.button("Predict"):
|
| 32 |
+
if text_input: # Check if input is not empty
|
| 33 |
+
# Process the input text through the custom tokenizer
|
| 34 |
+
inputs = tokenizer.tokenize(text_input)
|
| 35 |
+
|
| 36 |
+
# Convert to tensor, add batch dimension, and send to same device as model
|
| 37 |
+
inputs = torch.tensor(inputs).unsqueeze(0).to(model.device)
|
| 38 |
+
|
| 39 |
+
with torch.no_grad(): # No gradient computation
|
| 40 |
+
# Get model prediction
|
| 41 |
+
output_prob = model.predict(inputs)
|
| 42 |
+
|
| 43 |
+
# Interpret the output probability
|
| 44 |
+
prediction = 1 if output_prob >= 0.5 else 0
|
| 45 |
+
result = "offensive" if prediction == 1 else "not offensive"
|
| 46 |
+
st.write(f"Prediction: {result}")
|
| 47 |
+
else:
|
| 48 |
+
st.error("Please enter some text to predict.")
|