File size: 2,084 Bytes
a452638
 
dbf4eb2
7c7d7c6
57fe04c
 
 
dbf4eb2
7c7d7c6
 
 
 
 
 
 
dbf4eb2
57fe04c
 
 
 
 
 
 
 
a452638
57fe04c
 
7c7d7c6
57fe04c
 
 
aa13915
 
dbf4eb2
 
7c7d7c6
 
 
dbf4eb2
a452638
 
 
dbf4eb2
 
76b4c44
 
dbf4eb2
 
38f2cf5
 
 
a452638
dbf4eb2
38f2cf5
 
dbf4eb2
38f2cf5
 
7c7d7c6
dbf4eb2
38f2cf5
dbf4eb2
38f2cf5
 
a452638
 
 
 
 
 
 
38f2cf5
7c7d7c6
e6eb9cf
7c7d7c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from io import BytesIO

import pandas as pd
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline

st.set_page_config(
    page_title="Zero-shot classification from tabular data",
    page_icon=None,
    layout="wide",
    initial_sidebar_state="auto",
    menu_items=None,
)


@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None,
    },
    allow_output_mutation=True,
    show_spinner=False,
)
def load_classifier() -> Pipeline:
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    return classifier


with st.spinner(text="Setting stuff up related to the inference engine..."):
    classifier = load_classifier()

st.title("Zero-shot classification from tabular data")
st.text(
    "Upload an Excel table and perform zero-shot classification on a set of custom labels"
)

data = st.file_uploader(
    "Upload Excel file (it should contain a column named `text` in its header):"
)
labels = st.text_input("Enter comma-separated labels:")

# classify first N snippets only for faster inference

if st.button("Calculate labels"):

    try:
        labels_list = labels.split(",")
        table = pd.read_excel(data)
        table = table.loc[table["text"].apply(len) > 10].reset_index(drop=True)

        prog_bar = st.progress(0)
        preds = []

        for i in range(len(table)):
            preds.append(classifier(table.loc[i, "text"], labels)["labels"][0])
            prog_bar.progress((i + 1) / len(table))

        table["label"] = preds

        st.table(table[["text", "label"]])

        buf = BytesIO()
        table[["text", "label"]].to_excel(buf)

        st.download_button(
            label="Download table", data=buf.getvalue(), file_name="output.xlsx"
        )

    except:
        st.error(
            "Something went wrong. Make sure you upload an Excel file containing a column named `text` and a set of comma-separated labels is provided"
        )