Spaces:
Runtime error
Runtime error
File size: 4,364 Bytes
3ec6971 41386c7 3ec6971 65e4324 3ec6971 29de8f2 3ec6971 65e4324 3ec6971 29de8f2 3ec6971 e4016f5 3ec6971 0f63c19 65e4324 3ec6971 a28b188 3ec6971 e4016f5 9fd37ff 3ec6971 a28b188 0b51283 a28b188 3ec6971 a28b188 0b51283 a28b188 3ec6971 0b51283 3ec6971 0f63c19 3ec6971 0f63c19 3ec6971 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, DebertaV2Tokenizer, DebertaV2Model
import sentencepiece
import streamlit as st
import pandas as pd
import spacy
from spacy import displacy
import plotly.express as px
import numpy as np
example_list = [
"""Hong Kong’s two-week flight ban has dashed the hopes of those planning family reunions as well as disrupted plans for incoming domestic helpers, with the Philippines, Britain and the United States among eight countries hit with tightened rules aimed at containing a Covid-19 surge.""",
"""From Friday (Jan 7), all bars and entertainment venues will close for two weeks, and restaurants have to stop dine-in after 6pm, Chief Executive Carrie Lam Cheng Yuet-ngor announced on Wednesday. """
]
st.set_page_config(layout="wide", page_title="Vocabulary Categorizer")
st.title("Vocabulary Categorizer")
st.write("This application identifies, highlights and categorizes nouns.")
model_list = ['xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large']
st.sidebar.header("Vocabulary Categorizer")
model_checkpoint = st.sidebar.radio("", model_list)
st.sidebar.write("Which model highlights the most nouns? Which model highlights nouns the most accurately?")
st.sidebar.write("")
xlm_agg_strategy_info = "'aggregation_strategy' can be selected as 'simple' or 'none' for 'xlm-roberta'."
st.sidebar.header("Select Aggregation Strategy Type")
if model_checkpoint == "xlm-roberta-large-finetuned-conll03-english":
aggregation = st.sidebar.radio("", ('simple', 'none'))
st.sidebar.write(xlm_agg_strategy_info)
st.sidebar.write("")
elif model_checkpoint == "xlm-roberta-large":
aggregation = st.sidebar.radio("", ('simple', 'none'))
st.sidebar.write(xlm_agg_strategy_info)
st.sidebar.write("")
st.subheader("Select Text Input Method")
input_method = st.radio("", ('Select from examples', 'Write or paste text'))
if input_method == 'Select from examples':
selected_text = st.selectbox('Select example from list', example_list, index=0, key=1)
st.subheader("Text to Run")
input_text = st.text_area("Selected example", selected_text, height=128, max_chars=None, key=2)
elif input_method == "Write or paste text":
st.subheader("Text Input")
input_text = st.text_area('Write or paste text below', value="", height=128, max_chars=None, key=2)
@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint, aggregation):
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
@st.cache(allow_output_mutation=True)
def get_html(html: str):
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
html = html.replace("\n", " ")
return WRAPPER.format(html)
Run_Button = st.button("Run", key=None)
if Run_Button == True:
ner_pipeline = setModel(model_checkpoint, aggregation)
output = ner_pipeline(input_text)
df = pd.DataFrame.from_dict(output)
if aggregation != "none":
cols_to_keep = ['word','entity_group','score','start','end']
else:
cols_to_keep = ['word','entity','score','start','end']
df_final = df[cols_to_keep]
st.subheader("Categorized Nouns")
st.dataframe(df_final)
st.subheader("Highlighted Nouns")
spacy_display = {}
spacy_display["ents"] = []
spacy_display["text"] = input_text
spacy_display["title"] = None
for entity in output:
if aggregation != "none":
spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
else:
spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity"]})
entity_list = ["PER", "LOC", "ORG", "MISC"]
colors = {'PER': '#85DCDF', 'LOC': '#DF85DC', 'ORG': '#DCDF85', 'MISC': '#85ABDF',}
html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True, options={"ents": entity_list, "colors": colors})
style = "<style>mark.entity { display: inline-block }</style>"
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True) |