object-detection / yolo_model.py
mingyang91's picture
Update UI
095f7cc verified
raw
history blame
3.17 kB
import random
from PIL import ImageDraw, Image
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
"""
Helper Functions for Plotting BBoxes
:param x:
:param img:
:param color:
:param label:
:param line_thickness:
:return:
"""
width, height = img.size
tl = line_thickness or round(0.002 * (width + height) / 2) + 1 # line/font thickness
color = color or (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
img_draw = ImageDraw.Draw(img)
img_draw.rectangle((c1[0], c1[1], c2[0], c2[1]), outline=color, width=tl)
if label:
tf = max(tl - 1, 1) # font thickness
x1, y1, x2, y2 = img_draw.textbbox(c1, label, stroke_width=tf)
img_draw.rectangle((x1, y1, x2, y2), fill=color)
img_draw.text((x1, y1), label, fill=(255, 255, 255))
def add_bboxes(pil_img, result, confidence=0.6):
"""
Plotting Bounding Box on img
:param pil_img:
:param result:
:param confidence:
:return:
"""
for box in result.boxes:
[cl] = box.cls.tolist()
[conf] = box.conf.tolist()
if conf < confidence:
continue
[rect] = box.xyxy.tolist()
text = f'{result.names[cl]}: {conf: 0.2f}'
plot_one_box(x=rect, img=pil_img, label=text)
return pil_img
class YoloModel:
def __init__(self, repo_name: str, file_name: str):
weight_file = YoloModel.download_weight_file(repo_name, file_name)
self.model = YOLO(weight_file)
@staticmethod
def download_weight_file(repo_name: str, file_name: str):
return hf_hub_download(repo_name, file_name)
def detect(self, im):
return self.model(source=im)
def preview_detect(self, im, confidence):
results = self.model(source=im)
res_img = im
for result in results:
res_img = add_bboxes(res_img, result, confidence)
return res_img
def test():
model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt")
im = Image.open("./tests/fire1.jpg")
results = model.model(source=im)
for result in results:
im = add_bboxes(im, result, confidence=0.1)
print(result.boxes)
def argument_parser():
"""
Argument Parser
:return: args
"""
import argparse
parser = argparse.ArgumentParser(description='Help for YoloModel')
parser.add_argument('--test', '-t', action='store_true', help='Run test')
# list of repo_name&file_name
parser.add_argument('--weight_files', '-w', nargs='+', help='List of weight files')
return parser.parse_args()
def pre_cache_weight_files(weight_files: list[str]):
"""
Pre-cache weight files
:return: None
"""
for weight_file in weight_files:
weight_file = weight_file.split(":")
YoloModel.download_weight_file(weight_file[0], weight_file[1])
if __name__ == '__main__':
args = argument_parser()
if args.test:
test()
else:
pre_cache_weight_files(args.weight_files)