diff --git a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py index 3526e3945f..10c1a0d6d2 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py @@ -50,7 +50,7 @@ PRNetAdapter ) -from .reidentification import ReidAdapter +from .reidentification import ReidAdapter, BertEmbeddingsAdapter from .detection import ( TFObjectDetectionAPIAdapter, ClassAgnosticDetectionAdapter, @@ -208,6 +208,7 @@ 'AnomalySegmentationAdapter', 'ReidAdapter', + 'BertEmbeddingsAdapter', 'ImageProcessingAdapter', 'SuperResolutionAdapter', diff --git a/tools/accuracy_checker/accuracy_checker/adapters/reidentification.py b/tools/accuracy_checker/accuracy_checker/adapters/reidentification.py index dea6a59799..f8a6abb61c 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/reidentification.py +++ b/tools/accuracy_checker/accuracy_checker/adapters/reidentification.py @@ -106,3 +106,22 @@ def select_output_blob(self, outputs): self.output_blob = self.check_output_name(self.target_out, outputs) if self.output_blob is None: self.output_blob = next(iter(outputs)) + + +class BertEmbeddingsAdapter(ReidAdapter): + __provider__ = 'bert_sentence_embedding' + + @classmethod + def parameters(cls): + parameters = super(ReidAdapter, cls).parameters() + parameters.update({ + 'target_out': StringField(optional=True, description='Target output layer name') + }) + return parameters + + def configure(self): + self.joining_method = 'sum' + self.target_out = self.get_value_from_config('target_out') + self.keep_shape = True + self.mean_pooling = False # use 'sentence_similarity_pooling' postprocessor for proper mean pooling with attention to the input_mask + self.grn_workaround = False \ No newline at end of file diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py index f3d167772e..b00cc9c4c5 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py +++ b/tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py @@ -42,7 +42,7 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r # extract metrics results values prepared for printing def extract_metrics_results(self, print_results=True, ignore_results_formatting=False, - ignore_metric_reference=False): + ignore_metric_reference=False, threshold_callback=None): raise NotImplementedError # destruction for entity, which can not be deleted automatically diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py index 2fc39d0b3f..dfeaf40d8b 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py +++ b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py @@ -169,12 +169,15 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference) def extract_metrics_results(self, print_results=True, ignore_results_formatting=False, - ignore_metric_reference=False): + ignore_metric_reference=False, threshold_callback=None): if not self._metrics_results: self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference) result_presenters = self.metric_executor.get_metric_presenters() extracted_results, extracted_meta = [], [] for presenter, metric_result in zip(result_presenters, self._metrics_results): + if threshold_callback: + abs_threshold, rel_threshold = threshold_callback(metric_result) + metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold) result, metadata = presenter.extract_result(metric_result) if isinstance(result, list): extracted_results.extend(result) diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py index 14070d7bd3..a6d17c90b6 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py +++ b/tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py @@ -579,13 +579,16 @@ def compute_metrics(self, print_results=True, ignore_results_formatting=False, i return self._metrics_results def extract_metrics_results(self, print_results=True, ignore_results_formatting=False, - ignore_metric_reference=False): + ignore_metric_reference=False, threshold_callback=None): if not self._metrics_results: self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference) result_presenters = self.metric_executor.get_metric_presenters() extracted_results, extracted_meta = [], [] for presenter, metric_result in zip(result_presenters, self._metrics_results): + if threshold_callback: + abs_threshold, rel_threshold = threshold_callback(metric_result) + metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold) result, metadata = presenter.extract_result(metric_result) if isinstance(result, list): extracted_results.extend(result) diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py index ca01343b1c..388d22e298 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py +++ b/tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py @@ -51,9 +51,9 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r self._internal_module.print_metrics(ignore_results_formatting, ignore_metric_reference) def extract_metrics_results(self, print_results=True, ignore_results_formatting=False, - ignore_metric_reference=False): + ignore_metric_reference=False, threshold_callback=None): return self._internal_module.extract_metrics_results(print_results, ignore_results_formatting, - ignore_metric_reference) + ignore_metric_reference, threshold_callback) def release(self): self._internal_module.release() diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py index d9afff2084..cbcb2de518 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py +++ b/tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py @@ -464,13 +464,16 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference) def extract_metrics_results(self, print_results=True, ignore_results_formatting=False, - ignore_metric_reference=False): + ignore_metric_reference=False, threshold_callback=None): if not self._metrics_results: self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference) result_presenters = self.metric_executor.get_metric_presenters() extracted_results, extracted_meta = [], [] for presenter, metric_result in zip(result_presenters, self._metrics_results): + if threshold_callback: + abs_threshold, rel_threshold = threshold_callback(metric_result) + metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold) result, metadata = presenter.extract_result(metric_result) if isinstance(result, list): extracted_results.extend(result) diff --git a/tools/accuracy_checker/accuracy_checker/presenters.py b/tools/accuracy_checker/accuracy_checker/presenters.py index 695c1efaee..6866666f94 100644 --- a/tools/accuracy_checker/accuracy_checker/presenters.py +++ b/tools/accuracy_checker/accuracy_checker/presenters.py @@ -200,22 +200,26 @@ def extract_result(self, evaluation_result, names_from_refs=False): def write_scalar_result( res_value, name, abs_threshold=None, rel_threshold=None, diff_with_ref=None, value_name=None, - postfix='%', scale=100, result_format='{:.2f}' + postfix='%', scale=100, result_format='{:.3f}' ): display_name = "{}@{}".format(name, value_name) if value_name else name display_result = result_format.format(res_value * scale) message = '{}: {}{}'.format(display_name, display_result, postfix) if diff_with_ref and (diff_with_ref[0] or diff_with_ref[1]): - abs_threshold = abs_threshold or 0 - rel_threshold = rel_threshold or 0 - if abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]: - fail_message = "[FAILED: abs error = {:.4} | relative error = {:.4}]".format( - diff_with_ref[0] * scale, diff_with_ref[1] - ) + abs_error = diff_with_ref[0] * scale + rel_error = diff_with_ref[1] + error_text = "abs error = {:.4} | relative error = {:.4}".format(abs_error, rel_error) + + if not abs_threshold or not rel_threshold: + result_message = "[RESULT: {}]".format(error_text) + message = "{} {}".format(message, result_message) + elif abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]: + fail_message = "[FAILED: {}]".format(error_text) message = "{} {}".format(message, color_format(fail_message, Color.FAILED)) else: - message = "{} {}".format(message, color_format("[OK]", Color.PASSED)) + pass_message = "[PASS: {}]".format(error_text) + message = "{} {}".format(message, color_format(pass_message, Color.PASSED)) print_info(message)