Splade_PP_en_v2_onnx / splade_pp_en_v2_to_onnx.py
jmg2016's picture
initial commit of onnx file and conversion script
6ce5455
import sys
import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore
# Convert prithivida/Splade_PP_en_v2 to onnx.
# based on this info:
# - https://github.com/naver/splade/issues/47
# - https://github.com/castorini/anserini/blob/master/docs/onnx-conversion.md
class TransformerRep(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
self.model.eval() # type: ignore
self.fp16 = True
def encode(self, input_ids, token_type_ids, attention_mask):
return self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask
)[0]
class SpladeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = TransformerRep()
self.agg = "max"
self.model.eval()
def forward(self, input_ids,token_type_ids, attention_mask):
with torch.cuda.amp.autocast(): # type: ignore
with torch.no_grad():
lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
indices = vec.nonzero().squeeze()
weights = vec.squeeze()[indices]
return indices[:,1], weights[:,1]
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage:', sys.argv[0], '<output-file-name>')
sys.exit(1)
# Convert the model to TorchScript
model = SpladeModel()
input_ids = torch.randint(1,100, size=(1,50))
token_type_ids = torch.full((1,50), 0)
attention_mask = torch.full((1,50), 1)
traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask))
dyn_axis = {
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
'output_idx': {0: 'batch_size', 1: 'sequence'},
'output_weights': {0: 'batch_size', 1: 'sequence'}
}
onnx_model = torch.onnx.export(
traced_model,
(input_ids, token_type_ids, attention_mask), # type: ignore
f=sys.argv[1],
input_names=['input_ids','token_type_ids', 'attention_mask'],
output_names=['output_idx', 'output_weights'],
dynamic_axes=dyn_axis,
do_constant_folding=True,
opset_version=15,
verbose=False,
)