Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from fastai.vision.all import load_learner | |
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model | |
def localize_trash(im, det_name, det_checkpoint, device, prob_threshold): | |
# detector | |
detector = set_model(det_name, 1, det_checkpoint, device) | |
detector.eval() | |
# mean-std normalize the input image (batch-size: 1) | |
img = get_transforms(im) | |
# propagate through the model | |
outputs = detector(img.to(device)) | |
# keep only predictions above set confidence | |
bboxes_keep = outputs[0, outputs[0, :, 4] > prob_threshold] | |
probas = bboxes_keep[:, 4:] | |
# convert boxes to image scales | |
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:])) | |
return probas, bboxes_scaled | |
def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled): | |
# classifier | |
classifier = load_learner(clas_checkpoint) | |
bboxes_final = [] | |
cls_prob = [] | |
for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()): | |
img = im.crop((xmin, ymin, xmax, ymax)) | |
outputs = classifier.predict(img) | |
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item() | |
p[0] = torch.max(np.trunc(outputs[2] * 100)) | |
if p[0] >= cls_th * 100: | |
bboxes_final.append((xmin, ymin, xmax, ymax)) | |
cls_prob.append(p) | |
return cls_prob, bboxes_final | |
def detect_trash( | |
im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th | |
): | |
# prepare models for evaluation | |
torch.set_grad_enabled(False) | |
# 1) Localize | |
probas, bboxes_scaled = localize_trash( | |
im, det_name, det_checkpoint, device, prob_threshold | |
) | |
# 2) Classify | |
cls_prob, bboxes_final = classify_trash( | |
im, clas_checkpoint, cls_th, probas, bboxes_scaled | |
) | |
return cls_prob, bboxes_final | |