eva02_large_patch14_448 / 06_export_preprocessing_onnx.py
dnth's picture
Upload 11 files
88359db verified
raw
history blame
1.33 kB
from typing import List
import onnx
import torch
import torch.nn as nn
from onnxsim import simplify
class Preprocess(nn.Module):
def __init__(self, input_shape: List[int]):
super(Preprocess, self).__init__()
self.input_shape = tuple(input_shape)
self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1)
self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = torch.nn.functional.interpolate(
input=x,
size=self.input_shape[2:],
)
x = x / 255.0
x = (x - self.mean) / self.std
return x
if __name__ == "__main__":
input_shape = [1, 3, 448, 448]
output_onnx_file = "preprocessing.onnx"
model = Preprocess(input_shape=input_shape)
torch.onnx.export(
model,
torch.randn(input_shape),
output_onnx_file,
opset_version=20,
input_names=["input_rgb"],
output_names=["output_preprocessing"],
dynamic_axes={
"input_rgb": {
0: "batch_size",
2: "height",
3: "width",
},
},
)
model_onnx = onnx.load(output_onnx_file)
model_simplified, _ = simplify(model_onnx)
onnx.save(model_simplified, output_onnx_file)