diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py index fad7fd4085..3b70990250 100644 --- a/doctr/models/recognition/core.py +++ b/doctr/models/recognition/core.py @@ -3,13 +3,69 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from typing import Callable, Literal, Union import numpy as np from doctr.datasets import encode_sequences from doctr.utils.repr import NestedObject -__all__ = ["RecognitionPostProcessor", "RecognitionModel"] +__all__ = ["RecognitionPostProcessor", "RecognitionModel", "aggregate_confidence", "ConfidenceAggregation"] + +# Type alias for confidence aggregation methods +ConfidenceAggregation = Union[Literal["mean", "geometric_mean", "harmonic_mean", "min", "max"], Callable[[np.ndarray], float]] + + +def aggregate_confidence( + probs: np.ndarray, + method: ConfidenceAggregation = "mean", +) -> float: + """Aggregate character-level confidence scores into a word-level confidence score. + + Args: + probs: Array of character-level confidence scores (values between 0 and 1) + method: Aggregation method to use. Can be one of: + - "mean": Arithmetic mean (default) + - "geometric_mean": Geometric mean (more sensitive to low values) + - "harmonic_mean": Harmonic mean (even more sensitive to low values) + - "min": Minimum confidence (most conservative) + - "max": Maximum confidence (most optimistic) + - A callable that takes an ndarray and returns a float + + Returns: + Aggregated confidence score as a float between 0 and 1 + """ + if len(probs) == 0: + return 0.0 + + # Convert to numpy if needed and ensure float type + probs = np.asarray(probs, dtype=np.float64) + + # Clip to valid probability range + probs = np.clip(probs, 0.0, 1.0) + + if callable(method): + return float(method(probs)) + + if method == "mean": + return float(np.mean(probs)) + elif method == "geometric_mean": + # Use log-sum-exp trick for numerical stability + # geometric_mean = exp(mean(log(probs))) + # Handle zeros by replacing with small epsilon + safe_probs = np.where(probs > 0, probs, 1e-10) + return float(np.exp(np.mean(np.log(safe_probs)))) + elif method == "harmonic_mean": + # harmonic_mean = n / sum(1/probs) + # Handle zeros by replacing with small epsilon + safe_probs = np.where(probs > 0, probs, 1e-10) + return float(len(safe_probs) / np.sum(1.0 / safe_probs)) + elif method == "min": + return float(np.min(probs)) + elif method == "max": + return float(np.max(probs)) + else: + raise ValueError(f"Unknown aggregation method: {method}. Expected one of 'mean', 'geometric_mean', 'harmonic_mean', 'min', 'max', or a callable.") class RecognitionModel(NestedObject): @@ -41,14 +97,19 @@ class RecognitionPostProcessor(NestedObject): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. """ def __init__( self, vocab: str, + confidence_aggregation: ConfidenceAggregation = "mean", ) -> None: self.vocab = vocab + self.confidence_aggregation = confidence_aggregation self._embedding = list(self.vocab) + [""] def extra_repr(self) -> str: - return f"vocab_size={len(self.vocab)}" + agg_repr = self.confidence_aggregation if isinstance(self.confidence_aggregation, str) else "custom" + return f"vocab_size={len(self.vocab)}, confidence_aggregation='{agg_repr}'" diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index 077da138c8..74d5fc4880 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -16,7 +16,7 @@ from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r from ...utils.pytorch import load_pretrained_params -from ..core import RecognitionModel, RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionModel, RecognitionPostProcessor, aggregate_confidence __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -50,10 +50,19 @@ class CTCPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ - @staticmethod + def __init__( + self, + vocab: str, + confidence_aggregation: ConfidenceAggregation = "min", + ) -> None: + super().__init__(vocab, confidence_aggregation) + def ctc_best_path( + self, logits: torch.Tensor, vocab: str = VOCABS["french"], blank: int = 0, @@ -69,16 +78,38 @@ def ctc_best_path( Returns: A list of tuples: (word, confidence) """ - # Gather the most confident characters, and assign the smallest conf among those to the sequence prob - probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values - - # collapse best path (using itertools.groupby), map to chars, join char list to string - words = [ - decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) - for seq in torch.argmax(logits, dim=-1) - ] - - return list(zip(words, probs.tolist())) + # Get softmax probabilities and best path indices + softmax_probs = F.softmax(logits, dim=-1) + best_path_indices = torch.argmax(logits, dim=-1) # N x T + # Get the probability of the best path at each time step + best_path_probs = softmax_probs.max(dim=-1).values # N x T + + results = [] + for batch_idx in range(logits.size(0)): + seq = best_path_indices[batch_idx].tolist() + probs_seq = best_path_probs[batch_idx] + + # Collapse best path: remove blanks and repeated characters + # Track which positions contribute to the final word + char_probs = [] + prev_char = None + for pos, char_idx in enumerate(seq): + if char_idx != blank and char_idx != prev_char: + char_probs.append(probs_seq[pos].item()) + prev_char = char_idx + + # Decode the word + word = decode_sequence([k for k, _ in groupby(seq) if k != blank], vocab) + + # Aggregate character probabilities + if char_probs: + conf = aggregate_confidence(char_probs, self.confidence_aggregation) + else: + conf = 0.0 + + results.append((word, conf)) + + return results def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]: """Performs decoding of raw output with CTC and decoding of CTC predictions diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py index 55401aecda..c5cf076ac3 100644 --- a/doctr/models/recognition/master/base.py +++ b/doctr/models/recognition/master/base.py @@ -7,7 +7,7 @@ import numpy as np from ....datasets import encode_sequences -from ..core import RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionPostProcessor class _MASTER: @@ -44,11 +44,14 @@ class _MASTERPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ def __init__( self, vocab: str, + confidence_aggregation: ConfidenceAggregation = "min", ) -> None: - super().__init__(vocab) + super().__init__(vocab, confidence_aggregation) self._embedding = list(vocab) + [""] + [""] + [""] diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 46f87ec72e..b18681b912 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -17,6 +17,7 @@ from doctr.models.modules.transformer import Decoder, PositionalEncoding from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from ..core import aggregate_confidence from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -258,7 +259,13 @@ def decode(self, encoded: torch.Tensor) -> torch.Tensor: class MASTERPostProcessor(_MASTERPostProcessor): - """Post processor for MASTER architectures""" + """Post processor for MASTER architectures + + Args: + vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". + """ def __call__( self, @@ -267,9 +274,7 @@ def __call__( # compute pred with argmax for attention models out_idxs = logits.argmax(-1) # N x L - probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) - # Take the minimum confidence of the sequence - probs = probs.min(dim=1).values.detach().cpu() + preds_prob = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) # Manual decoding word_values = [ @@ -277,7 +282,15 @@ def __call__( for encoded_seq in out_idxs.cpu().numpy() ] - return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + # compute probabilities for each word up to the EOS token using configured aggregation method + probs = [ + aggregate_confidence(preds_prob[i, : len(word)].detach().cpu().numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) def _master( diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index e445dd0791..57622cde05 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -14,6 +14,7 @@ from doctr.models.modules.transformer import Decoder, PositionalEncoding from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ..core import aggregate_confidence from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -250,6 +251,8 @@ class MASTERPostProcessor(_MASTERPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ def __call__( @@ -259,9 +262,7 @@ def __call__( # compute pred with argmax for attention models out_idxs = tf.math.argmax(logits, axis=2) # N x L - probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) - # Take the minimum confidence of the sequence - probs = tf.math.reduce_min(probs, axis=1) + preds_prob = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) # decode raw output of the model with tf_label_to_idx out_idxs = tf.cast(out_idxs, dtype="int32") @@ -271,7 +272,15 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + # compute probabilities for each word up to the EOS token using configured aggregation method + probs = [ + aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER: diff --git a/doctr/models/recognition/parseq/base.py b/doctr/models/recognition/parseq/base.py index ff445ccab4..8b3efa02c5 100644 --- a/doctr/models/recognition/parseq/base.py +++ b/doctr/models/recognition/parseq/base.py @@ -7,7 +7,7 @@ import numpy as np from ....datasets import encode_sequences -from ..core import RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionPostProcessor class _PARSeq: @@ -44,11 +44,14 @@ class _PARSeqPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. """ def __init__( self, vocab: str, + confidence_aggregation: ConfidenceAggregation = "mean", ) -> None: - super().__init__(vocab) + super().__init__(vocab, confidence_aggregation) self._embedding = list(vocab) + ["", "", ""] diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index a088f4b776..d1fa54be4c 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -20,6 +20,7 @@ from ...classification import vit_s from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from ..core import aggregate_confidence from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -409,6 +410,8 @@ class PARSeqPostProcessor(_PARSeqPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean". """ def __call__( @@ -424,9 +427,12 @@ def __call__( "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs.cpu().numpy() ] - # compute probabilties for each word up to the EOS token + # compute probabilities for each word up to the EOS token using configured aggregation method probs = [ - preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + aggregate_confidence(preds_prob[i, : len(word)].cpu().numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 7e4153d55e..f677ec8dd1 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -17,6 +17,7 @@ from ...classification import vit_s from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ..core import aggregate_confidence from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -423,6 +424,8 @@ class PARSeqPostProcessor(_PARSeqPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean". """ def __call__( @@ -441,9 +444,11 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - # compute probabilties for each word up to the EOS token + # compute probabilities for each word up to the EOS token using configured aggregation method probs = [ - preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation) + if word + else 0.0 for i, word in enumerate(word_values) ] diff --git a/doctr/models/recognition/predictor/_utils.py b/doctr/models/recognition/predictor/_utils.py index 2758f7936e..02798d96bd 100644 --- a/doctr/models/recognition/predictor/_utils.py +++ b/doctr/models/recognition/predictor/_utils.py @@ -8,6 +8,7 @@ import numpy as np +from ..core import ConfidenceAggregation, aggregate_confidence from ..utils import merge_multi_strings __all__ = ["split_crops", "remap_preds"] @@ -120,6 +121,7 @@ def remap_preds( preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int, float]], overlap_ratio: float, + confidence_aggregation: ConfidenceAggregation = "mean", ) -> list[tuple[str, float]]: """ Reconstruct predictions from possibly split crops. @@ -128,6 +130,9 @@ def remap_preds( preds: List of (text, confidence) tuples from each crop. crop_map: Map returned by `split_crops`. overlap_ratio: Overlap ratio used during splitting. + confidence_aggregation: Method to aggregate confidence scores from split crops. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. + Defaults to "mean". Returns: List of merged (text, confidence) tuples corresponding to original crops. @@ -140,6 +145,6 @@ def remap_preds( start_idx, end_idx, last_overlap = item text_parts, confidences = zip(*preds[start_idx:end_idx]) merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap) - merged_conf = sum(confidences) / len(confidences) # average confidence + merged_conf = aggregate_confidence(list(confidences), confidence_aggregation) remapped.append((merged_text, merged_conf)) return remapped diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index 35242674ea..d23ff9f47c 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -16,7 +16,7 @@ from ...classification import resnet31 from ...utils.pytorch import _bf16_to_float32, load_pretrained_params -from ..core import RecognitionModel, RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionModel, RecognitionPostProcessor, aggregate_confidence __all__ = ["SAR", "sar_resnet31"] @@ -320,8 +320,17 @@ class SARPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ + def __init__( + self, + vocab: str, + confidence_aggregation: ConfidenceAggregation = "min", + ) -> None: + super().__init__(vocab, confidence_aggregation) + def __call__( self, logits: torch.Tensor, @@ -329,9 +338,7 @@ def __call__( # compute pred with argmax for attention models out_idxs = logits.argmax(-1) # N x L - probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) - # Take the minimum confidence of the sequence - probs = probs.min(dim=1).values.detach().cpu() + preds_prob = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) # Manual decoding word_values = [ @@ -339,7 +346,15 @@ def __call__( for encoded_seq in out_idxs.detach().cpu().numpy() ] - return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + # compute probabilities for each word up to the EOS token using configured aggregation method + probs = [ + aggregate_confidence(preds_prob[i, : len(word)].detach().cpu().numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) def _sar( diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 5c2745f4f3..4de318fd5e 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -14,7 +14,7 @@ from ...classification import resnet31 from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params -from ..core import RecognitionModel, RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionModel, RecognitionPostProcessor, aggregate_confidence __all__ = ["SAR", "sar_resnet31"] @@ -344,8 +344,17 @@ class SARPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ + def __init__( + self, + vocab: str, + confidence_aggregation: ConfidenceAggregation = "min", + ) -> None: + super().__init__(vocab, confidence_aggregation) + def __call__( self, logits: tf.Tensor, @@ -353,9 +362,7 @@ def __call__( # compute pred with argmax for attention models out_idxs = tf.math.argmax(logits, axis=2) # N x L - probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) - # Take the minimum confidence of the sequence - probs = tf.math.reduce_min(probs, axis=1) + preds_prob = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) # decode raw output of the model with tf_label_to_idx out_idxs = tf.cast(out_idxs, dtype="int32") @@ -365,7 +372,15 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + # compute probabilities for each word up to the EOS token using configured aggregation method + probs = [ + aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) def _sar( diff --git a/doctr/models/recognition/viptr/pytorch.py b/doctr/models/recognition/viptr/pytorch.py index e874148e65..ad085bc5ca 100644 --- a/doctr/models/recognition/viptr/pytorch.py +++ b/doctr/models/recognition/viptr/pytorch.py @@ -17,7 +17,7 @@ from ...classification import vip_base, vip_tiny from ...utils.pytorch import _bf16_to_float32, load_pretrained_params -from ..core import RecognitionModel, RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionModel, RecognitionPostProcessor, aggregate_confidence __all__ = ["VIPTR", "viptr_base", "viptr_tiny"] @@ -45,10 +45,19 @@ class VIPTRPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min". """ - @staticmethod + def __init__( + self, + vocab: str, + confidence_aggregation: ConfidenceAggregation = "min", + ) -> None: + super().__init__(vocab, confidence_aggregation) + def ctc_best_path( + self, logits: torch.Tensor, vocab: str = VOCABS["french"], blank: int = 0, @@ -64,16 +73,38 @@ def ctc_best_path( Returns: A list of tuples: (word, confidence) """ - # Gather the most confident characters, and assign the smallest conf among those to the sequence prob - probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values - - # collapse best path (using itertools.groupby), map to chars, join char list to string - words = [ - decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) - for seq in torch.argmax(logits, dim=-1) - ] - - return list(zip(words, probs.tolist())) + # Get softmax probabilities and best path indices + softmax_probs = F.softmax(logits, dim=-1) + best_path_indices = torch.argmax(logits, dim=-1) # N x T + # Get the probability of the best path at each time step + best_path_probs = softmax_probs.max(dim=-1).values # N x T + + results = [] + for batch_idx in range(logits.size(0)): + seq = best_path_indices[batch_idx].tolist() + probs_seq = best_path_probs[batch_idx] + + # Collapse best path: remove blanks and repeated characters + # Track which positions contribute to the final word + char_probs = [] + prev_char = None + for pos, char_idx in enumerate(seq): + if char_idx != blank and char_idx != prev_char: + char_probs.append(probs_seq[pos].item()) + prev_char = char_idx + + # Decode the word + word = decode_sequence([k for k, _ in groupby(seq) if k != blank], vocab) + + # Aggregate character probabilities + if char_probs: + conf = aggregate_confidence(char_probs, self.confidence_aggregation) + else: + conf = 0.0 + + results.append((word, conf)) + + return results def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]: """Performs decoding of raw output with CTC and decoding of CTC predictions diff --git a/doctr/models/recognition/vitstr/base.py b/doctr/models/recognition/vitstr/base.py index adc78115f2..1c7c5e8002 100644 --- a/doctr/models/recognition/vitstr/base.py +++ b/doctr/models/recognition/vitstr/base.py @@ -7,7 +7,7 @@ import numpy as np from ....datasets import encode_sequences -from ..core import RecognitionPostProcessor +from ..core import ConfidenceAggregation, RecognitionPostProcessor class _ViTSTR: @@ -43,11 +43,14 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. """ def __init__( self, vocab: str, + confidence_aggregation: ConfidenceAggregation = "mean", ) -> None: - super().__init__(vocab) + super().__init__(vocab, confidence_aggregation) self._embedding = list(vocab) + ["", ""] diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index b438d91bb2..d673cb4282 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -16,6 +16,7 @@ from ...classification import vit_b, vit_s from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from ..core import aggregate_confidence from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -166,6 +167,8 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean". """ def __call__( @@ -181,9 +184,12 @@ def __call__( "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs.cpu().numpy() ] - # compute probabilties for each word up to the EOS token + # compute probabilities for each word up to the EOS token using configured aggregation method probs = [ - preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + aggregate_confidence(preds_prob[i, : len(word)].cpu().numpy(), self.confidence_aggregation) + if word + else 0.0 + for i, word in enumerate(word_values) ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index d2ff52f2f9..a579deace8 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -13,6 +13,7 @@ from ...classification import vit_b, vit_s from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ..core import aggregate_confidence from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -165,6 +166,8 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor): Args: vocab: string containing the ordered sequence of supported characters + confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence. + Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean". """ def __call__( @@ -183,9 +186,11 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - # compute probabilties for each word up to the EOS token + # compute probabilities for each word up to the EOS token using configured aggregation method probs = [ - preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation) + if word + else 0.0 for i, word in enumerate(word_values) ] diff --git a/tests/common/test_models_recognition_utils.py b/tests/common/test_models_recognition_utils.py index 428c42c94e..8fdcecc82f 100644 --- a/tests/common/test_models_recognition_utils.py +++ b/tests/common/test_models_recognition_utils.py @@ -1,5 +1,7 @@ +import numpy as np import pytest +from doctr.models.recognition.core import aggregate_confidence from doctr.models.recognition.utils import merge_multi_strings, merge_strings @@ -61,3 +63,119 @@ def test_merge_strings(a, b, overlap_ratio, merged): ) def test_merge_multi_strings(seq_list, overlap_ratio, last_overlap_ratio, merged): assert merged == merge_multi_strings(seq_list, overlap_ratio, last_overlap_ratio) + + +# Tests for confidence aggregation +class TestAggregateConfidence: + """Tests for the aggregate_confidence function.""" + + def test_empty_probs(self): + """Empty probability array should return 0.0.""" + assert aggregate_confidence(np.array([]), "mean") == 0.0 + assert aggregate_confidence([], "min") == 0.0 + + @pytest.mark.parametrize( + "probs, method, expected", + [ + # Arithmetic mean tests + ([0.8, 0.9, 0.7], "mean", 0.8), + ([0.5, 0.5, 0.5], "mean", 0.5), + ([1.0, 1.0, 1.0], "mean", 1.0), + ([0.0, 0.0, 0.0], "mean", 0.0), + # Minimum tests + ([0.8, 0.9, 0.7], "min", 0.7), + ([0.5, 0.3, 0.9], "min", 0.3), + ([1.0, 1.0, 1.0], "min", 1.0), + # Maximum tests + ([0.8, 0.9, 0.7], "max", 0.9), + ([0.5, 0.3, 0.9], "max", 0.9), + ([0.0, 0.0, 0.0], "max", 0.0), + ], + ) + def test_basic_aggregation_methods(self, probs, method, expected): + """Test basic aggregation methods with simple inputs.""" + result = aggregate_confidence(probs, method) + assert abs(result - expected) < 1e-6 + + def test_geometric_mean(self): + """Test geometric mean calculation.""" + # geometric_mean([0.8, 0.8, 0.8]) = 0.8 + result = aggregate_confidence([0.8, 0.8, 0.8], "geometric_mean") + assert abs(result - 0.8) < 1e-6 + + # geometric_mean([1.0, 0.5]) = sqrt(0.5) ≈ 0.707 + result = aggregate_confidence([1.0, 0.5], "geometric_mean") + assert abs(result - np.sqrt(0.5)) < 1e-6 + + # geometric_mean with a zero should return very small value (using epsilon) + result = aggregate_confidence([0.0, 0.5, 0.5], "geometric_mean") + assert result < 0.01 # Should be very small due to zero + + def test_harmonic_mean(self): + """Test harmonic mean calculation.""" + # harmonic_mean([0.5, 0.5, 0.5]) = 0.5 + result = aggregate_confidence([0.5, 0.5, 0.5], "harmonic_mean") + assert abs(result - 0.5) < 1e-6 + + # harmonic_mean([1.0, 0.5]) = 2 / (1/1.0 + 1/0.5) = 2 / 3 ≈ 0.667 + result = aggregate_confidence([1.0, 0.5], "harmonic_mean") + assert abs(result - 2 / 3) < 1e-6 + + # harmonic_mean with a zero should return very small value (using epsilon) + result = aggregate_confidence([0.0, 0.5, 0.5], "harmonic_mean") + assert result < 0.01 # Should be very small due to zero + + def test_clipping(self): + """Test that values are clipped to [0, 1] range.""" + # Values outside range should be clipped + result = aggregate_confidence([1.5, 0.5, -0.5], "mean") + # After clipping: [1.0, 0.5, 0.0], mean = 0.5 + assert abs(result - 0.5) < 1e-6 + + def test_single_value(self): + """Test with single value - all methods should return that value.""" + for method in ["mean", "geometric_mean", "harmonic_mean", "min", "max"]: + result = aggregate_confidence([0.75], method) + assert abs(result - 0.75) < 1e-6 + + def test_custom_callable(self): + """Test with custom aggregation function.""" + + def custom_median(probs): + return float(np.median(probs)) + + result = aggregate_confidence([0.1, 0.5, 0.9], custom_median) + assert abs(result - 0.5) < 1e-6 + + def test_invalid_method(self): + """Test that invalid method raises ValueError.""" + with pytest.raises(ValueError, match="Unknown aggregation method"): + aggregate_confidence([0.5, 0.5], "invalid_method") + + def test_numpy_array_input(self): + """Test with numpy array input.""" + probs = np.array([0.8, 0.9, 0.7]) + result = aggregate_confidence(probs, "mean") + assert abs(result - 0.8) < 1e-6 + + def test_list_input(self): + """Test with Python list input.""" + probs = [0.8, 0.9, 0.7] + result = aggregate_confidence(probs, "mean") + assert abs(result - 0.8) < 1e-6 + + def test_ordering_sensitivity(self): + """Test that methods sensitive to outliers behave correctly.""" + # Low outlier should affect min and harmonic_mean more than mean + probs_with_low_outlier = [0.9, 0.9, 0.9, 0.1] + + mean_result = aggregate_confidence(probs_with_low_outlier, "mean") + min_result = aggregate_confidence(probs_with_low_outlier, "min") + harmonic_result = aggregate_confidence(probs_with_low_outlier, "harmonic_mean") + + # min should return the lowest value + assert abs(min_result - 0.1) < 1e-6 + # mean should be higher + assert mean_result > harmonic_result + # harmonic mean should be more affected by low values than arithmetic mean + assert harmonic_result < mean_result diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py index bd2a8d9eee..fd38f5a2c2 100644 --- a/tests/pytorch/test_models_recognition_pt.py +++ b/tests/pytorch/test_models_recognition_pt.py @@ -67,17 +67,17 @@ def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): @pytest.mark.parametrize( - "post_processor, input_shape", + "post_processor, input_shape, default_aggregation", [ - [CTCPostProcessor, [2, 119, 30]], - [SARPostProcessor, [2, 119, 30]], - [ViTSTRPostProcessor, [2, 119, 30]], - [MASTERPostProcessor, [2, 119, 30]], - [PARSeqPostProcessor, [2, 119, 30]], - [VIPTRPostProcessor, [2, 119, 30]], + [CTCPostProcessor, [2, 119, 30], "min"], + [SARPostProcessor, [2, 119, 30], "min"], + [ViTSTRPostProcessor, [2, 119, 30], "mean"], + [MASTERPostProcessor, [2, 119, 30], "min"], + [PARSeqPostProcessor, [2, 119, 30], "mean"], + [VIPTRPostProcessor, [2, 119, 30], "min"], ], ) -def test_reco_postprocessors(post_processor, input_shape, mock_vocab): +def test_reco_postprocessors(post_processor, input_shape, default_aggregation, mock_vocab): processor = post_processor(mock_vocab) decoded = processor(torch.rand(*input_shape)) assert isinstance(decoded, list) @@ -85,7 +85,8 @@ def test_reco_postprocessors(post_processor, input_shape, mock_vocab): assert len(decoded) == input_shape[0] assert all(char in mock_vocab for word, _ in decoded for char in word) # Repr - assert repr(processor) == f"{post_processor.__name__}(vocab_size={len(mock_vocab)})" + expected_repr = f"{post_processor.__name__}(vocab_size={len(mock_vocab)}, confidence_aggregation='{default_aggregation}')" + assert repr(processor) == expected_repr @pytest.mark.parametrize( diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index 53cde36bbe..72ebfa880b 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -68,16 +68,16 @@ def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): @pytest.mark.parametrize( - "post_processor, input_shape", + "post_processor, input_shape, default_aggregation", [ - [SARPostProcessor, [2, 30, 119]], - [CTCPostProcessor, [2, 30, 119]], - [MASTERPostProcessor, [2, 30, 119]], - [ViTSTRPostProcessor, [2, 30, 119]], - [PARSeqPostProcessor, [2, 30, 119]], + [SARPostProcessor, [2, 30, 119], "min"], + [CTCPostProcessor, [2, 30, 119], "mean"], # TF CTC uses beam search, inherits default + [MASTERPostProcessor, [2, 30, 119], "min"], + [ViTSTRPostProcessor, [2, 30, 119], "mean"], + [PARSeqPostProcessor, [2, 30, 119], "mean"], ], ) -def test_reco_postprocessors(post_processor, input_shape, mock_vocab): +def test_reco_postprocessors(post_processor, input_shape, default_aggregation, mock_vocab): processor = post_processor(mock_vocab) decoded = processor(tf.random.uniform(shape=input_shape, minval=0, maxval=1, dtype=tf.float32)) assert isinstance(decoded, list) @@ -85,7 +85,8 @@ def test_reco_postprocessors(post_processor, input_shape, mock_vocab): assert len(decoded) == input_shape[0] assert all(char in mock_vocab for word, _ in decoded for char in word) # Repr - assert repr(processor) == f"{post_processor.__name__}(vocab_size={len(mock_vocab)})" + expected_repr = f"{post_processor.__name__}(vocab_size={len(mock_vocab)}, confidence_aggregation='{default_aggregation}')" + assert repr(processor) == expected_repr @pytest.fixture(scope="session")