import os import json import argparse import os.path as osp import cv2 import tqdm import torch import numpy as np import tensorflow as tf import supervision as sv from torchvision.ops import nms BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1) MASK_ANNOTATOR = sv.MaskAnnotator() class LabelAnnotator(sv.LabelAnnotator): @staticmethod def resolve_text_background_xyxy( center_coordinates, text_wh, position, ): center_x, center_y = center_coordinates text_w, text_h = text_wh return center_x, center_y, center_x + text_w, center_y + text_h LABEL_ANNOTATOR = LabelAnnotator(text_padding=4, text_scale=0.5, text_thickness=1) def parse_args(): parser = argparse.ArgumentParser('YOLO-World TFLite (INT8) Demo') parser.add_argument('path', help='TFLite Model `.tflite`') parser.add_argument('image', help='image path, include image file or dir.') parser.add_argument( 'text', help= 'detecting texts (str, txt, or json), should be consistent with the ONNX model' ) parser.add_argument('--output-dir', default='./output', help='directory to save output files') args = parser.parse_args() return args def preprocess(image, size=(640, 640)): h, w = image.shape[:2] max_size = max(h, w) scale_factor = size[0] / max_size pad_h = (max_size - h) // 2 pad_w = (max_size - w) // 2 pad_image = np.zeros((max_size, max_size, 3), dtype=image.dtype) pad_image[pad_h:h + pad_h, pad_w:w + pad_w] = image image = cv2.resize(pad_image, size, interpolation=cv2.INTER_LINEAR).astype('float32') image /= 255.0 image = image[None] return image, scale_factor, (pad_h, pad_w) def generate_anchors_per_level(feat_size, stride, offset=0.5): h, w = feat_size shift_x = (torch.arange(0, w) + offset) * stride shift_y = (torch.arange(0, h) + offset) * stride yy, xx = torch.meshgrid(shift_y, shift_x) anchors = torch.stack([xx, yy]).reshape(2, -1).transpose(0, 1) return anchors def generate_anchors(feat_sizes=[(80, 80), (40, 40), (20, 20)], strides=[8, 16, 32], offset=0.5): anchors = [ generate_anchors_per_level(fs, s, offset) for fs, s in zip(feat_sizes, strides) ] anchors = torch.cat(anchors) return anchors def simple_bbox_decode(points, pred_bboxes, stride): pred_bboxes = pred_bboxes * stride[None, :, None] x1 = points[..., 0] - pred_bboxes[..., 0] y1 = points[..., 1] - pred_bboxes[..., 1] x2 = points[..., 0] + pred_bboxes[..., 2] y2 = points[..., 1] + pred_bboxes[..., 3] bboxes = torch.stack([x1, y1, x2, y2], -1) return bboxes def visualize(image, bboxes, labels, scores, texts): detections = sv.Detections(xyxy=bboxes, class_id=labels, confidence=scores) labels = [ f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in zip(detections.class_id, detections.confidence) ] image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections) image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels) return image def inference_per_sample(interp, image_path, texts, priors, strides, output_dir, size=(640, 640), vis=False, score_thr=0.05, nms_thr=0.3, max_dets=300): # input / output details from TFLite input_details = interp.get_input_details() output_details = interp.get_output_details() # load image from path ori_image = cv2.imread(image_path) h, w = ori_image.shape[:2] image, scale_factor, pad_param = preprocess(ori_image[:, :, [2, 1, 0]], size) # inference interp.set_tensor(input_details[0]['index'], image) interp.invoke() scores = interp.get_tensor(output_details[1]['index']) bboxes = interp.get_tensor(output_details[0]['index']) # can be converted to numpy for other devices # using torch here is only for references. ori_scores = torch.from_numpy(scores[0]) ori_bboxes = torch.from_numpy(bboxes) # decode bbox cordinates with priors decoded_bboxes = simple_bbox_decode(priors, ori_bboxes, strides)[0] scores_list = [] labels_list = [] bboxes_list = [] for cls_id in range(len(texts)): cls_scores = ori_scores[:, cls_id] labels = torch.ones(cls_scores.shape[0], dtype=torch.long) * cls_id keep_idxs = nms(decoded_bboxes, cls_scores, iou_threshold=0.5) cur_bboxes = decoded_bboxes[keep_idxs] cls_scores = cls_scores[keep_idxs] labels = labels[keep_idxs] scores_list.append(cls_scores) labels_list.append(labels) bboxes_list.append(cur_bboxes) scores = torch.cat(scores_list, dim=0) labels = torch.cat(labels_list, dim=0) bboxes = torch.cat(bboxes_list, dim=0) keep_idxs = scores > score_thr scores = scores[keep_idxs] labels = labels[keep_idxs] bboxes = bboxes[keep_idxs] # only for visualization, add an extra NMS keep_idxs = nms(bboxes, scores, iou_threshold=nms_thr) num_dets = min(len(keep_idxs), max_dets) bboxes = bboxes[keep_idxs].unsqueeze(0) scores = scores[keep_idxs].unsqueeze(0) labels = labels[keep_idxs].unsqueeze(0) scores = scores[0, :num_dets].numpy() bboxes = bboxes[0, :num_dets].numpy() labels = labels[0, :num_dets].numpy() bboxes -= np.array( [pad_param[1], pad_param[0], pad_param[1], pad_param[0]]) bboxes /= scale_factor bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, w) bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, h) if vis: image_out = visualize(ori_image, bboxes, labels, scores, texts) cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image_out) print(f"detecting {num_dets} objects.") return image_out, ori_scores, ori_bboxes[0] else: return bboxes, labels, scores def main(): args = parse_args() tflite_file = args.tflite # init ONNX session interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True) interpreter.allocate_tensors() print("Init TFLite Interpter") output_dir = "onnx_outputs" if not osp.exists(output_dir): os.mkdir(output_dir) # load images if not osp.isfile(args.image): images = [ osp.join(args.image, img) for img in os.listdir(args.image) if img.endswith('.png') or img.endswith('.jpg') ] else: images = [args.image] if args.text.endswith('.txt'): with open(args.text) as f: lines = f.readlines() texts = [[t.rstrip('\r\n')] for t in lines] elif args.text.endswith('.json'): texts = json.load(open(args.text)) else: texts = [[t.strip()] for t in args.text.split(',')] size = (640, 640) strides = [8, 16, 32] # prepare anchors, since TFLite models does not contain anchors, due to INT8 quantization. featmap_sizes = [(size[0] // s, size[1] // s) for s in strides] flatten_priors = generate_anchors(featmap_sizes, strides=strides) mlvl_strides = [ flatten_priors.new_full((featmap_size[0] * featmap_size[1] * 1, ), stride) for featmap_size, stride in zip(featmap_sizes, strides) ] flatten_strides = torch.cat(mlvl_strides) print("Start to inference.") for img in tqdm.tqdm(images): inference_per_sample(interpreter, img, texts, flatten_priors[None], flatten_strides, output_dir=output_dir, vis=True, score_thr=0.3, nms_thr=0.5) print("Finish inference") if __name__ == "__main__": main()