File size: 1,881 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import torch
from fastai.vision.all import load_learner

from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model


def localize_trash(im, det_name, det_checkpoint, device, prob_threshold):
    # detector
    detector = set_model(det_name, 1, det_checkpoint, device)
    detector.eval()
    # mean-std normalize the input image (batch-size: 1)
    img = get_transforms(im)
    # propagate through the model
    outputs = detector(img.to(device))
    # keep only predictions above set confidence
    bboxes_keep = outputs[0, outputs[0, :, 4] > prob_threshold]
    probas = bboxes_keep[:, 4:]
    # convert boxes to image scales
    bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
    return probas, bboxes_scaled


def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled):
    # classifier
    classifier = load_learner(clas_checkpoint)

    bboxes_final = []
    cls_prob = []
    for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()):
        img = im.crop((xmin, ymin, xmax, ymax))
        outputs = classifier.predict(img)
        p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
        p[0] = torch.max(np.trunc(outputs[2] * 100))
        if p[0] >= cls_th * 100:
            bboxes_final.append((xmin, ymin, xmax, ymax))
            cls_prob.append(p)
    return cls_prob, bboxes_final


def detect_trash(
    im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th
):
    # prepare models for evaluation
    torch.set_grad_enabled(False)

    # 1) Localize
    probas, bboxes_scaled = localize_trash(
        im, det_name, det_checkpoint, device, prob_threshold
    )

    # 2) Classify
    cls_prob, bboxes_final = classify_trash(
        im, clas_checkpoint, cls_th, probas, bboxes_scaled
    )

    return cls_prob, bboxes_final