File size: 1,333 Bytes
88359db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import List
import onnx
import torch
import torch.nn as nn
from onnxsim import simplify
class Preprocess(nn.Module):
def __init__(self, input_shape: List[int]):
super(Preprocess, self).__init__()
self.input_shape = tuple(input_shape)
self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1)
self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = torch.nn.functional.interpolate(
input=x,
size=self.input_shape[2:],
)
x = x / 255.0
x = (x - self.mean) / self.std
return x
if __name__ == "__main__":
input_shape = [1, 3, 448, 448]
output_onnx_file = "preprocessing.onnx"
model = Preprocess(input_shape=input_shape)
torch.onnx.export(
model,
torch.randn(input_shape),
output_onnx_file,
opset_version=20,
input_names=["input_rgb"],
output_names=["output_preprocessing"],
dynamic_axes={
"input_rgb": {
0: "batch_size",
2: "height",
3: "width",
},
},
)
model_onnx = onnx.load(output_onnx_file)
model_simplified, _ = simplify(model_onnx)
onnx.save(model_simplified, output_onnx_file)
|