|
from typing import List |
|
|
|
import onnx |
|
import torch |
|
import torch.nn as nn |
|
from onnxsim import simplify |
|
|
|
|
|
class Preprocess(nn.Module): |
|
def __init__(self, input_shape: List[int]): |
|
super(Preprocess, self).__init__() |
|
self.input_shape = tuple(input_shape) |
|
self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1) |
|
self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = torch.nn.functional.interpolate( |
|
input=x, |
|
size=self.input_shape[2:], |
|
) |
|
x = x / 255.0 |
|
x = (x - self.mean) / self.std |
|
|
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
input_shape = [1, 3, 448, 448] |
|
output_onnx_file = "preprocessing.onnx" |
|
model = Preprocess(input_shape=input_shape) |
|
|
|
torch.onnx.export( |
|
model, |
|
torch.randn(input_shape), |
|
output_onnx_file, |
|
opset_version=20, |
|
input_names=["input_rgb"], |
|
output_names=["output_preprocessing"], |
|
dynamic_axes={ |
|
"input_rgb": { |
|
0: "batch_size", |
|
2: "height", |
|
3: "width", |
|
}, |
|
}, |
|
) |
|
|
|
model_onnx = onnx.load(output_onnx_file) |
|
model_simplified, _ = simplify(model_onnx) |
|
onnx.save(model_simplified, output_onnx_file) |
|
|