import random from PIL import ImageDraw, Image from huggingface_hub import hf_hub_download from ultralytics import YOLO def plot_one_box(x, img, color=None, label=None, line_thickness=None): """ Helper Functions for Plotting BBoxes :param x: :param img: :param color: :param label: :param line_thickness: :return: """ width, height = img.size tl = line_thickness or round(0.002 * (width + height) / 2) + 1 # line/font thickness color = color or (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) img_draw = ImageDraw.Draw(img) img_draw.rectangle((c1[0], c1[1], c2[0], c2[1]), outline=color, width=tl) if label: tf = max(tl - 1, 1) # font thickness x1, y1, x2, y2 = img_draw.textbbox(c1, label, stroke_width=tf) img_draw.rectangle((x1, y1, x2, y2), fill=color) img_draw.text((x1, y1), label, fill=(255, 255, 255)) def add_bboxes(pil_img, result, confidence=0.6): """ Plotting Bounding Box on img :param pil_img: :param result: :param confidence: :return: """ for box in result.boxes: [cl] = box.cls.tolist() [conf] = box.conf.tolist() if conf < confidence: continue [rect] = box.xyxy.tolist() text = f'{result.names[cl]}: {conf: 0.2f}' plot_one_box(x=rect, img=pil_img, label=text) return pil_img class YoloModel: def __init__(self, repo_name: str, file_name: str): weight_file = YoloModel.download_weight_file(repo_name, file_name) self.model = YOLO(weight_file) @staticmethod def download_weight_file(repo_name: str, file_name: str): return hf_hub_download(repo_name, file_name) def detect(self, im): return self.model(source=im) def preview_detect(self, im, confidence): results = self.model(source=im) res_img = im for result in results: res_img = add_bboxes(res_img, result, confidence) return res_img def test(): model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt") im = Image.open("./tests/fire1.jpg") results = model.model(source=im) for result in results: im = add_bboxes(im, result, confidence=0.1) print(result.boxes) def argument_parser(): """ Argument Parser :return: args """ import argparse parser = argparse.ArgumentParser(description='Help for YoloModel') parser.add_argument('--test', '-t', action='store_true', help='Run test') # list of repo_name&file_name parser.add_argument('--weight_files', '-w', nargs='+', help='List of weight files') return parser.parse_args() def pre_cache_weight_files(weight_files: list[str]): """ Pre-cache weight files :return: None """ for weight_file in weight_files: weight_file = weight_file.split(":") YoloModel.download_weight_file(weight_file[0], weight_file[1]) if __name__ == '__main__': args = argument_parser() if args.test: test() else: pre_cache_weight_files(args.weight_files)