Spaces:
Sleeping
Sleeping
File size: 1,881 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np
import torch
from fastai.vision.all import load_learner
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
def localize_trash(im, det_name, det_checkpoint, device, prob_threshold):
# detector
detector = set_model(det_name, 1, det_checkpoint, device)
detector.eval()
# mean-std normalize the input image (batch-size: 1)
img = get_transforms(im)
# propagate through the model
outputs = detector(img.to(device))
# keep only predictions above set confidence
bboxes_keep = outputs[0, outputs[0, :, 4] > prob_threshold]
probas = bboxes_keep[:, 4:]
# convert boxes to image scales
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
return probas, bboxes_scaled
def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled):
# classifier
classifier = load_learner(clas_checkpoint)
bboxes_final = []
cls_prob = []
for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()):
img = im.crop((xmin, ymin, xmax, ymax))
outputs = classifier.predict(img)
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
p[0] = torch.max(np.trunc(outputs[2] * 100))
if p[0] >= cls_th * 100:
bboxes_final.append((xmin, ymin, xmax, ymax))
cls_prob.append(p)
return cls_prob, bboxes_final
def detect_trash(
im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th
):
# prepare models for evaluation
torch.set_grad_enabled(False)
# 1) Localize
probas, bboxes_scaled = localize_trash(
im, det_name, det_checkpoint, device, prob_threshold
)
# 2) Classify
cls_prob, bboxes_final = classify_trash(
im, clas_checkpoint, cls_th, probas, bboxes_scaled
)
return cls_prob, bboxes_final
|