|
import sys |
|
import torch |
|
from transformers import AutoModelForMaskedLM,AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerRep(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2') |
|
self.model.eval() |
|
self.fp16 = True |
|
|
|
def encode(self, input_ids, token_type_ids, attention_mask): |
|
return self.model( |
|
input_ids=input_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask |
|
)[0] |
|
|
|
|
|
|
|
class SpladeModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = TransformerRep() |
|
self.agg = "max" |
|
self.model.eval() |
|
|
|
def forward(self, input_ids,token_type_ids, attention_mask): |
|
with torch.cuda.amp.autocast(): |
|
with torch.no_grad(): |
|
lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0] |
|
vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1) |
|
indices = vec.nonzero().squeeze() |
|
weights = vec.squeeze()[indices] |
|
return indices[:,1], weights[:,1] |
|
|
|
|
|
if __name__ == '__main__': |
|
if len(sys.argv) != 2: |
|
print('Usage:', sys.argv[0], '<output-file-name>') |
|
sys.exit(1) |
|
|
|
|
|
model = SpladeModel() |
|
|
|
input_ids = torch.randint(1,100, size=(1,50)) |
|
token_type_ids = torch.full((1,50), 0) |
|
attention_mask = torch.full((1,50), 1) |
|
traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask)) |
|
|
|
|
|
dyn_axis = { |
|
'input_ids': {0: 'batch_size', 1: 'sequence'}, |
|
'attention_mask': {0: 'batch_size', 1: 'sequence'}, |
|
'token_type_ids': {0: 'batch_size', 1: 'sequence'}, |
|
'output_idx': {0: 'batch_size', 1: 'sequence'}, |
|
'output_weights': {0: 'batch_size', 1: 'sequence'} |
|
} |
|
|
|
onnx_model = torch.onnx.export( |
|
traced_model, |
|
(input_ids, token_type_ids, attention_mask), |
|
f=sys.argv[1], |
|
input_names=['input_ids','token_type_ids', 'attention_mask'], |
|
output_names=['output_idx', 'output_weights'], |
|
dynamic_axes=dyn_axis, |
|
do_constant_folding=True, |
|
opset_version=15, |
|
verbose=False, |
|
) |
|
|