File size: 925 Bytes
8595e5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
def model_fn(model_dir):
"""
Load the model and tokenizer from the specified paths
:param model_dir:
:return:
"""
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
truncation=True, max_length=512, return_all_scores=True)
# Tokenize the input, pick up first 512 tokens before passing it further
tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
input_data = tokenizer.decode(tokens)
return bert_pipe(input_data)
|