Skip to content

Commit 15e6c73

Browse files
committed
Present model status based on abs_threshold
1 parent 2b9ca75 commit 15e6c73

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,16 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
183183
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)
184184

185185
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
186-
ignore_metric_reference=False):
186+
ignore_metric_reference=False, threshold_callback=None):
187187
if not self._metrics_results:
188188
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
189189
result_presenters = self.metric_executor.get_metric_presenters()
190190
extracted_results, extracted_meta = [], []
191191
for presenter, metric_result in zip(result_presenters, self._metrics_results):
192+
if threshold_callback:
193+
abs_threshold, rel_threshold = threshold_callback(metric_result)
194+
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
195+
192196
result, metadata = presenter.extract_result(metric_result)
193197
if isinstance(result, list):
194198
extracted_results.extend(result)

tools/accuracy_checker/accuracy_checker/presenters.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,18 +207,20 @@ def write_scalar_result(
207207
message = '{}: {}{}'.format(display_name, display_result, postfix)
208208

209209
if diff_with_ref and (diff_with_ref[0] or diff_with_ref[1]):
210-
abs_error = diff_with_ref[0] * scale
211-
rel_error = diff_with_ref[1]
212-
error_text = "abs error = {:.4} | relative error = {:.4}".format(abs_error, rel_error)
213-
214-
if not abs_threshold or not rel_threshold:
215-
result_message = "[RESULT: {}]".format(error_text)
210+
if not abs_threshold:
211+
result_message = "[abs error = {:.4} | relative error = {:.4}]".format(
212+
diff_with_ref[0] * scale, diff_with_ref[1]
213+
)
216214
message = "{} {}".format(message, result_message)
217-
elif abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]:
218-
fail_message = "[FAILED: {}]".format(error_text)
215+
elif abs_threshold <= diff_with_ref[0] or (rel_threshold and rel_threshold <= diff_with_ref[1]):
216+
fail_message = "FAILED: [abs error = {:.4} | relative error = {:.4}]".format(
217+
diff_with_ref[0] * scale, diff_with_ref[1]
218+
)
219219
message = "{} {}".format(message, color_format(fail_message, Color.FAILED))
220220
else:
221-
pass_message = "[PASS: {}]".format(error_text)
221+
pass_message = "PASSED: [abs error = {:.4} | relative error = {:.4}]".format(
222+
diff_with_ref[0] * scale, diff_with_ref[1]
223+
)
222224
message = "{} {}".format(message, color_format(pass_message, Color.PASSED))
223225

224226
print_info(message)

0 commit comments

Comments
 (0)