File size: 1,880 Bytes
684e6f5
095f7cc
 
 
684e6f5
095f7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
684e6f5
 
 
 
095f7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from PIL import Image
from huggingface_hub import hf_hub_download
from ultralytics import YOLO

from models.tools.draw import add_bboxes


class YoloModel:
    def __init__(self, repo_name: str, file_name: str):
        weight_file = YoloModel.download_weight_file(repo_name, file_name)
        self.model = YOLO(weight_file)

    @staticmethod
    def download_weight_file(repo_name: str, file_name: str):
        return hf_hub_download(repo_name, file_name)

    def detect(self, im):
        return self.model(source=im)

    def preview_detect(self, filename, confidence):
        image = Image.open(filename)
        results = self.model(source=image)
        res_img = image
        for result in results:
            res_img = add_bboxes(res_img, result, confidence)
        return res_img


def test():
    model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt")
    im = Image.open("./tests/fire1.jpg")
    results = model.model(source=im)
    for result in results:
        im = add_bboxes(im, result, confidence=0.1)
        print(result.boxes)


def argument_parser():
    """
    Argument Parser
    :return: args
    """
    import argparse
    parser = argparse.ArgumentParser(description='Help for YoloModel')
    parser.add_argument('--test', '-t', action='store_true', help='Run test')
    # list of repo_name&file_name
    parser.add_argument('--weight_files', '-w', nargs='+', help='List of weight files')
    return parser.parse_args()


def pre_cache_weight_files(weight_files: list[str]):
    """
    Pre-cache weight files
    :return: None
    """
    for weight_file in weight_files:
        weight_file = weight_file.split(":")
        YoloModel.download_weight_file(weight_file[0], weight_file[1])


if __name__ == '__main__':
    args = argument_parser()
    if args.test:
        test()
    else:
        pre_cache_weight_files(args.weight_files)