Skip to content

Commit c4b50c0

Browse files
authored
kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience * formatting fix
1 parent 38f2f4d commit c4b50c0

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

whisper/decoding.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass, field, replace
22
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
33

44
import numpy as np
@@ -778,7 +778,10 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
778778

779779
@torch.no_grad()
780780
def decode(
781-
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
781+
model: "Whisper",
782+
mel: Tensor,
783+
options: DecodingOptions = DecodingOptions(),
784+
**kwargs,
782785
) -> Union[DecodingResult, List[DecodingResult]]:
783786
"""
784787
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -802,6 +805,9 @@ def decode(
802805
if single := mel.ndim == 2:
803806
mel = mel.unsqueeze(0)
804807

808+
if kwargs:
809+
options = replace(options, **kwargs)
810+
805811
result = DecodingTask(model, options).run(mel)
806812

807813
return result[0] if single else result

0 commit comments

Comments
 (0)