|
print("Loading Multi head pipeline") |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
class CustomTextClassificationPipeline(TextClassificationPipeline): |
|
def __init__(self, model, tokenizer=None, **kwargs): |
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
return self.tokenizer(inputs, return_tensors='pt', truncation=False) |
|
|
|
def _forward(self, model_inputs): |
|
input_ids = model_inputs['input_ids'] |
|
attention_mask = (input_ids != 0).long() |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist() |
|
categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"] |
|
return dict(zip(categories, predictions)) |
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
"multi-head-text-classification", |
|
pipeline_class=CustomTextClassificationPipeline, |
|
pt_model=AutoModelForSequenceClassification, |
|
) |