workers / worker.py
joshcx's picture
Added nltk punkt download.
b620f45
import streamlit as st
import tokenizers
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize
class WorkerClassifier:
def __init__(
self, worker_model_dir, zero_shot_model_type="facebook/bart-large-mnli"
):
self.zero_shot = None
self.zero_shot_model_type = zero_shot_model_type
self.worker_model_dir = worker_model_dir
self.id2label = {
0: "lauren",
1: "betty",
2: "doris",
3: "hailey",
}
self.label2id = {v: k for k, v in self.id2label.items()}
def init_models(self):
self.ner = self.init_anonymizer()
self.zero_shot = self.init_zero_shot()
self.worker_model = self.init_worker_model()
self.worker_tokenizer = self.init_worker_tokenizer()
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
},
allow_output_mutation=True,
)
def init_worker_tokenizer(self):
return AutoTokenizer.from_pretrained(self.worker_model_dir)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
},
allow_output_mutation=True,
)
def init_worker_model(self):
return AutoModelForSequenceClassification.from_pretrained(
self.worker_model_dir, problem_type="multi_label_classification"
)
def predict_worker(self, text, threshold=0.5):
encoding = self.worker_tokenizer(text, return_tensors="pt")
outputs = self.worker_model(**encoding)
logits = outputs["logits"]
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= threshold)] = 1
# turn predicted id's into actual label names
predicted_labels = [
[self.id2label[idx], probs[idx].detach().item()]
for idx, label in enumerate(predictions)
if label == 1.0
]
return predicted_labels
@st.cache(allow_output_mutation=True)
def init_anonymizer(self):
return pipeline(task="ner")
def anonymize(self, text: str):
new_sentences = []
sentences = sent_tokenize(text)
for sent in sentences:
result = self.ner(sent, aggregation_strategy="simple")
for r in reversed(result):
if r["entity_group"] == "PER":
sent = sent[: r["start"]] + "PERSON" + sent[r["end"] :]
new_sentences.append(sent)
return " ".join(new_sentences)
@st.cache(
hash_funcs={
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy(),
},
allow_output_mutation=True,
)
def init_zero_shot(self):
return pipeline(
task="zero-shot-classification", model=self.zero_shot_model_type
)
def get_personality_sentences(self, text):
new_sentences = []
sentences = sent_tokenize(text)
for sent in sentences:
if self.personality_sent_classifier(sent):
new_sentences.append(sent)
return " ".join(new_sentences)
def personality_sent_classifier(self, text, threshold=0.8):
candidate_labels = ["describing a personality trait."]
hypothesis_template = "This example is {}"
output = self.zero_shot(
text,
candidate_labels=candidate_labels,
hypothesis_template=hypothesis_template,
)
# print(f'{text} with score {output["scores"][0]}\n')
if output["scores"][0] > threshold:
return True
return False
def predict(self, text):
# first extract sentences that are relevant to personalities
text = self.get_personality_sentences(text)
extracted_text = text
# next anonymize the sentences
text = self.anonymize(text)
# classify text
text = self.predict_worker(text)
return extracted_text, text