Spaces:
Runtime error
Runtime error
# # Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
import json | |
import warnings | |
import argparse | |
from io import BytesIO | |
import onnx | |
import torch | |
from mmdet.apis import init_detector | |
from mmengine.config import ConfigDict | |
from mmengine.logging import print_log | |
from mmengine.utils.path import mkdir_or_exist | |
from easydeploy.model import DeployModel, MMYOLOBackend # noqa E402 | |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) | |
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) | |
warnings.filterwarnings(action='ignore', category=UserWarning) | |
warnings.filterwarnings(action='ignore', category=FutureWarning) | |
warnings.filterwarnings(action='ignore', category=ResourceWarning) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('config', help='Config file') | |
parser.add_argument('checkpoint', help='Checkpoint file') | |
parser.add_argument('--custom-text', | |
type=str, | |
help='custom text inputs (text json) for YOLO-World.') | |
parser.add_argument('--add-padding', | |
action="store_true", | |
help="add an empty padding to texts.") | |
parser.add_argument('--model-only', | |
action='store_true', | |
help='Export model only') | |
parser.add_argument('--without-nms', | |
action='store_true', | |
help='Export model without NMS') | |
parser.add_argument('--without-bbox-decoder', | |
action='store_true', | |
help='Export model without Bbox Decoder (for INT8 Quantization)') | |
parser.add_argument('--work-dir', | |
default='./work_dirs', | |
help='Path to save export model') | |
parser.add_argument('--img-size', | |
nargs='+', | |
type=int, | |
default=[640, 640], | |
help='Image size of height and width') | |
parser.add_argument('--batch-size', type=int, default=1, help='Batch size') | |
parser.add_argument('--device', | |
default='cuda:0', | |
help='Device used for inference') | |
parser.add_argument('--simplify', | |
action='store_true', | |
help='Simplify onnx model by onnx-sim') | |
parser.add_argument('--opset', | |
type=int, | |
default=11, | |
help='ONNX opset version') | |
parser.add_argument('--backend', | |
type=str, | |
default='onnxruntime', | |
help='Backend for export onnx') | |
parser.add_argument('--pre-topk', | |
type=int, | |
default=1000, | |
help='Postprocess pre topk bboxes feed into NMS') | |
parser.add_argument('--keep-topk', | |
type=int, | |
default=100, | |
help='Postprocess keep topk bboxes out of NMS') | |
parser.add_argument('--iou-threshold', | |
type=float, | |
default=0.65, | |
help='IoU threshold for NMS') | |
parser.add_argument('--score-threshold', | |
type=float, | |
default=0.25, | |
help='Score threshold for NMS') | |
args = parser.parse_args() | |
args.img_size *= 2 if len(args.img_size) == 1 else 1 | |
return args | |
def build_model_from_cfg(config_path, checkpoint_path, device): | |
model = init_detector(config_path, checkpoint_path, device=device) | |
model.eval() | |
return model | |
def main(): | |
args = parse_args() | |
mkdir_or_exist(args.work_dir) | |
backend = MMYOLOBackend(args.backend.lower()) | |
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO, | |
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): | |
if not args.model_only: | |
print_log('Export ONNX with bbox decoder and NMS ...') | |
else: | |
args.model_only = True | |
print_log(f'Can not export postprocess for {args.backend.lower()}.\n' | |
f'Set "args.model_only=True" default.') | |
if args.model_only: | |
postprocess_cfg = None | |
output_names = None | |
else: | |
postprocess_cfg = ConfigDict(pre_top_k=args.pre_topk, | |
keep_top_k=args.keep_topk, | |
iou_threshold=args.iou_threshold, | |
score_threshold=args.score_threshold) | |
output_names = ['num_dets', 'boxes', 'scores', 'labels'] | |
if args.without_bbox_decoder or args.without_nms: | |
output_names = ['scores', 'boxes'] | |
if args.custom_text is not None and len(args.custom_text) > 0: | |
with open(args.custom_text) as f: | |
texts = json.load(f) | |
texts = [x[0] for x in texts] | |
else: | |
from mmdet.datasets import CocoDataset | |
texts = CocoDataset.METAINFO['classes'] | |
if args.add_padding: | |
texts = texts + [' '] | |
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) | |
if hasattr(baseModel, 'reparameterize'): | |
# reparameterize text into YOLO-World | |
baseModel.reparameterize([texts]) | |
deploy_model = DeployModel(baseModel=baseModel, | |
backend=backend, | |
postprocess_cfg=postprocess_cfg, | |
with_nms=not args.without_nms, | |
without_bbox_decoder=args.without_bbox_decoder) | |
deploy_model.eval() | |
fake_input = torch.randn(args.batch_size, 3, | |
*args.img_size).to(args.device) | |
# dry run | |
deploy_model(fake_input) | |
save_onnx_path = os.path.join( | |
args.work_dir, | |
os.path.basename(args.checkpoint).replace('pth', 'onnx')) | |
# export onnx | |
with BytesIO() as f: | |
torch.onnx.export(deploy_model, | |
fake_input, | |
f, | |
input_names=['images'], | |
output_names=output_names, | |
opset_version=args.opset) | |
f.seek(0) | |
onnx_model = onnx.load(f) | |
onnx.checker.check_model(onnx_model) | |
# Fix tensorrt onnx output shape, just for view | |
if not args.model_only and not args.without_nms and backend in ( | |
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): | |
shapes = [ | |
args.batch_size, 1, args.batch_size, args.keep_topk, 4, | |
args.batch_size, args.keep_topk, args.batch_size, | |
args.keep_topk | |
] | |
for i in onnx_model.graph.output: | |
for j in i.type.tensor_type.shape.dim: | |
j.dim_param = str(shapes.pop(0)) | |
if args.simplify: | |
try: | |
import onnxsim | |
onnx_model, check = onnxsim.simplify(onnx_model) | |
assert check, 'assert check failed' | |
except Exception as e: | |
print_log(f'Simplify failure: {e}') | |
onnx.save(onnx_model, save_onnx_path) | |
print_log(f'ONNX export success, save into {save_onnx_path}') | |
if __name__ == '__main__': | |
main() | |