Source code for tests.metrics_utils

"""
Classes and methods to compute metrics on the RAM extraction method, \
    based on the ground truth vision extracted objects.
"""

import numpy as np
from termcolor import colored
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import euclidean, cdist


np.set_printoptions(precision=1)
USE_IPDB = False
MIN_DIST_DETECTION = 5


def format_values(dictionary):
    keys, values = dictionary.keys(), dictionary.values()
    return {key: round(val, 1) for key, val in zip(keys, (np.array(list(values)) * 100))}


[docs] class DetectionScores(): """ Keeps track of true positives, false positive and false negative of the detected objects. Allow to fast compute precision, recall, and f1 scores. """ def __init__(self) -> None: self.true_pos = {} self.false_pos = {} self.false_neg = {}
[docs] def update(self, det_dict): """ """ for cat, (TP, FP, FN) in det_dict.items(): if not cat in self.true_pos: self.true_pos[cat] = TP self.false_pos[cat] = FP self.false_neg[cat] = FN else: self.true_pos[cat] += TP self.false_pos[cat] += FP self.false_neg[cat] += FN
@property def cat_precisions(self): """ The per cartegory dictionary of the precision """ return {cat: self.true_pos[cat]/(self.true_pos[cat]+self.false_pos[cat]) for cat in self.true_pos if self.true_pos[cat]+self.false_pos[cat]} @property def cat_recalls(self): """ The per cartegory dictionary of the recall """ return {cat: self.true_pos[cat]/(self.true_pos[cat]+self.false_neg[cat]) for cat in self.true_pos if self.true_pos[cat]+self.false_neg[cat]} @property def cat_f_scores(self): """ The per cartegory dictionary of the F-scores """ prec, rec = self.cat_precisions, self.cat_recalls f_scores = {} for cat in prec.keys(): try: if prec[cat] == 0 and rec[cat] == 0: f_scores[cat] = 0 else: f_scores[cat] = 2 * prec[cat] * rec[cat] / (prec[cat] + rec[cat]) except KeyError: f_scores[cat] = 0 return f_scores @property def mean_precision(self): """ The mean precision on all objects """ return sum(self.true_pos.values())/(sum(self.true_pos.values()) + sum(self.false_pos.values())) @property def mean_recall(self): """ The mean recall on all objects """ return sum(self.true_pos.values())/(sum(self.true_pos.values()) + sum(self.false_neg.values())) @property def mean_f_score(self): """ The mean F-score on all objects """ prec, rec = self.mean_precision, self.mean_recall return 2 * prec * rec / (prec + rec) def __repr__(self) -> str: return "Detection stats with Cat. F-scores: \n" + str(self.cat_f_scores) @property def dict_summary(self): """ The mean precision, recall and F-score on all objects. """ return {"precision": self.mean_precision, "recall": self.mean_recall, "f-score": self.mean_f_score, "iou": self.iou}
[docs] def get_iou(obj1, obj2): """ Computes the intersection over union between two GameObjects. :param obj1: The bouding box of the detected object in (x, y, w, h) format :type obj1: ocatari.ram.game_objects.GameObject or ocatari.vision.game_objects.GameObject :param obj2: The ground truth bouding box :type obj2: ocatari.ram.game_objects.GameObject or ocatari.vision.game_objects.GameObject """ # determine the (x, y)-coordinates of the intersection rectangle xA = max(obj1.x, obj2.x) yA = max(obj1.y, obj2.y) xB = min(obj1.x+obj1.w, obj2.x+obj2.w) yB = min(obj1.y+obj1.h, obj2.y+obj2.h) # compute the area of intersection rectangle interArea = max(0, xB - xA) * max(0, yB - yA) # compute the area of both the prediction and ground-truth # rectangles boxAArea = (obj1.w) * (obj1.h) boxBArea = (obj2.w) * (obj2.h) # compute the intersection over union by taking the intersection # area and dividing it by the sum of prediction + ground-truth # areas - the interesection area iou = interArea / float(boxAArea + boxBArea - interArea) # return the intersection over union value return iou
def _make_class_lists(ram_list, vision_list): """ Creates a dictionary of object category liked with both the ram objs of this category and then the vision objs of this category. :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.ram.game_objects.GameObject :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.vision.game_objects.GameObject :return: A dictionary containing the per class metrics. :rtype: dict """ categories = set([obj.category for obj in ram_list+vision_list]) cat_lists = {} for cat in categories: cat_lists[cat] = ([obj.center for obj in ram_list if obj.category == cat], [obj.center for obj in vision_list if obj.category == cat]) return cat_lists
[docs] def detection_stats(ram_list, vision_list): """ Returns the Precision, Recall and F1_score via comparing the RAM and vision lists. These metrics are computed based on MIN_DIST_DETECTION (default =5). :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.ram.game_objects.GameObject :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.vision.game_objects.GameObject :return: A dictionary containing the per class metrics. :rtype: dict """ cat_lists = _make_class_lists(ram_list, vision_list) dets = {} for cat, (rlist, vlist) in cat_lists.items(): TrueP = 0 FalseP = max(0, len(rlist) - len(vlist)) FalseN = max(0, len(vlist) - len(rlist)) if len(rlist) > 1 and len(vlist) > 1: cost_m = cdist(rlist, vlist) row_ind, col_ind = linear_sum_assignment(cost_m) # reassign rlist = [rlist[i] for i in row_ind] vlist = [vlist[i] for i in col_ind] for ro, vo in zip(rlist, vlist): if euclidean(ro, vo) < MIN_DIST_DETECTION: TrueP += 1 else: FalseN += 1 FalseP += 1 dets[cat] = (TrueP, FalseP, FalseN) return dets
[docs] def get_all_metrics(ram_list, vision_list): """ Computes the: * mean_iou * per_class_ious * only_in_ram * only_in_vision * objs_in_ram * objs_in_vision * dets () two lists of GameObjects (from the RAM extraction and from the vision extraction methods). :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.ram.game_objects.GameObject :param ram_list: The list of objects detected by the RAM extraction method :type ram_list: list of ocatari.vision.game_objects.GameObject :return: A dictionary containing the per class metrics. :rtype: dict """ only_in_ram = [] only_in_vision = [] per_class_ious = {} ious = [] dets = detection_stats(ram_list, vision_list) if abs(len(vision_list) - len(ram_list)) > 10 and USE_IPDB: import ipdb; ipdb.set_trace() for vobj in vision_list: vobj._is_in_ram = False for robj in ram_list: robj._is_in_image = False for vobj in vision_list: if robj.__class__.__name__ == vobj.__class__.__name__: objname = robj.__class__.__name__ iou = get_iou(robj, vobj) if iou > 0: ious.append(iou) if objname not in per_class_ious: per_class_ious[objname] = [iou] else: per_class_ious[objname].append(iou) vobj._is_in_ram = True robj._is_in_image = True break for name, li in per_class_ious.items(): per_class_ious[name] = np.mean(li) for robj in ram_list: if not robj._is_in_image: only_in_ram.append(str(robj)) for vobj in vision_list: if not vobj._is_in_ram: only_in_vision.append(str(vobj)) return {"mean_iou": np.mean(ious), "per_class_ious": per_class_ious, "only_in_ram": only_in_ram, "only_in_vision": only_in_vision, "objs_in_ram": [str(o) for o in ram_list], "objs_in_vision": [str(o) for o in vision_list], "dets": dets}