@@ -112,8 +112,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
112112 self .sortformer_modules .encoder_proj = None
113113 self ._init_loss_weights ()
114114
115- self .eps = 1e-3
116- self .negative_init_val = - 99
115+ self .eps = self . _cfg . get ( "eps" , 1e-3 )
116+ self .negative_init_val = self . _cfg . get ( "negative_init_val" , - 99 )
117117 self .loss = instantiate (self ._cfg .loss )
118118
119119 self .async_streaming = self ._cfg .get ("async_streaming" , False )
@@ -832,6 +832,7 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict:
832832 Returns:
833833 (dict): A dictionary containing the following training metrics.
834834 """
835+ targets = targets .to (preds .dtype )
835836 if preds .shape [1 ] < targets .shape [1 ]:
836837 logging .info (
837838 f"WARNING! preds has less frames than targets ({ preds .shape [1 ]} < { targets .shape [1 ]} ). "
@@ -904,6 +905,7 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict:
904905 Returns:
905906 val_metrics (dict): A dictionary containing the following validation metrics
906907 """
908+ targets = targets .to (preds .dtype )
907909 if preds .shape [1 ] < targets .shape [1 ]:
908910 logging .info (
909911 f"WARNING! preds has less frames than targets ({ preds .shape [1 ]} < { targets .shape [1 ]} ). "
@@ -1035,6 +1037,7 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target
10351037 target_lens (torch.Tensor): Lengths of target sequences.
10361038 Shape: (batch_size,)
10371039 """
1040+ targets = targets .to (preds .dtype )
10381041 if preds .shape [1 ] < targets .shape [1 ]:
10391042 logging .info (
10401043 f"WARNING! preds has less frames than targets ({ preds .shape [1 ]} < { targets .shape [1 ]} ). "
0 commit comments