Spaces:
Runtime error
Runtime error
File size: 7,303 Bytes
f5fdf51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# # Copyright (c) OpenMMLab. All rights reserved.
import os
import json
import warnings
import argparse
from io import BytesIO
import onnx
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist
from easydeploy.model import DeployModel, MMYOLOBackend # noqa E402
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--custom-text',
type=str,
help='custom text inputs (text json) for YOLO-World.')
parser.add_argument('--add-padding',
action="store_true",
help="add an empty padding to texts.")
parser.add_argument('--model-only',
action='store_true',
help='Export model only')
parser.add_argument('--without-nms',
action='store_true',
help='Export model without NMS')
parser.add_argument('--without-bbox-decoder',
action='store_true',
help='Export model without Bbox Decoder (for INT8 Quantization)')
parser.add_argument('--work-dir',
default='./work_dirs',
help='Path to save export model')
parser.add_argument('--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument('--device',
default='cuda:0',
help='Device used for inference')
parser.add_argument('--simplify',
action='store_true',
help='Simplify onnx model by onnx-sim')
parser.add_argument('--opset',
type=int,
default=11,
help='ONNX opset version')
parser.add_argument('--backend',
type=str,
default='onnxruntime',
help='Backend for export onnx')
parser.add_argument('--pre-topk',
type=int,
default=1000,
help='Postprocess pre topk bboxes feed into NMS')
parser.add_argument('--keep-topk',
type=int,
default=100,
help='Postprocess keep topk bboxes out of NMS')
parser.add_argument('--iou-threshold',
type=float,
default=0.65,
help='IoU threshold for NMS')
parser.add_argument('--score-threshold',
type=float,
default=0.25,
help='Score threshold for NMS')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model
def main():
args = parse_args()
mkdir_or_exist(args.work_dir)
backend = MMYOLOBackend(args.backend.lower())
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
args.model_only = True
print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
f'Set "args.model_only=True" default.')
if args.model_only:
postprocess_cfg = None
output_names = None
else:
postprocess_cfg = ConfigDict(pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold)
output_names = ['num_dets', 'boxes', 'scores', 'labels']
if args.without_bbox_decoder or args.without_nms:
output_names = ['scores', 'boxes']
if args.custom_text is not None and len(args.custom_text) > 0:
with open(args.custom_text) as f:
texts = json.load(f)
texts = [x[0] for x in texts]
else:
from mmdet.datasets import CocoDataset
texts = CocoDataset.METAINFO['classes']
if args.add_padding:
texts = texts + [' ']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
if hasattr(baseModel, 'reparameterize'):
# reparameterize text into YOLO-World
baseModel.reparameterize([texts])
deploy_model = DeployModel(baseModel=baseModel,
backend=backend,
postprocess_cfg=postprocess_cfg,
with_nms=not args.without_nms,
without_bbox_decoder=args.without_bbox_decoder)
deploy_model.eval()
fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)
save_onnx_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
# export onnx
with BytesIO() as f:
torch.onnx.export(deploy_model,
fake_input,
f,
input_names=['images'],
output_names=output_names,
opset_version=args.opset)
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
# Fix tensorrt onnx output shape, just for view
if not args.model_only and not args.without_nms and backend in (
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
args.keep_topk
]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print_log(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print_log(f'ONNX export success, save into {save_onnx_path}')
if __name__ == '__main__':
main()
|