|
import streamlit as st |
|
import tokenizers |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
import numpy as np |
|
import torch |
|
import nltk |
|
|
|
nltk.download("punkt") |
|
from nltk.tokenize import sent_tokenize |
|
|
|
|
|
class WorkerClassifier: |
|
def __init__( |
|
self, worker_model_dir, zero_shot_model_type="facebook/bart-large-mnli" |
|
): |
|
self.zero_shot = None |
|
self.zero_shot_model_type = zero_shot_model_type |
|
self.worker_model_dir = worker_model_dir |
|
self.id2label = { |
|
0: "lauren", |
|
1: "betty", |
|
2: "doris", |
|
3: "hailey", |
|
} |
|
self.label2id = {v: k for k, v in self.id2label.items()} |
|
|
|
def init_models(self): |
|
self.ner = self.init_anonymizer() |
|
self.zero_shot = self.init_zero_shot() |
|
self.worker_model = self.init_worker_model() |
|
self.worker_tokenizer = self.init_worker_tokenizer() |
|
|
|
@st.cache( |
|
hash_funcs={ |
|
torch.nn.parameter.Parameter: lambda _: None, |
|
tokenizers.Tokenizer: lambda _: None, |
|
tokenizers.AddedToken: lambda _: None, |
|
}, |
|
allow_output_mutation=True, |
|
) |
|
def init_worker_tokenizer(self): |
|
return AutoTokenizer.from_pretrained(self.worker_model_dir) |
|
|
|
@st.cache( |
|
hash_funcs={ |
|
torch.nn.parameter.Parameter: lambda _: None, |
|
tokenizers.Tokenizer: lambda _: None, |
|
tokenizers.AddedToken: lambda _: None, |
|
}, |
|
allow_output_mutation=True, |
|
) |
|
def init_worker_model(self): |
|
return AutoModelForSequenceClassification.from_pretrained( |
|
self.worker_model_dir, problem_type="multi_label_classification" |
|
) |
|
|
|
def predict_worker(self, text, threshold=0.5): |
|
encoding = self.worker_tokenizer(text, return_tensors="pt") |
|
outputs = self.worker_model(**encoding) |
|
|
|
logits = outputs["logits"] |
|
|
|
sigmoid = torch.nn.Sigmoid() |
|
probs = sigmoid(logits.squeeze().cpu()) |
|
predictions = np.zeros(probs.shape) |
|
predictions[np.where(probs >= threshold)] = 1 |
|
|
|
predicted_labels = [ |
|
[self.id2label[idx], probs[idx].detach().item()] |
|
for idx, label in enumerate(predictions) |
|
if label == 1.0 |
|
] |
|
return predicted_labels |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def init_anonymizer(self): |
|
return pipeline(task="ner") |
|
|
|
def anonymize(self, text: str): |
|
new_sentences = [] |
|
sentences = sent_tokenize(text) |
|
for sent in sentences: |
|
result = self.ner(sent, aggregation_strategy="simple") |
|
for r in reversed(result): |
|
if r["entity_group"] == "PER": |
|
sent = sent[: r["start"]] + "PERSON" + sent[r["end"] :] |
|
new_sentences.append(sent) |
|
|
|
return " ".join(new_sentences) |
|
|
|
@st.cache( |
|
hash_funcs={ |
|
tokenizers.Tokenizer: lambda _: None, |
|
tokenizers.AddedToken: lambda _: None, |
|
torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy(), |
|
}, |
|
allow_output_mutation=True, |
|
) |
|
def init_zero_shot(self): |
|
return pipeline( |
|
task="zero-shot-classification", model=self.zero_shot_model_type |
|
) |
|
|
|
def get_personality_sentences(self, text): |
|
new_sentences = [] |
|
sentences = sent_tokenize(text) |
|
|
|
for sent in sentences: |
|
if self.personality_sent_classifier(sent): |
|
new_sentences.append(sent) |
|
return " ".join(new_sentences) |
|
|
|
def personality_sent_classifier(self, text, threshold=0.8): |
|
candidate_labels = ["describing a personality trait."] |
|
hypothesis_template = "This example is {}" |
|
|
|
output = self.zero_shot( |
|
text, |
|
candidate_labels=candidate_labels, |
|
hypothesis_template=hypothesis_template, |
|
) |
|
|
|
if output["scores"][0] > threshold: |
|
return True |
|
return False |
|
|
|
def predict(self, text): |
|
|
|
text = self.get_personality_sentences(text) |
|
extracted_text = text |
|
|
|
|
|
text = self.anonymize(text) |
|
|
|
|
|
text = self.predict_worker(text) |
|
return extracted_text, text |
|
|