Skip to content

Commit 44d42ef

Browse files
authored
Add bert embeddings adapter and callback for setting relative threshold for printed results (#4026)
* Add Bert sentence embedding adapter * Add callback to set rel_threshold for metric results * Update threshold setting * Suppress displaying FAILED status if no threshold passed
1 parent 9f1f1a6 commit 44d42ef

File tree

8 files changed

+48
-15
lines changed

8 files changed

+48
-15
lines changed

tools/accuracy_checker/accuracy_checker/adapters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
PRNetAdapter
5151
)
5252

53-
from .reidentification import ReidAdapter
53+
from .reidentification import ReidAdapter, BertEmbeddingsAdapter
5454
from .detection import (
5555
TFObjectDetectionAPIAdapter,
5656
ClassAgnosticDetectionAdapter,
@@ -208,6 +208,7 @@
208208
'AnomalySegmentationAdapter',
209209

210210
'ReidAdapter',
211+
'BertEmbeddingsAdapter',
211212

212213
'ImageProcessingAdapter',
213214
'SuperResolutionAdapter',

tools/accuracy_checker/accuracy_checker/adapters/reidentification.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,22 @@ def select_output_blob(self, outputs):
106106
self.output_blob = self.check_output_name(self.target_out, outputs)
107107
if self.output_blob is None:
108108
self.output_blob = next(iter(outputs))
109+
110+
111+
class BertEmbeddingsAdapter(ReidAdapter):
112+
__provider__ = 'bert_sentence_embedding'
113+
114+
@classmethod
115+
def parameters(cls):
116+
parameters = super(ReidAdapter, cls).parameters()
117+
parameters.update({
118+
'target_out': StringField(optional=True, description='Target output layer name')
119+
})
120+
return parameters
121+
122+
def configure(self):
123+
self.joining_method = 'sum'
124+
self.target_out = self.get_value_from_config('target_out')
125+
self.keep_shape = True
126+
self.mean_pooling = False # use 'sentence_similarity_pooling' postprocessor for proper mean pooling with attention to the input_mask
127+
self.grn_workaround = False

tools/accuracy_checker/accuracy_checker/evaluators/base_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
4242

4343
# extract metrics results values prepared for printing
4444
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
45-
ignore_metric_reference=False):
45+
ignore_metric_reference=False, threshold_callback=None):
4646
raise NotImplementedError
4747

4848
# destruction for entity, which can not be deleted automatically

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,15 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
169169
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)
170170

171171
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
172-
ignore_metric_reference=False):
172+
ignore_metric_reference=False, threshold_callback=None):
173173
if not self._metrics_results:
174174
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
175175
result_presenters = self.metric_executor.get_metric_presenters()
176176
extracted_results, extracted_meta = [], []
177177
for presenter, metric_result in zip(result_presenters, self._metrics_results):
178+
if threshold_callback:
179+
abs_threshold, rel_threshold = threshold_callback(metric_result)
180+
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
178181
result, metadata = presenter.extract_result(metric_result)
179182
if isinstance(result, list):
180183
extracted_results.extend(result)

tools/accuracy_checker/accuracy_checker/evaluators/model_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,16 @@ def compute_metrics(self, print_results=True, ignore_results_formatting=False, i
579579
return self._metrics_results
580580

581581
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
582-
ignore_metric_reference=False):
582+
ignore_metric_reference=False, threshold_callback=None):
583583
if not self._metrics_results:
584584
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
585585

586586
result_presenters = self.metric_executor.get_metric_presenters()
587587
extracted_results, extracted_meta = [], []
588588
for presenter, metric_result in zip(result_presenters, self._metrics_results):
589+
if threshold_callback:
590+
abs_threshold, rel_threshold = threshold_callback(metric_result)
591+
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
589592
result, metadata = presenter.extract_result(metric_result)
590593
if isinstance(result, list):
591594
extracted_results.extend(result)

tools/accuracy_checker/accuracy_checker/evaluators/module_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
5151
self._internal_module.print_metrics(ignore_results_formatting, ignore_metric_reference)
5252

5353
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
54-
ignore_metric_reference=False):
54+
ignore_metric_reference=False, threshold_callback=None):
5555
return self._internal_module.extract_metrics_results(print_results, ignore_results_formatting,
56-
ignore_metric_reference)
56+
ignore_metric_reference, threshold_callback)
5757

5858
def release(self):
5959
self._internal_module.release()

tools/accuracy_checker/accuracy_checker/evaluators/quantization_model_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,16 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
464464
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)
465465

466466
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
467-
ignore_metric_reference=False):
467+
ignore_metric_reference=False, threshold_callback=None):
468468
if not self._metrics_results:
469469
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
470470

471471
result_presenters = self.metric_executor.get_metric_presenters()
472472
extracted_results, extracted_meta = [], []
473473
for presenter, metric_result in zip(result_presenters, self._metrics_results):
474+
if threshold_callback:
475+
abs_threshold, rel_threshold = threshold_callback(metric_result)
476+
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
474477
result, metadata = presenter.extract_result(metric_result)
475478
if isinstance(result, list):
476479
extracted_results.extend(result)

tools/accuracy_checker/accuracy_checker/presenters.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,26 @@ def extract_result(self, evaluation_result, names_from_refs=False):
200200

201201
def write_scalar_result(
202202
res_value, name, abs_threshold=None, rel_threshold=None, diff_with_ref=None, value_name=None,
203-
postfix='%', scale=100, result_format='{:.2f}'
203+
postfix='%', scale=100, result_format='{:.3f}'
204204
):
205205
display_name = "{}@{}".format(name, value_name) if value_name else name
206206
display_result = result_format.format(res_value * scale)
207207
message = '{}: {}{}'.format(display_name, display_result, postfix)
208208

209209
if diff_with_ref and (diff_with_ref[0] or diff_with_ref[1]):
210-
abs_threshold = abs_threshold or 0
211-
rel_threshold = rel_threshold or 0
212-
if abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]:
213-
fail_message = "[FAILED: abs error = {:.4} | relative error = {:.4}]".format(
214-
diff_with_ref[0] * scale, diff_with_ref[1]
215-
)
210+
abs_error = diff_with_ref[0] * scale
211+
rel_error = diff_with_ref[1]
212+
error_text = "abs error = {:.4} | relative error = {:.4}".format(abs_error, rel_error)
213+
214+
if not abs_threshold or not rel_threshold:
215+
result_message = "[RESULT: {}]".format(error_text)
216+
message = "{} {}".format(message, result_message)
217+
elif abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]:
218+
fail_message = "[FAILED: {}]".format(error_text)
216219
message = "{} {}".format(message, color_format(fail_message, Color.FAILED))
217220
else:
218-
message = "{} {}".format(message, color_format("[OK]", Color.PASSED))
221+
pass_message = "[PASS: {}]".format(error_text)
222+
message = "{} {}".format(message, color_format(pass_message, Color.PASSED))
219223

220224
print_info(message)
221225

0 commit comments

Comments
 (0)