|
| 1 | +import itertools |
1 | 2 | import subprocess |
2 | 3 | import warnings |
3 | 4 | from dataclasses import dataclass |
@@ -290,34 +291,41 @@ def add_word_timestamps( |
290 | 291 | if len(segments) == 0: |
291 | 292 | return |
292 | 293 |
|
293 | | - text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot] |
| 294 | + text_tokens_per_segment = [ |
| 295 | + [token for token in segment["tokens"] if token < tokenizer.eot] |
| 296 | + for segment in segments |
| 297 | + ] |
| 298 | + |
| 299 | + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) |
294 | 300 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) |
295 | 301 | merge_punctuations(alignment, prepend_punctuations, append_punctuations) |
296 | 302 |
|
297 | 303 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE |
298 | | - segment_lengths = [len(s["tokens"]) for s in segments] |
299 | | - token_sources = np.repeat(np.arange(len(segments)), segment_lengths) |
300 | | - |
301 | | - for segment in segments: |
302 | | - segment["words"] = [] |
303 | | - |
304 | | - word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) |
305 | | - for i, timing in enumerate(alignment): |
306 | | - if timing.word: |
307 | | - segment = segments[token_sources[word_boundaries[i]]] |
308 | | - start = round(time_offset + timing.start, 2) |
309 | | - end = round(time_offset + timing.end, 2) |
310 | | - segment["words"].append( |
311 | | - dict( |
312 | | - word=timing.word, |
313 | | - start=start, |
314 | | - end=end, |
315 | | - probability=timing.probability, |
| 304 | + word_index = 0 |
| 305 | + |
| 306 | + for segment, text_tokens in zip(segments, text_tokens_per_segment): |
| 307 | + saved_tokens = 0 |
| 308 | + words = [] |
| 309 | + |
| 310 | + while word_index < len(alignment) and saved_tokens < len(text_tokens): |
| 311 | + timing = alignment[word_index] |
| 312 | + |
| 313 | + if timing.word: |
| 314 | + words.append( |
| 315 | + dict( |
| 316 | + word=timing.word, |
| 317 | + start=round(time_offset + timing.start, 2), |
| 318 | + end=round(time_offset + timing.end, 2), |
| 319 | + probability=timing.probability, |
| 320 | + ) |
316 | 321 | ) |
317 | | - ) |
318 | 322 |
|
319 | | - for segment in segments: |
320 | | - if len(words := segment["words"]) > 0: |
| 323 | + saved_tokens += len(timing.tokens) |
| 324 | + word_index += 1 |
| 325 | + |
| 326 | + if len(words) > 0: |
321 | 327 | # adjust the segment-level timestamps based on the word-level timestamps |
322 | 328 | segment["start"] = words[0]["start"] |
323 | 329 | segment["end"] = words[-1]["end"] |
| 330 | + |
| 331 | + segment["words"] = words |
0 commit comments