Spaces:
Runtime error
Runtime error
harry-stark
commited on
Commit
·
17a8518
1
Parent(s):
a93f647
Added app files
Browse files- app.py +22 -0
- hf_model.py +16 -0
- requirements.txt +4 -0
- utils.py +17 -0
app.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from hf_model import classifier_zero,load_model
|
| 4 |
+
from utils import plot_result
|
| 5 |
+
classifier=load_model()
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
st.header("Zero Shot Classification")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
sequence = st.text_area(label="Input Sequence")
|
| 12 |
+
labels = st.text_input('Possible topics (separated by `,`)', max_chars=1000)
|
| 13 |
+
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
| 14 |
+
if len(labels) == 0 or len(sequence) == 0:
|
| 15 |
+
st.write('Enter some text and at least one possible topic to see predictions.')
|
| 16 |
+
|
| 17 |
+
multi_class = st.checkbox('Allow multiple correct topics', value=True)
|
| 18 |
+
|
| 19 |
+
with st.spinner('Classifying...'):
|
| 20 |
+
top_topics, scores = classifier_zero(classifier,sequence=sequence,labels=labels,multi_class=multi_class)
|
| 21 |
+
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
|
| 22 |
+
|
hf_model.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipeline
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def load_model():
|
| 5 |
+
|
| 6 |
+
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 8 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 9 |
+
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
| 10 |
+
return classifier
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
|
| 14 |
+
outputs=classifier(sequence, labels,multi_label=multi_class)
|
| 15 |
+
return outputs['labels'], outputs['scores']
|
| 16 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers[sentencepiece]==4.11.0
|
| 2 |
+
streamlit
|
| 3 |
+
plotly
|
| 4 |
+
torch
|
utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import plotly.express as px
|
| 4 |
+
def plot_result(top_topics, scores):
|
| 5 |
+
top_topics = np.array(top_topics)
|
| 6 |
+
scores = np.array(scores)
|
| 7 |
+
scores *= 100
|
| 8 |
+
fig = px.bar(x=scores, y=top_topics, orientation='h',
|
| 9 |
+
labels={'x': 'Confidence', 'y': 'Label'},
|
| 10 |
+
text=scores,
|
| 11 |
+
range_x=(0,115),
|
| 12 |
+
title='Top Predictions',
|
| 13 |
+
color=np.linspace(0,1,len(scores)),
|
| 14 |
+
color_continuous_scale='GnBu')
|
| 15 |
+
fig.update(layout_coloraxis_showscale=False)
|
| 16 |
+
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
| 17 |
+
st.plotly_chart(fig)
|