Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions doctr/models/recognition/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,69 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Callable, Literal, Union

import numpy as np

from doctr.datasets import encode_sequences
from doctr.utils.repr import NestedObject

__all__ = ["RecognitionPostProcessor", "RecognitionModel"]
__all__ = ["RecognitionPostProcessor", "RecognitionModel", "aggregate_confidence", "ConfidenceAggregation"]

# Type alias for confidence aggregation methods
ConfidenceAggregation = Union[Literal["mean", "geometric_mean", "harmonic_mean", "min", "max"], Callable[[np.ndarray], float]]


def aggregate_confidence(
probs: np.ndarray,
method: ConfidenceAggregation = "mean",
) -> float:
"""Aggregate character-level confidence scores into a word-level confidence score.

Args:
probs: Array of character-level confidence scores (values between 0 and 1)
method: Aggregation method to use. Can be one of:
- "mean": Arithmetic mean (default)
- "geometric_mean": Geometric mean (more sensitive to low values)
- "harmonic_mean": Harmonic mean (even more sensitive to low values)
- "min": Minimum confidence (most conservative)
- "max": Maximum confidence (most optimistic)
- A callable that takes an ndarray and returns a float

Returns:
Aggregated confidence score as a float between 0 and 1
"""
if len(probs) == 0:
return 0.0

# Convert to numpy if needed and ensure float type
probs = np.asarray(probs, dtype=np.float64)

# Clip to valid probability range
probs = np.clip(probs, 0.0, 1.0)

if callable(method):
return float(method(probs))

if method == "mean":
return float(np.mean(probs))
elif method == "geometric_mean":
# Use log-sum-exp trick for numerical stability
# geometric_mean = exp(mean(log(probs)))
# Handle zeros by replacing with small epsilon
safe_probs = np.where(probs > 0, probs, 1e-10)
return float(np.exp(np.mean(np.log(safe_probs))))
elif method == "harmonic_mean":
# harmonic_mean = n / sum(1/probs)
# Handle zeros by replacing with small epsilon
safe_probs = np.where(probs > 0, probs, 1e-10)
return float(len(safe_probs) / np.sum(1.0 / safe_probs))
elif method == "min":
return float(np.min(probs))
elif method == "max":
return float(np.max(probs))
else:
raise ValueError(f"Unknown aggregation method: {method}. Expected one of 'mean', 'geometric_mean', 'harmonic_mean', 'min', 'max', or a callable.")


class RecognitionModel(NestedObject):
Expand Down Expand Up @@ -41,14 +97,19 @@ class RecognitionPostProcessor(NestedObject):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable.
"""

def __init__(
self,
vocab: str,
confidence_aggregation: ConfidenceAggregation = "mean",
) -> None:
self.vocab = vocab
self.confidence_aggregation = confidence_aggregation
self._embedding = list(self.vocab) + ["<eos>"]

def extra_repr(self) -> str:
return f"vocab_size={len(self.vocab)}"
agg_repr = self.confidence_aggregation if isinstance(self.confidence_aggregation, str) else "custom"
return f"vocab_size={len(self.vocab)}, confidence_aggregation='{agg_repr}'"
55 changes: 43 additions & 12 deletions doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.pytorch import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
from ..core import ConfidenceAggregation, RecognitionModel, RecognitionPostProcessor, aggregate_confidence

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]

Expand Down Expand Up @@ -50,10 +50,19 @@ class CTCPostProcessor(RecognitionPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min".
"""

@staticmethod
def __init__(
self,
vocab: str,
confidence_aggregation: ConfidenceAggregation = "min",
) -> None:
super().__init__(vocab, confidence_aggregation)

def ctc_best_path(
self,
logits: torch.Tensor,
vocab: str = VOCABS["french"],
blank: int = 0,
Expand All @@ -69,16 +78,38 @@ def ctc_best_path(
Returns:
A list of tuples: (word, confidence)
"""
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values

# collapse best path (using itertools.groupby), map to chars, join char list to string
words = [
decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
for seq in torch.argmax(logits, dim=-1)
]

return list(zip(words, probs.tolist()))
# Get softmax probabilities and best path indices
softmax_probs = F.softmax(logits, dim=-1)
best_path_indices = torch.argmax(logits, dim=-1) # N x T
# Get the probability of the best path at each time step
best_path_probs = softmax_probs.max(dim=-1).values # N x T

results = []
for batch_idx in range(logits.size(0)):
seq = best_path_indices[batch_idx].tolist()
probs_seq = best_path_probs[batch_idx]

# Collapse best path: remove blanks and repeated characters
# Track which positions contribute to the final word
char_probs = []
prev_char = None
for pos, char_idx in enumerate(seq):
if char_idx != blank and char_idx != prev_char:
char_probs.append(probs_seq[pos].item())
prev_char = char_idx

# Decode the word
word = decode_sequence([k for k, _ in groupby(seq) if k != blank], vocab)

# Aggregate character probabilities
if char_probs:
conf = aggregate_confidence(char_probs, self.confidence_aggregation)
else:
conf = 0.0

results.append((word, conf))

return results

def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
"""Performs decoding of raw output with CTC and decoding of CTC predictions
Expand Down
7 changes: 5 additions & 2 deletions doctr/models/recognition/master/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ....datasets import encode_sequences
from ..core import RecognitionPostProcessor
from ..core import ConfidenceAggregation, RecognitionPostProcessor


class _MASTER:
Expand Down Expand Up @@ -44,11 +44,14 @@ class _MASTERPostProcessor(RecognitionPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min".
"""

def __init__(
self,
vocab: str,
confidence_aggregation: ConfidenceAggregation = "min",
) -> None:
super().__init__(vocab)
super().__init__(vocab, confidence_aggregation)
self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
23 changes: 18 additions & 5 deletions doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from ..core import aggregate_confidence
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -258,7 +259,13 @@ def decode(self, encoded: torch.Tensor) -> torch.Tensor:


class MASTERPostProcessor(_MASTERPostProcessor):
"""Post processor for MASTER architectures"""
"""Post processor for MASTER architectures

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min".
"""

def __call__(
self,
Expand All @@ -267,17 +274,23 @@ def __call__(
# compute pred with argmax for attention models
out_idxs = logits.argmax(-1)
# N x L
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
# Take the minimum confidence of the sequence
probs = probs.min(dim=1).values.detach().cpu()
preds_prob = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)

# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.cpu().numpy()
]

return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
# compute probabilities for each word up to the EOS token using configured aggregation method
probs = [
aggregate_confidence(preds_prob[i, : len(word)].detach().cpu().numpy(), self.confidence_aggregation)
if word
else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))


def _master(
Expand Down
17 changes: 13 additions & 4 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from ..core import aggregate_confidence
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -250,6 +251,8 @@ class MASTERPostProcessor(_MASTERPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "min".
"""

def __call__(
Expand All @@ -259,9 +262,7 @@ def __call__(
# compute pred with argmax for attention models
out_idxs = tf.math.argmax(logits, axis=2)
# N x L
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
# Take the minimum confidence of the sequence
probs = tf.math.reduce_min(probs, axis=1)
preds_prob = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)

# decode raw output of the model with tf_label_to_idx
out_idxs = tf.cast(out_idxs, dtype="int32")
Expand All @@ -271,7 +272,15 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
# compute probabilities for each word up to the EOS token using configured aggregation method
probs = [
aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation)
if word
else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))


def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER:
Expand Down
7 changes: 5 additions & 2 deletions doctr/models/recognition/parseq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ....datasets import encode_sequences
from ..core import RecognitionPostProcessor
from ..core import ConfidenceAggregation, RecognitionPostProcessor


class _PARSeq:
Expand Down Expand Up @@ -44,11 +44,14 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable.
"""

def __init__(
self,
vocab: str,
confidence_aggregation: ConfidenceAggregation = "mean",
) -> None:
super().__init__(vocab)
super().__init__(vocab, confidence_aggregation)
self._embedding = list(vocab) + ["<eos>", "<sos>", "<pad>"]
10 changes: 8 additions & 2 deletions doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ...classification import vit_s
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from ..core import aggregate_confidence
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -409,6 +410,8 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean".
"""

def __call__(
Expand All @@ -424,9 +427,12 @@ def __call__(
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
# compute probabilities for each word up to the EOS token using configured aggregation method
probs = [
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
aggregate_confidence(preds_prob[i, : len(word)].cpu().numpy(), self.confidence_aggregation)
if word
else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))
Expand Down
9 changes: 7 additions & 2 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ...classification import vit_s
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from ..core import aggregate_confidence
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -423,6 +424,8 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
confidence_aggregation: method to aggregate character-level confidence scores into word-level confidence.
Can be "mean", "geometric_mean", "harmonic_mean", "min", "max", or a custom callable. Defaults to "mean".
"""

def __call__(
Expand All @@ -441,9 +444,11 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

# compute probabilties for each word up to the EOS token
# compute probabilities for each word up to the EOS token using configured aggregation method
probs = [
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
aggregate_confidence(preds_prob[i, : len(word)].numpy(), self.confidence_aggregation)
if word
else 0.0
for i, word in enumerate(word_values)
]

Expand Down
Loading