Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BartTokenizer, BartForConditionalGeneration, pipeline | |
import numpy as np | |
import torch | |
import re | |
from textstat import textstat | |
MAX_LEN = 256 | |
NUM_BEAMS = 4 | |
EARLY_STOPPING = True | |
N_OUT = 4 | |
cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor') | |
cwi_model = AutoModelForSequenceClassification.from_pretrained( | |
'twigs/cwi-regressor') | |
simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier') | |
simpl_model = BartForConditionalGeneration.from_pretrained( | |
'twigs/bart-text2text-simplifier') | |
cwi_pipe = pipeline('text-classification', model=cwi_model, | |
tokenizer=cwi_tok, function_to_apply='none') | |
fill_pipe = pipeline('fill-mask', model=simpl_model, | |
tokenizer=simpl_tok, top_k=1) | |
def id_replace_complex(s, threshold=0.4): | |
# get all tokens | |
tokens = re.compile('\w+').findall(s) | |
cands = [f"{t}. {s}" for t in tokens] | |
# get complex tokens | |
# if score >= threshold select tokens[idx] | |
compl_tok = [tokens[idx] for idx, x in enumerate( | |
cwi_pipe(cands)) if x['score'] >= threshold] | |
# potentially parallelizable, depends on desired behaviour | |
for t in compl_tok: | |
idx = s.index(t) | |
s = s[:idx] + '<mask>' + s[idx+len(t):] | |
# get top candidate for mask fill in complex token | |
s = fill_pipe(s)[0]['sequence'] | |
return s, compl_tok | |
def generate_candidate_text(s, model, tokenizer, tokenized=False): | |
out = simpl_tok([s], max_length=256, padding="max_length", truncation=True, | |
return_tensors='pt') if not tokenized else s | |
generated_ids = model.generate( | |
input_ids=out['input_ids'], | |
attention_mask=out['attention_mask'], | |
use_cache=True, | |
decoder_start_token_id=simpl_model.config.pad_token_id, | |
num_beams=NUM_BEAMS, | |
max_length=MAX_LEN, | |
early_stopping=EARLY_STOPPING, | |
num_return_sequences=N_OUT | |
) | |
return [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[ | |
1:] for ids in generated_ids] | |
def rank_candidate_text(sentences): | |
fkgl_scores = [textstat.flesch_kincaid_grade(s) for s in sentences] | |
return sentences[np.argmin(fkgl_scores)] | |
def full_pipeline(source, simpl_model, simpl_tok, tokens, lexical=False): | |
modified, complex_words = id_replace_complex( | |
source, threshold=0.2) if lexical else source, None | |
cands = generate_candidate_text(tokens+modified, simpl_model, simpl_tok) | |
output = rank_candidate_text(cands) | |
return output, complex_words | |
def main(): | |
aug_tok = ['c_', 'lev_', 'dep_', 'rank_', 'rat_', 'n_syl_'] | |
base_tokens = ['CharRatio', 'LevSim', 'DependencyTreeDepth', | |
'WordComplexity', 'WordRatio', 'NumberOfSyllables'] | |
default_values = [0.8, 0.6, 0.9, 0.8, 0.9, 1.9] | |
user_values = default_values | |
tok_values = dict((t, default_values[idx]) for idx, t in enumerate(base_tokens)) | |
example_sentences = ["A matchbook is a small cardboard folder (matchcover) enclosing a quantity of matches and having a coarse striking surface on the exterior.", | |
"If there are no strong land use controls, buildings are built along a bypass, converting it into an ordinary town road, and the bypass may eventually become as congested as the local streets it was intended to avoid.", | |
"Plot Captain Caleb Holt (Kirk Cameron) is a firefighter in Albany, Georgia and firmly keeps the cardinal rule of all firemen, \"Never leave your partner behind\".", | |
"Britpop emerged from the British independent music scene of the early 1990s and was characterised by bands influenced by British guitar pop music of the 1960s and 1970s."] | |
st.title("Make it Simple") | |
with st.expander("Example sentences"): | |
for s in example_sentences: | |
st.code(body=s) | |
with st.form(key="simplify"): | |
input_sentence = st.text_area("Original sentence") | |
tok = st.multiselect( | |
label="Tokens to augment the sentence", options=base_tokens, default=base_tokens) | |
if (tok): | |
st.text("Select the desired intensity") | |
for idx, t in enumerate(tok): | |
user_values[idx] = st.slider( | |
t, min_value=0., max_value=1., value=tok_values[t], step=0.1, key=t) | |
submit = st.form_submit_button("Process") | |
if (submit): | |
tokens = " ".join([t+str(v) for t, v in zip(aug_tok, user_values)]) + " " | |
output, words = full_pipeline(input_sentence, simpl_model, simpl_tok, tokens) | |
c1, c2 = st.columns([1,2]) | |
with c1: | |
st.markdown("#### Words identified as complex") | |
if words: | |
for w in words: | |
st.markdown(f"* {w}") | |
else: | |
st.markdown("None :smile:") | |
with c2: | |
st.markdown(f"#### Original Sentence:\n > {input_sentence}") | |
st.markdown(f"#### Output Sentence:\n > {output}") | |
if __name__ == '__main__': | |
main() | |