Skip to content

Commit 6b6245a

Browse files
committed
Add callback to set rel_threshold for metric results
1 parent 0563314 commit 6b6245a

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
4242

4343
# extract metrics results values prepared for printing
4444
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
45-
ignore_metric_reference=False):
45+
ignore_metric_reference=False, threshold_callback=None):
4646
raise NotImplementedError
4747

4848
# destruction for entity, which can not be deleted automatically

tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,15 @@ def compute_metrics(self, print_results=True, ignore_results_formatting=False, i
579579
return self._metrics_results
580580

581581
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
582-
ignore_metric_reference=False):
582+
ignore_metric_reference=False, threshold_callback=None):
583583
if not self._metrics_results:
584584
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
585585

586586
result_presenters = self.metric_executor.get_metric_presenters()
587587
extracted_results, extracted_meta = [], []
588588
for presenter, metric_result in zip(result_presenters, self._metrics_results):
589+
if threshold_callback:
590+
metric_result = metric_result._replace(rel_threshold=threshold_callback(metric_result))
589591
result, metadata = presenter.extract_result(metric_result)
590592
if isinstance(result, list):
591593
extracted_results.extend(result)

tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
5151
self._internal_module.print_metrics(ignore_results_formatting, ignore_metric_reference)
5252

5353
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
54-
ignore_metric_reference=False):
54+
ignore_metric_reference=False, threshold_callback=None):
5555
return self._internal_module.extract_metrics_results(print_results, ignore_results_formatting,
56-
ignore_metric_reference)
56+
ignore_metric_reference, threshold_callback)
5757

5858
def release(self):
5959
self._internal_module.release()

tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,15 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
464464
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)
465465

466466
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
467-
ignore_metric_reference=False):
467+
ignore_metric_reference=False, threshold_callback=None):
468468
if not self._metrics_results:
469469
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
470470

471471
result_presenters = self.metric_executor.get_metric_presenters()
472472
extracted_results, extracted_meta = [], []
473473
for presenter, metric_result in zip(result_presenters, self._metrics_results):
474+
if threshold_callback:
475+
metric_result = metric_result._replace(rel_threshold=threshold_callback(metric_result))
474476
result, metadata = presenter.extract_result(metric_result)
475477
if isinstance(result, list):
476478
extracted_results.extend(result)

0 commit comments

Comments
 (0)