Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tools/accuracy_checker/accuracy_checker/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PRNetAdapter
)

from .reidentification import ReidAdapter
from .reidentification import ReidAdapter, BertEmbeddingsAdapter
from .detection import (
TFObjectDetectionAPIAdapter,
ClassAgnosticDetectionAdapter,
Expand Down Expand Up @@ -208,6 +208,7 @@
'AnomalySegmentationAdapter',

'ReidAdapter',
'BertEmbeddingsAdapter',

'ImageProcessingAdapter',
'SuperResolutionAdapter',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,22 @@ def select_output_blob(self, outputs):
self.output_blob = self.check_output_name(self.target_out, outputs)
if self.output_blob is None:
self.output_blob = next(iter(outputs))


class BertEmbeddingsAdapter(ReidAdapter):
__provider__ = 'bert_sentence_embedding'

@classmethod
def parameters(cls):
parameters = super(ReidAdapter, cls).parameters()
parameters.update({
'target_out': StringField(optional=True, description='Target output layer name')
})
return parameters

def configure(self):
self.joining_method = 'sum'
self.target_out = self.get_value_from_config('target_out')
self.keep_shape = True
self.mean_pooling = False # use 'sentence_similarity_pooling' postprocessor for proper mean pooling with attention to the input_mask
self.grn_workaround = False
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r

# extract metrics results values prepared for printing
def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
ignore_metric_reference=False):
ignore_metric_reference=False, threshold_callback=None):
raise NotImplementedError

# destruction for entity, which can not be deleted automatically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,15 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)

def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
ignore_metric_reference=False):
ignore_metric_reference=False, threshold_callback=None):
if not self._metrics_results:
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)
result_presenters = self.metric_executor.get_metric_presenters()
extracted_results, extracted_meta = [], []
for presenter, metric_result in zip(result_presenters, self._metrics_results):
if threshold_callback:
abs_threshold, rel_threshold = threshold_callback(metric_result)
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
result, metadata = presenter.extract_result(metric_result)
if isinstance(result, list):
extracted_results.extend(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,16 @@ def compute_metrics(self, print_results=True, ignore_results_formatting=False, i
return self._metrics_results

def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
ignore_metric_reference=False):
ignore_metric_reference=False, threshold_callback=None):
if not self._metrics_results:
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)

result_presenters = self.metric_executor.get_metric_presenters()
extracted_results, extracted_meta = [], []
for presenter, metric_result in zip(result_presenters, self._metrics_results):
if threshold_callback:
abs_threshold, rel_threshold = threshold_callback(metric_result)
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
result, metadata = presenter.extract_result(metric_result)
if isinstance(result, list):
extracted_results.extend(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
self._internal_module.print_metrics(ignore_results_formatting, ignore_metric_reference)

def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
ignore_metric_reference=False):
ignore_metric_reference=False, threshold_callback=None):
return self._internal_module.extract_metrics_results(print_results, ignore_results_formatting,
ignore_metric_reference)
ignore_metric_reference, threshold_callback)

def release(self):
self._internal_module.release()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,16 @@ def print_metrics_results(self, ignore_results_formatting=False, ignore_metric_r
presenter.write_result(metric_result, ignore_results_formatting, ignore_metric_reference)

def extract_metrics_results(self, print_results=True, ignore_results_formatting=False,
ignore_metric_reference=False):
ignore_metric_reference=False, threshold_callback=None):
if not self._metrics_results:
self.compute_metrics(False, ignore_results_formatting, ignore_metric_reference)

result_presenters = self.metric_executor.get_metric_presenters()
extracted_results, extracted_meta = [], []
for presenter, metric_result in zip(result_presenters, self._metrics_results):
if threshold_callback:
abs_threshold, rel_threshold = threshold_callback(metric_result)
metric_result = metric_result._replace(abs_threshold=abs_threshold, rel_threshold=rel_threshold)
result, metadata = presenter.extract_result(metric_result)
if isinstance(result, list):
extracted_results.extend(result)
Expand Down
20 changes: 12 additions & 8 deletions tools/accuracy_checker/accuracy_checker/presenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,26 @@ def extract_result(self, evaluation_result, names_from_refs=False):

def write_scalar_result(
res_value, name, abs_threshold=None, rel_threshold=None, diff_with_ref=None, value_name=None,
postfix='%', scale=100, result_format='{:.2f}'
postfix='%', scale=100, result_format='{:.3f}'
):
display_name = "{}@{}".format(name, value_name) if value_name else name
display_result = result_format.format(res_value * scale)
message = '{}: {}{}'.format(display_name, display_result, postfix)

if diff_with_ref and (diff_with_ref[0] or diff_with_ref[1]):
abs_threshold = abs_threshold or 0
rel_threshold = rel_threshold or 0
if abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]:
fail_message = "[FAILED: abs error = {:.4} | relative error = {:.4}]".format(
diff_with_ref[0] * scale, diff_with_ref[1]
)
abs_error = diff_with_ref[0] * scale
rel_error = diff_with_ref[1]
error_text = "abs error = {:.4} | relative error = {:.4}".format(abs_error, rel_error)

if not abs_threshold or not rel_threshold:
result_message = "[RESULT: {}]".format(error_text)
message = "{} {}".format(message, result_message)
elif abs_threshold <= diff_with_ref[0] or rel_threshold <= diff_with_ref[1]:
fail_message = "[FAILED: {}]".format(error_text)
message = "{} {}".format(message, color_format(fail_message, Color.FAILED))
else:
message = "{} {}".format(message, color_format("[OK]", Color.PASSED))
pass_message = "[PASS: {}]".format(error_text)
message = "{} {}".format(message, color_format(pass_message, Color.PASSED))

print_info(message)

Expand Down
Loading