Skip to content

Commit d3354d2

Browse files
authored
Merge branch 'main' into pablo-garay/remove_exportDeploy
2 parents da6975d + b6547c3 commit d3354d2

File tree

2 files changed

+84
-82
lines changed

2 files changed

+84
-82
lines changed

nemo/collections/asr/models/sortformer_diar_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
112112
self.sortformer_modules.encoder_proj = None
113113
self._init_loss_weights()
114114

115-
self.eps = 1e-3
116-
self.negative_init_val = -99
115+
self.eps = self._cfg.get("eps", 1e-3)
116+
self.negative_init_val = self._cfg.get("negative_init_val", -99)
117117
self.loss = instantiate(self._cfg.loss)
118118

119119
self.async_streaming = self._cfg.get("async_streaming", False)
@@ -832,6 +832,7 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict:
832832
Returns:
833833
(dict): A dictionary containing the following training metrics.
834834
"""
835+
targets = targets.to(preds.dtype)
835836
if preds.shape[1] < targets.shape[1]:
836837
logging.info(
837838
f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). "
@@ -904,6 +905,7 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict:
904905
Returns:
905906
val_metrics (dict): A dictionary containing the following validation metrics
906907
"""
908+
targets = targets.to(preds.dtype)
907909
if preds.shape[1] < targets.shape[1]:
908910
logging.info(
909911
f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). "
@@ -1035,6 +1037,7 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target
10351037
target_lens (torch.Tensor): Lengths of target sequences.
10361038
Shape: (batch_size,)
10371039
"""
1040+
targets = targets.to(preds.dtype)
10381041
if preds.shape[1] < targets.shape[1]:
10391042
logging.info(
10401043
f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). "

0 commit comments

Comments
 (0)