print("Loading Multi head pipeline") from transformers.pipelines import PIPELINE_REGISTRY from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification class CustomTextClassificationPipeline(TextClassificationPipeline): def __init__(self, model, tokenizer=None, **kwargs): if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) super().__init__(model=model, tokenizer=tokenizer, **kwargs) def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): return self.tokenizer(inputs, return_tensors='pt', truncation=False) def _forward(self, model_inputs): input_ids = model_inputs['input_ids'] attention_mask = (input_ids != 0).long() outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) return outputs def postprocess(self, model_outputs): predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist() categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"] return dict(zip(categories, predictions)) PIPELINE_REGISTRY.register_pipeline( "multi-head-text-classification", pipeline_class=CustomTextClassificationPipeline, pt_model=AutoModelForSequenceClassification, )