Skip to content

Commit aa3d1cf

Browse files
naymaraqnaymaraq
authored andcommitted
Inference optimization for cache-aware pipelines (NVIDIA-NeMo#15035)
* optimize context manager and cache feature bufferer Signed-off-by: naymaraq <[email protected]> * speedUp cache_feature_bufferer Signed-off-by: naymaraq <[email protected]> * improved docstring in BatchedCacheFeatureBufferer Signed-off-by: naymaraq <[email protected]> --------- Signed-off-by: naymaraq <[email protected]> Co-authored-by: naymaraq <[email protected]> Signed-off-by: genquan9 <[email protected]>
1 parent 7277489 commit aa3d1cf

File tree

6 files changed

+167
-201
lines changed

6 files changed

+167
-201
lines changed

nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def supports_capitalization(self) -> bool:
140140
Returns:
141141
(bool) True if the ASR model supports capitalization, False otherwise.
142142
"""
143-
if not hasattr(self, "asr_model") or self.asr_model is None:
144-
raise ValueError("ASR model is not initialized.")
145143
return self.tokenizer.supports_capitalization
146144

147145
def supports_punctuation(self) -> bool:
@@ -150,8 +148,6 @@ def supports_punctuation(self) -> bool:
150148
Returns:
151149
(bool) True if the ASR model supports punctuation, False otherwise.
152150
"""
153-
if not hasattr(self, "asr_model") or self.asr_model is None:
154-
raise ValueError("ASR model is not initialized.")
155151
return self.supported_punctuation() != set()
156152

157153
def supported_punctuation(self) -> set:

nemo/collections/asr/inference/pipelines/base_pipeline.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def from_state(cls, state: StreamingState, request: Request, sep: str = ' ') ->
7979
"""
8080
final_transcript = state.final_transcript.strip()
8181
final_segments = [seg.copy() for seg in state.final_segments]
82+
if len(final_segments) > 0:
83+
final_segments[0].text = final_segments[0].text.lstrip(sep)
84+
final_segments[-1].text = final_segments[-1].text.rstrip(sep)
85+
8286
if final_transcript:
8387
separator = ''
8488
if not request.is_first and state.concat_with_space:
@@ -185,11 +189,12 @@ def transcribe_step(self, requests: list[Request]) -> list[TranscribeStepOutput]
185189

186190
# Create current step output for each request
187191
outputs = []
192+
sep = self.get_sep()
188193
for request in requests:
189194

190195
# Extract current step output from the state
191196
state = self.get_state(request.stream_id)
192-
step_output = TranscribeStepOutput.from_state(state=state, request=request, sep=self.get_sep())
197+
step_output = TranscribeStepOutput.from_state(state=state, request=request, sep=sep)
193198
outputs.append(step_output)
194199

195200
# Cleanup the state after the response is sent
@@ -344,6 +349,7 @@ def init_bufferer_for_cache_aware_streaming(self) -> None:
344349
check_existance_of_required_attributes(
345350
self,
346351
[
352+
'num_slots',
347353
'use_feat_cache',
348354
'chunk_size_in_secs',
349355
'buffer_size_in_secs',
@@ -361,6 +367,7 @@ def init_bufferer_for_cache_aware_streaming(self) -> None:
361367
chunk_size_for_feature_buffer = self.buffer_size_in_secs
362368

363369
self.bufferer = BatchedCacheFeatureBufferer(
370+
num_slots=self.num_slots,
364371
sample_rate=self.sample_rate,
365372
buffer_size_in_secs=self.buffer_size_in_secs,
366373
chunk_size_in_secs=chunk_size_for_feature_buffer,
@@ -406,6 +413,7 @@ def run(
406413
request_generator.set_progress_bar(progress_bar)
407414

408415
pipeline_output = {}
416+
sep = self.get_sep()
409417
self.open_session()
410418
for requests in request_generator:
411419
step_outputs = self.transcribe_step(requests)
@@ -417,7 +425,18 @@ def run(
417425
"segments": [],
418426
"audio_filepath": request_generator.get_audio_filepath(stream_id),
419427
}
420-
pipeline_output[stream_id]["text"] += step_output.final_transcript
421-
pipeline_output[stream_id]["segments"].extend(step_output.final_segments)
428+
429+
accumulated_text = pipeline_output[stream_id]["text"]
430+
final_transcript = step_output.final_transcript
431+
final_segments = step_output.final_segments
432+
if not accumulated_text:
433+
final_transcript = final_transcript.lstrip(sep)
434+
if len(final_segments) > 0:
435+
first_segment = final_segments[0]
436+
first_segment.text = first_segment.text.lstrip(sep)
437+
438+
accumulated_text += final_transcript
439+
pipeline_output[stream_id]["text"] = accumulated_text
440+
pipeline_output[stream_id]["segments"].extend(final_segments)
422441
self.close_session()
423442
return pipeline_output

nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,6 @@ def init_endpointer(self) -> None:
163163
residue_tokens_at_end=self.residue_tokens_at_end,
164164
)
165165

166-
def reset_session(self) -> None:
167-
"""Reset the context manager."""
168-
self.context_manager.reset()
169-
super().reset_session()
170-
171166
def create_state(self, options: ASRRequestOptions) -> CacheAwareCTCStreamingState:
172167
"""
173168
Create new empty state.

nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,6 @@ def init_endpointer(self) -> None:
165165
residue_tokens_at_end=self.residue_tokens_at_end,
166166
)
167167

168-
def reset_session(self) -> None:
169-
"""Reset the context manager."""
170-
self.context_manager.reset()
171-
super().reset_session()
172-
173168
def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingState:
174169
"""
175170
Create new empty state.

0 commit comments

Comments
 (0)