File size: 1,333 Bytes
88359db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import List

import onnx
import torch
import torch.nn as nn
from onnxsim import simplify


class Preprocess(nn.Module):
    def __init__(self, input_shape: List[int]):
        super(Preprocess, self).__init__()
        self.input_shape = tuple(input_shape)
        self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1)

    def forward(self, x: torch.Tensor):
        x = torch.nn.functional.interpolate(
            input=x,
            size=self.input_shape[2:],
        )
        x = x / 255.0
        x = (x - self.mean) / self.std

        return x


if __name__ == "__main__":
    input_shape = [1, 3, 448, 448]
    output_onnx_file = "preprocessing.onnx"
    model = Preprocess(input_shape=input_shape)

    torch.onnx.export(
        model,
        torch.randn(input_shape),
        output_onnx_file,
        opset_version=20,
        input_names=["input_rgb"],
        output_names=["output_preprocessing"],
        dynamic_axes={
            "input_rgb": {
                0: "batch_size",
                2: "height",
                3: "width",
            },
        },
    )

    model_onnx = onnx.load(output_onnx_file)
    model_simplified, _ = simplify(model_onnx)
    onnx.save(model_simplified, output_onnx_file)