Spaces:
Sleeping
Sleeping
# -------------------------------------------------------- | |
# Python Single Object Tracking Evaluation | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Fangyi Zhang | |
# @author [email protected] | |
# @project https://github.com/StrangerZhang/pysot-toolkit.git | |
# Revised for SiamMask by foolwood | |
# -------------------------------------------------------- | |
import warnings | |
import itertools | |
import numpy as np | |
from colorama import Style, Fore | |
from ..utils import calculate_failures, calculate_accuracy | |
class AccuracyRobustnessBenchmark: | |
""" | |
Args: | |
dataset: | |
burnin: | |
""" | |
def __init__(self, dataset, burnin=10): | |
self.dataset = dataset | |
self.burnin = burnin | |
def eval(self, eval_trackers=None): | |
""" | |
Args: | |
eval_tags: list of tag | |
eval_trackers: list of tracker name | |
Returns: | |
ret: dict of results | |
""" | |
if eval_trackers is None: | |
eval_trackers = self.dataset.tracker_names | |
if isinstance(eval_trackers, str): | |
eval_trackers = [eval_trackers] | |
result = {} | |
for tracker_name in eval_trackers: | |
accuracy, failures = self._calculate_accuracy_robustness(tracker_name) | |
result[tracker_name] = {'overlaps': accuracy, | |
'failures': failures} | |
return result | |
def show_result(self, result, eao_result=None, show_video_level=False, helight_threshold=0.5): | |
"""pretty print result | |
Args: | |
result: returned dict from function eval | |
""" | |
tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12) | |
if eao_result is not None: | |
header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|{:^7}|" | |
header = header.format('Tracker Name', | |
'Accuracy', 'Robustness', 'Lost Number', 'EAO') | |
formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|" | |
else: | |
header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|" | |
header = header.format('Tracker Name', | |
'Accuracy', 'Robustness', 'Lost Number') | |
formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|" | |
bar = '-'*len(header) | |
print(bar) | |
print(header) | |
print(bar) | |
if eao_result is not None: | |
tracker_eao = sorted(eao_result.items(), | |
key=lambda x:x[1]['all'], | |
reverse=True)[:20] | |
tracker_names = [x[0] for x in tracker_eao] | |
else: | |
tracker_names = list(result.keys()) | |
for tracker_name in tracker_names: | |
ret = result[tracker_name] | |
overlaps = list(itertools.chain(*ret['overlaps'].values())) | |
accuracy = np.nanmean(overlaps) | |
length = sum([len(x) for x in ret['overlaps'].values()]) | |
failures = list(ret['failures'].values()) | |
lost_number = np.mean(np.sum(failures, axis=0)) | |
robustness = np.mean(np.sum(np.array(failures), axis=0) / length) * 100 | |
if eao_result is None: | |
print(formatter.format(tracker_name, accuracy, robustness, lost_number)) | |
else: | |
print(formatter.format(tracker_name, accuracy, robustness, lost_number, eao_result[tracker_name]['all'])) | |
print(bar) | |
if show_video_level and len(result) < 10: | |
print('\n\n') | |
header1 = "|{:^14}|".format("Tracker name") | |
header2 = "|{:^14}|".format("Video name") | |
for tracker_name in result.keys(): | |
header1 += ("{:^17}|").format(tracker_name) | |
header2 += "{:^8}|{:^8}|".format("Acc", "LN") | |
print('-'*len(header1)) | |
print(header1) | |
print('-'*len(header1)) | |
print(header2) | |
print('-'*len(header1)) | |
videos = list(result[tracker_name]['overlaps'].keys()) | |
for video in videos: | |
row = "|{:^14}|".format(video) | |
for tracker_name in result.keys(): | |
overlaps = result[tracker_name]['overlaps'][video] | |
accuracy = np.nanmean(overlaps) | |
failures = result[tracker_name]['failures'][video] | |
lost_number = np.mean(failures) | |
accuracy_str = "{:^8.3f}".format(accuracy) | |
if accuracy < helight_threshold: | |
row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|' | |
else: | |
row += accuracy_str+'|' | |
lost_num_str = "{:^8.3f}".format(lost_number) | |
if lost_number > 0: | |
row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|' | |
else: | |
row += lost_num_str+'|' | |
print(row) | |
print('-'*len(header1)) | |
def _calculate_accuracy_robustness(self, tracker_name): | |
overlaps = {} | |
failures = {} | |
all_length = {} | |
for i in range(len(self.dataset)): | |
video = self.dataset[i] | |
gt_traj = video.gt_traj | |
if tracker_name not in video.pred_trajs: | |
tracker_trajs = video.load_tracker(self.dataset.tracker_path, tracker_name, False) | |
else: | |
tracker_trajs = video.pred_trajs[tracker_name] | |
overlaps_group = [] | |
num_failures_group = [] | |
for tracker_traj in tracker_trajs: | |
num_failures = calculate_failures(tracker_traj)[0] | |
overlaps_ = calculate_accuracy(tracker_traj, gt_traj, | |
burnin=10, bound=(video.width, video.height))[1] | |
overlaps_group.append(overlaps_) | |
num_failures_group.append(num_failures) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore", category=RuntimeWarning) | |
overlaps[video.name] = np.nanmean(overlaps_group, axis=0).tolist() | |
failures[video.name] = num_failures_group | |
return overlaps, failures | |