Spaces:
Sleeping
Sleeping
File size: 2,084 Bytes
a452638 dbf4eb2 7c7d7c6 57fe04c dbf4eb2 7c7d7c6 dbf4eb2 57fe04c a452638 57fe04c 7c7d7c6 57fe04c aa13915 dbf4eb2 7c7d7c6 dbf4eb2 a452638 dbf4eb2 76b4c44 dbf4eb2 38f2cf5 a452638 dbf4eb2 38f2cf5 dbf4eb2 38f2cf5 7c7d7c6 dbf4eb2 38f2cf5 dbf4eb2 38f2cf5 a452638 38f2cf5 7c7d7c6 e6eb9cf 7c7d7c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from io import BytesIO
import pandas as pd
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline
st.set_page_config(
page_title="Zero-shot classification from tabular data",
page_icon=None,
layout="wide",
initial_sidebar_state="auto",
menu_items=None,
)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
},
allow_output_mutation=True,
show_spinner=False,
)
def load_classifier() -> Pipeline:
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
return classifier
with st.spinner(text="Setting stuff up related to the inference engine..."):
classifier = load_classifier()
st.title("Zero-shot classification from tabular data")
st.text(
"Upload an Excel table and perform zero-shot classification on a set of custom labels"
)
data = st.file_uploader(
"Upload Excel file (it should contain a column named `text` in its header):"
)
labels = st.text_input("Enter comma-separated labels:")
# classify first N snippets only for faster inference
if st.button("Calculate labels"):
try:
labels_list = labels.split(",")
table = pd.read_excel(data)
table = table.loc[table["text"].apply(len) > 10].reset_index(drop=True)
prog_bar = st.progress(0)
preds = []
for i in range(len(table)):
preds.append(classifier(table.loc[i, "text"], labels)["labels"][0])
prog_bar.progress((i + 1) / len(table))
table["label"] = preds
st.table(table[["text", "label"]])
buf = BytesIO()
table[["text", "label"]].to_excel(buf)
st.download_button(
label="Download table", data=buf.getvalue(), file_name="output.xlsx"
)
except:
st.error(
"Something went wrong. Make sure you upload an Excel file containing a column named `text` and a set of comma-separated labels is provided"
)
|