Skip to content

Commit 671ac5a

Browse files
authored
Fix alignment between the segments and the list of words (#1087)
* Fix alignment between the segments and the list of words * Ensure the word index does not overflow
1 parent 839639a commit 671ac5a

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

whisper/timing.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import subprocess
23
import warnings
34
from dataclasses import dataclass
@@ -290,34 +291,41 @@ def add_word_timestamps(
290291
if len(segments) == 0:
291292
return
292293

293-
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
294+
text_tokens_per_segment = [
295+
[token for token in segment["tokens"] if token < tokenizer.eot]
296+
for segment in segments
297+
]
298+
299+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
294300
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
295301
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
296302

297303
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
298-
segment_lengths = [len(s["tokens"]) for s in segments]
299-
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
300-
301-
for segment in segments:
302-
segment["words"] = []
303-
304-
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
305-
for i, timing in enumerate(alignment):
306-
if timing.word:
307-
segment = segments[token_sources[word_boundaries[i]]]
308-
start = round(time_offset + timing.start, 2)
309-
end = round(time_offset + timing.end, 2)
310-
segment["words"].append(
311-
dict(
312-
word=timing.word,
313-
start=start,
314-
end=end,
315-
probability=timing.probability,
304+
word_index = 0
305+
306+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
307+
saved_tokens = 0
308+
words = []
309+
310+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
311+
timing = alignment[word_index]
312+
313+
if timing.word:
314+
words.append(
315+
dict(
316+
word=timing.word,
317+
start=round(time_offset + timing.start, 2),
318+
end=round(time_offset + timing.end, 2),
319+
probability=timing.probability,
320+
)
316321
)
317-
)
318322

319-
for segment in segments:
320-
if len(words := segment["words"]) > 0:
323+
saved_tokens += len(timing.tokens)
324+
word_index += 1
325+
326+
if len(words) > 0:
321327
# adjust the segment-level timestamps based on the word-level timestamps
322328
segment["start"] = words[0]["start"]
323329
segment["end"] = words[-1]["end"]
330+
331+
segment["words"] = words

0 commit comments

Comments
 (0)