File size: 2,578 Bytes
6ce5455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys
import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore

# Convert prithivida/Splade_PP_en_v2 to onnx.
# based on this info:
# - https://github.com/naver/splade/issues/47
# - https://github.com/castorini/anserini/blob/master/docs/onnx-conversion.md


class TransformerRep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
        self.model.eval() # type: ignore
        self.fp16 = True

    def encode(self, input_ids, token_type_ids, attention_mask):
        return self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )[0]



class SpladeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerRep()
        self.agg = "max"
        self.model.eval()
    
    def forward(self, input_ids,token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(): # type: ignore
            with torch.no_grad():
                lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
                vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
                indices = vec.nonzero().squeeze()
                weights = vec.squeeze()[indices]
        return indices[:,1], weights[:,1]


if __name__ == '__main__':
    if len(sys.argv) != 2:
        print('Usage:', sys.argv[0], '<output-file-name>')
        sys.exit(1)

    # Convert the model to TorchScript
    model = SpladeModel()

    input_ids = torch.randint(1,100, size=(1,50))
    token_type_ids = torch.full((1,50), 0)
    attention_mask = torch.full((1,50), 1)
    traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask))


    dyn_axis = {
        'input_ids': {0: 'batch_size', 1: 'sequence'},
        'attention_mask': {0: 'batch_size', 1: 'sequence'},
        'token_type_ids': {0: 'batch_size', 1: 'sequence'},
        'output_idx': {0: 'batch_size', 1: 'sequence'},
        'output_weights': {0: 'batch_size', 1: 'sequence'}
    }

    onnx_model = torch.onnx.export(
        traced_model,
        (input_ids, token_type_ids, attention_mask), # type: ignore
        f=sys.argv[1],
        input_names=['input_ids','token_type_ids', 'attention_mask'],
        output_names=['output_idx', 'output_weights'],
        dynamic_axes=dyn_axis,
        do_constant_folding=True,
        opset_version=15,
        verbose=False,
    )