diff --git a/examples/voice_agent/client/package-lock.json b/examples/voice_agent/client/package-lock.json index 36293625008b..710b74e52476 100644 --- a/examples/voice_agent/client/package-lock.json +++ b/examples/voice_agent/client/package-lock.json @@ -11,11 +11,15 @@ "dependencies": { "@pipecat-ai/client-js": "^0.4.0", "@pipecat-ai/websocket-transport": "^0.4.1", - "protobufjs": "^7.4.0" + "protobufjs": "^7.4.0", + "react": "^19.2.0", + "react-dom": "^19.2.0" }, "devDependencies": { "@types/node": "^22.15.30", "@types/protobufjs": "^6.0.0", + "@types/react": "^19.2.2", + "@types/react-dom": "^19.2.2", "@vitejs/plugin-react-swc": "^3.10.1", "typescript": "^5.8.3", "vite": "^6.3.5" @@ -1185,6 +1189,26 @@ "protobufjs": "*" } }, + "node_modules/@types/react": { + "version": "19.2.2", + "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.2.tgz", + "integrity": "sha512-6mDvHUFSjyT2B2yeNx2nUgMxh9LtOWvkhIU3uePn2I2oyNymUAX1NIsdgviM4CH+JSrp2D2hsMvJOkxY+0wNRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "csstype": "^3.0.2" + } + }, + "node_modules/@types/react-dom": { + "version": "19.2.2", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.2.2.tgz", + "integrity": "sha512-9KQPoO6mZCi7jcIStSnlOWn2nEF3mNmyr3rIAsGnAbQKYbRLyqmeSc39EVgtxXVia+LMT8j3knZLAZAh+xLmrw==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "@types/react": "^19.2.0" + } + }, "node_modules/@typescript/vfs": { "version": "1.6.1", "resolved": "https://registry.npmjs.org/@typescript/vfs/-/vfs-1.6.1.tgz", @@ -1227,6 +1251,13 @@ "node": ">=6" } }, + "node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", + "dev": true, + "license": "MIT" + }, "node_modules/debug": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", @@ -1459,6 +1490,27 @@ "undici-types": "~7.8.0" } }, + "node_modules/react": { + "version": "19.2.0", + "resolved": "https://registry.npmjs.org/react/-/react-19.2.0.tgz", + "integrity": "sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "19.2.0", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.0.tgz", + "integrity": "sha512-UlbRu4cAiGaIewkPyiRGJk0imDN2T3JjieT6spoL2UeSf5od4n5LB/mQ4ejmxhCFT1tYe8IvaFulzynWovsEFQ==", + "license": "MIT", + "dependencies": { + "scheduler": "^0.27.0" + }, + "peerDependencies": { + "react": "^19.2.0" + } + }, "node_modules/rollup": { "version": "4.43.0", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.43.0.tgz", @@ -1507,6 +1559,12 @@ "tslib": "^2.1.0" } }, + "node_modules/scheduler": { + "version": "0.27.0", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", + "integrity": "sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==", + "license": "MIT" + }, "node_modules/shallow-clone": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/shallow-clone/-/shallow-clone-3.0.1.tgz", diff --git a/examples/voice_agent/client/package.json b/examples/voice_agent/client/package.json index d2df048f50f8..857f33a90815 100644 --- a/examples/voice_agent/client/package.json +++ b/examples/voice_agent/client/package.json @@ -14,6 +14,8 @@ "devDependencies": { "@types/node": "^22.15.30", "@types/protobufjs": "^6.0.0", + "@types/react": "^19.2.2", + "@types/react-dom": "^19.2.2", "@vitejs/plugin-react-swc": "^3.10.1", "typescript": "^5.8.3", "vite": "^6.3.5" @@ -21,6 +23,8 @@ "dependencies": { "@pipecat-ai/client-js": "^0.4.0", "@pipecat-ai/websocket-transport": "^0.4.1", - "protobufjs": "^7.4.0" + "protobufjs": "^7.4.0", + "react": "^19.2.0", + "react-dom": "^19.2.0" } } diff --git a/examples/voice_agent/client/src/app.ts b/examples/voice_agent/client/src/app.ts index c9809fa69c8a..82741cdb8bdd 100644 --- a/examples/voice_agent/client/src/app.ts +++ b/examples/voice_agent/client/src/app.ts @@ -46,12 +46,12 @@ class WebsocketClientApp { private readonly serverConfigs = { websocket: { name: 'WebSocket Server', - baseUrl: 'http://localhost:7860', + baseUrl: `http://${window.location.hostname}:7860`, port: 8765 }, fastapi: { name: 'FastAPI Server', - baseUrl: 'http://localhost:8000', + baseUrl: `http://${window.location.hostname}:8000`, port: 8000 } }; @@ -257,6 +257,7 @@ class WebsocketClientApp { this.log('Initializing devices...'); await this.rtviClient.initDevices(); + this.log('Devices initialized successfully'); this.log('Connecting to bot...'); await this.rtviClient.connect(); diff --git a/examples/voice_agent/client/tsconfig.json b/examples/voice_agent/client/tsconfig.json index c9c555d96f35..8e78ed030a2d 100644 --- a/examples/voice_agent/client/tsconfig.json +++ b/examples/voice_agent/client/tsconfig.json @@ -11,9 +11,9 @@ // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */ /* Language and Environment */ - "target": "es2016", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ - // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ - // "jsx": "preserve", /* Specify what JSX code is generated. */ + "target": "ES2020", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ + "lib": ["ES2020", "DOM", "DOM.Iterable"], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ + "jsx": "react-jsx", /* Specify what JSX code is generated. */ // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */ // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */ // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */ @@ -25,9 +25,9 @@ // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */ /* Modules */ - "module": "commonjs", /* Specify what module code is generated. */ + "module": "ESNext", /* Specify what module code is generated. */ // "rootDir": "./", /* Specify the root folder within your source files. */ - // "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */ + "moduleResolution": "bundler", /* Specify how TypeScript looks up a file from a given module specifier. */ // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */ // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ @@ -41,7 +41,7 @@ // "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */ // "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */ // "noUncheckedSideEffectImports": true, /* Check side effect imports. */ - // "resolveJsonModule": true, /* Enable importing .json files. */ + "resolveJsonModule": true, /* Enable importing .json files. */ // "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */ // "noResolve": true, /* Disallow 'import's, 'require's or ''s from expanding the number of files TypeScript should add to a project. */ @@ -74,10 +74,10 @@ // "declarationDir": "./", /* Specify the output directory for generated declaration files. */ /* Interop Constraints */ - // "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */ + "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */ // "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */ // "isolatedDeclarations": true, /* Require sufficient annotation on exports so other tools can trivially generate declaration files. */ - // "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */ + "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */ "esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */ // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */ "forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */ diff --git a/examples/voice_agent/server/backchannel_phrases.yaml b/examples/voice_agent/server/backchannel_phrases.yaml index 38c7523a7153..ac8b7cebf28b 100644 --- a/examples/voice_agent/server/backchannel_phrases.yaml +++ b/examples/voice_agent/server/backchannel_phrases.yaml @@ -11,7 +11,6 @@ - "great" - "great thanks" - "ha ha" -- "hi" - "hmm" - "humm" - "huh" diff --git a/examples/voice_agent/server/bot_websocket_server.py b/examples/voice_agent/server/bot_websocket_server.py index 572e322576ea..b622f2fddc60 100644 --- a/examples/voice_agent/server/bot_websocket_server.py +++ b/examples/voice_agent/server/bot_websocket_server.py @@ -28,12 +28,13 @@ from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frameworks.rtvi import RTVIAction, RTVIConfig, RTVIObserver, RTVIProcessor +from pipecat.processors.frameworks.rtvi import RTVIAction, RTVIConfig, RTVIProcessor from pipecat.serializers.protobuf import ProtobufFrameSerializer - -from nemo.agents.voice_agent.pipecat.services.nemo.diar import NeMoDiarInputParams, NemoDiarService +from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger +from nemo.agents.voice_agent.pipecat.processors.frameworks.rtvi import RTVIObserver +from nemo.agents.voice_agent.pipecat.services.nemo.diar import NemoDiarService from nemo.agents.voice_agent.pipecat.services.nemo.llm import get_llm_service_from_config -from nemo.agents.voice_agent.pipecat.services.nemo.stt import NeMoSTTInputParams, NemoSTTService +from nemo.agents.voice_agent.pipecat.services.nemo.stt import NemoSTTService from nemo.agents.voice_agent.pipecat.services.nemo.tts import KokoroTTSService, NeMoFastPitchHiFiGANTTSService from nemo.agents.voice_agent.pipecat.services.nemo.turn_taking import NeMoTurnTakingService from nemo.agents.voice_agent.pipecat.transports.network.websocket_server import ( @@ -77,6 +78,8 @@ def setup_logging(): # Transport configuration TRANSPORT_AUDIO_OUT_10MS_CHUNKS = config_manager.TRANSPORT_AUDIO_OUT_10MS_CHUNKS +RECORD_AUDIO_DATA = server_config.transport.get("record_audio_data", False) +AUDIO_LOG_DIR = server_config.transport.get("audio_log_dir", "./audio_logs") # VAD configuration vad_params = config_manager.get_vad_params() @@ -127,6 +130,21 @@ async def run_bot_websocket_server(): - Server will run indefinitely until manually stopped (Ctrl+C) """ + # Initialize AudioLogger if recording is enabled + audio_logger = None + if RECORD_AUDIO_DATA: + from datetime import datetime + + session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + audio_logger = AudioLogger( + log_dir=AUDIO_LOG_DIR, + session_id=session_id, + enabled=True, + ) + logger.info(f"AudioLogger initialized for session: {session_id} at {AUDIO_LOG_DIR}") + else: + logger.info("Audio logging is disabled") + vad_analyzer = SileroVADAnalyzer( sample_rate=SAMPLE_RATE, params=vad_params, @@ -161,6 +179,8 @@ async def run_bot_websocket_server(): has_turn_taking=True, backend="legacy", decoder_type="rnnt", + record_audio_data=RECORD_AUDIO_DATA, + audio_logger=audio_logger, ) logger.info("STT service initialized") @@ -183,6 +203,7 @@ async def run_bot_websocket_server(): max_buffer_size=TURN_TAKING_MAX_BUFFER_SIZE, bot_stop_delay=TURN_TAKING_BOT_STOP_DELAY, backchannel_phrases=TURN_TAKING_BACKCHANNEL_PHRASES_PATH, + audio_logger=audio_logger, ) logger.info("Turn taking service initialized") @@ -200,6 +221,8 @@ async def run_bot_websocket_server(): device=TTS_DEVICE, text_aggregator=text_aggregator, think_tokens=TTS_THINK_TOKENS, + record_audio_data=RECORD_AUDIO_DATA, + audio_logger=audio_logger, ) elif TTS_TYPE == "kokoro": tts = KokoroTTSService( @@ -208,6 +231,8 @@ async def run_bot_websocket_server(): speed=config_manager.server_config.tts.speed, text_aggregator=text_aggregator, think_tokens=TTS_THINK_TOKENS, + record_audio_data=RECORD_AUDIO_DATA, + audio_logger=audio_logger, ) else: raise ValueError(f"Invalid TTS type: {TTS_TYPE}") @@ -243,7 +268,9 @@ async def reset_context_handler(rtvi_processor: RTVIProcessor, service: str, arg assistant_context_aggregator.reset() user_context_aggregator.set_messages(copy.deepcopy(original_messages)) assistant_context_aggregator.set_messages(copy.deepcopy(original_messages)) - + text_aggregator.reset() + if diar is not None: + diar.reset() logger.info("Conversation context reset successfully") return True except Exception as e: @@ -276,6 +303,7 @@ async def reset_context_handler(rtvi_processor: RTVIProcessor, service: str, arg pipeline = Pipeline(pipeline) + rtvi_text_aggregator = SimpleSegmentedTextAggregator("\n?!.", min_sentence_length=5) task = PipelineTask( pipeline, params=PipelineParams( @@ -286,7 +314,7 @@ async def reset_context_handler(rtvi_processor: RTVIProcessor, service: str, arg report_only_initial_ttfb=True, idle_timeout=None, # Disable idle timeout ), - observers=[RTVIObserver(rtvi)], + observers=[RTVIObserver(rtvi, text_aggregator=rtvi_text_aggregator)], idle_timeout_secs=None, cancel_on_idle_timeout=False, ) @@ -317,6 +345,10 @@ async def on_client_connected(transport, client): @ws_transport.event_handler("on_client_disconnected") async def on_client_disconnected(transport, client): logger.info(f"Pipecat Client disconnected from {client.remote_address}") + # Finalize audio logger session if enabled + if audio_logger: + audio_logger.finalize_session() + logger.info("Audio logger session finalized") # Don't cancel the task immediately - let it handle the disconnection gracefully # The task will continue running and can accept new connections # Only send an EndTaskFrame to clean up the current session @@ -349,6 +381,10 @@ async def on_session_timeout(transport, client): logger.error(f"Pipeline runner error: {e}") task_running = False finally: + # Finalize audio logger on shutdown + if audio_logger: + audio_logger.finalize_session() + logger.info("Audio logger session finalized on shutdown") logger.info("Pipeline runner stopped") diff --git a/examples/voice_agent/server/example_prompts/fast-bite.txt b/examples/voice_agent/server/example_prompts/fast-bite.txt index c1529e45232c..593097292946 100644 --- a/examples/voice_agent/server/example_prompts/fast-bite.txt +++ b/examples/voice_agent/server/example_prompts/fast-bite.txt @@ -1,6 +1,6 @@ Fast Bites Lunch Menu -Burgers and Sandwiches +Burgers and Sandwiches: 1. Classic Cheeseburger – $5.99 Juicy beef patty, cheddar cheese, pickles, ketchup & mustard on a toasted bun. - Make it a double cheeseburger by adding another patty - $1.50 @@ -14,18 +14,18 @@ Combo Deals (includes small fries and fountain soda) 5. Chicken Sandwich Combo – $9.49 6. Veggie Wrap Combo – $8.49 -Sides +Sides: 7. French Fries - Small - $2.49 - Medium - $3.49 - Large - $4.49 8. Chicken Nuggets - - 4 pcs - $3.29 - - 8 pcs - $5.99 - - 12 pcs - $8.99 -9. Side Salad - $2.99 + - 4 pieces - $3.29 + - 8 pieces - $5.99 + - 12 pieces - $8.99 +9. Side Salad - $2.99 -Drinks +Drinks: 10. Fountain Soda (16 oz, choices: Coke, Diet Coke, Sprite, Fanta) – $1.99 11. Iced Tea or Lemonade – $2.29 12. Bottled Water – $1.49 diff --git a/examples/voice_agent/server/server_configs/default.yaml b/examples/voice_agent/server/server_configs/default.yaml index ea4005ea5c2f..f10ddcf774b2 100644 --- a/examples/voice_agent/server/server_configs/default.yaml +++ b/examples/voice_agent/server/server_configs/default.yaml @@ -5,6 +5,8 @@ transport: audio_out_10ms_chunks: 10 # use 4 as websocket default, but increasing to a larger number might have less glitches in TTS output + # record_audio_data: false + record_audio_data: true vad: type: silero @@ -15,7 +17,8 @@ vad: stt: type: nemo # choices in ['nemo'] currently only NeMo is supported - model: "stt_en_fastconformer_hybrid_large_streaming_80ms" + # model: "stt_en_fastconformer_hybrid_large_streaming_80ms" + model: "nvidia/parakeet_realtime_eou_120m-v1" model_config: "./server_configs/stt_configs/nemo_cache_aware_streaming.yaml" device: "cuda" @@ -41,12 +44,12 @@ llm: # model_config: "./server_configs/llm_configs/qwen2.5-7B.yaml" # model: "Qwen/Qwen3-8B" # model_config: "./server_configs/llm_configs/qwen3-8B.yaml" + # model: meta-llama/Llama-3.1-8B-Instruct device: "cuda" enable_reasoning: false # it's best to turn-off reasoning for lowest latency # `system_prompt` is used as the sytem prompt to the LLM, please refer to differnt LLM webpage for spcial functions like enabling/disabling thinking # system_prompt: /path/to/prompt.txt # or use path to a txt file that contains a long prompt, for example in `../example_prompts/fast_bite.txt` system_prompt: "You are a helpful AI agent named Lisa. Start by greeting the user warmly and introducing yourself within one sentence. Your answer should be concise and to the point. You might also see speaker tags (, , etc.) in the user context. You should respond to the user based on the speaker tag and the context of that speaker. Do not include the speaker tags in your response, use them only to identify the speaker. Do not include any emoji in response." - tts: type: kokoro # choices in ['nemo', 'kokoro'] model: "hexgrad/Kokoro-82M" diff --git a/examples/voice_agent/server/server_configs/tts_configs/kokoro_82M.yaml b/examples/voice_agent/server/server_configs/tts_configs/kokoro_82M.yaml index 48f72b726c65..949f8eced7cb 100644 --- a/examples/voice_agent/server/server_configs/tts_configs/kokoro_82M.yaml +++ b/examples/voice_agent/server/server_configs/tts_configs/kokoro_82M.yaml @@ -6,7 +6,11 @@ sub_model_id: "af_heart" # "af_heart" "af_bella" "am_fenrir" "am_michael" device: "cuda" speed: 1.25 # Speaking rate extra_separator: # a list of additional punctuations to chunk LLM response into segments for faster TTS output, e.g., ",". Set to `null` to use default behavior - - "," + - ',' + - '\n' + - "." - "?" - "!" + - ";" + - ":" think_tokens: ["", ""] # specify them to avoid TTS for thinking process, set to `null` to allow thinking out loud diff --git a/examples/voice_agent/server/server_configs/tts_configs/nemo_fastpitch-hifigan.yaml b/examples/voice_agent/server/server_configs/tts_configs/nemo_fastpitch-hifigan.yaml index ab9bfa36d95d..e4539364949c 100644 --- a/examples/voice_agent/server/server_configs/tts_configs/nemo_fastpitch-hifigan.yaml +++ b/examples/voice_agent/server/server_configs/tts_configs/nemo_fastpitch-hifigan.yaml @@ -5,7 +5,11 @@ main_model_id: "nvidia/tts_en_fastpitch" sub_model_id: "nvidia/tts_hifigan" device: "cuda" extra_separator: # a list of additional punctuations to chunk LLM response into segments for faster TTS output, e.g., ",". Set to `null` to use default behavior - - "," + - ',' + - '\n' + - "." - "?" - "!" + - ";" + - ":" think_tokens: ["", ""] # specify them to avoid TTS for thinking process, set to `null` to allow thinking out loud diff --git a/nemo/agents/voice_agent/pipecat/processors/__init__.py b/nemo/agents/voice_agent/pipecat/processors/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/agents/voice_agent/pipecat/processors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/agents/voice_agent/pipecat/processors/frameworks/__init__.py b/nemo/agents/voice_agent/pipecat/processors/frameworks/__init__.py new file mode 100644 index 000000000000..341a77c5bc66 --- /dev/null +++ b/nemo/agents/voice_agent/pipecat/processors/frameworks/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/agents/voice_agent/pipecat/processors/frameworks/rtvi.py b/nemo/agents/voice_agent/pipecat/processors/frameworks/rtvi.py new file mode 100644 index 000000000000..755c34dc93fc --- /dev/null +++ b/nemo/agents/voice_agent/pipecat/processors/frameworks/rtvi.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from loguru import logger +from pipecat.frames.frames import LLMTextFrame +from pipecat.processors.frameworks.rtvi import RTVIBotLLMTextMessage, RTVIBotTranscriptionMessage +from pipecat.processors.frameworks.rtvi import RTVIObserver as _RTVIObserver +from pipecat.processors.frameworks.rtvi import RTVIProcessor, RTVITextMessageData + +from nemo.agents.voice_agent.pipecat.utils.text.simple_text_aggregator import SimpleSegmentedTextAggregator + + +class RTVIObserver(_RTVIObserver): + def __init__( + self, rtvi: RTVIProcessor, text_aggregator: Optional[SimpleSegmentedTextAggregator] = None, *args, **kwargs + ): + super().__init__(rtvi, *args, **kwargs) + self._text_aggregator = text_aggregator if text_aggregator else SimpleSegmentedTextAggregator("?!:.") + + async def _handle_llm_text_frame(self, frame: LLMTextFrame): + """Handle LLM text output frames.""" + message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text)) + await self.push_transport_message_urgent(message) + + completed_text = await self._text_aggregator.aggregate(frame.text) + if completed_text: + await self._push_bot_transcription(completed_text) + + async def _push_bot_transcription(self, text: str): + """Push accumulated bot transcription as a message.""" + if len(text.strip()) > 0: + message = RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=text)) + logger.debug(f"Pushing bot transcription: `{text}`") + await self.push_transport_message_urgent(message) + self._bot_transcription = "" diff --git a/nemo/agents/voice_agent/pipecat/services/nemo/audio_logger.py b/nemo/agents/voice_agent/pipecat/services/nemo/audio_logger.py new file mode 100644 index 000000000000..b2a56c259752 --- /dev/null +++ b/nemo/agents/voice_agent/pipecat/services/nemo/audio_logger.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import threading +import wave +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +import numpy as np +from loguru import logger + + +class AudioLogger: + """ + Utility class for logging audio data and transcriptions during voice agent interactions. + + This logger saves: + - Audio files in WAV format + - Transcriptions with metadata in JSON format + - Session information and metadata + + File structure: + log_dir/ + ├── session_YYYYMMDD_HHMMSS/ + │ ├── user/ + │ │ ├── 00001_HHMMSS.wav + │ │ ├── 00001_HHMMSS.json + │ │ ├── 00002_HHMMSS.wav + │ │ └── 00002_HHMMSS.json + │ ├── agent/ + │ │ ├── 00001_HHMMSS.wav + │ │ ├── 00001_HHMMSS.json + │ └── session_metadata.json + + Args: + log_dir: Base directory for storing logs (default: "./audio_logs") + session_id: Optional custom session ID. If None, auto-generated from timestamp + enabled: Whether logging is enabled (default: True) + """ + + def __init__( + self, + log_dir: Union[str, Path] = "./audio_logs", + session_id: Optional[str] = None, + enabled: bool = True, + ): + self.enabled = enabled + if not self.enabled: + logger.info("AudioLogger is disabled") + return + + self.log_dir = Path(log_dir) + + # Generate session ID if not provided + if session_id is None: + session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + self.session_id = session_id + self.session_dir = self.log_dir / session_id + + # Create directories + self.user_dir = self.session_dir / "user" + self.agent_dir = self.session_dir / "agent" + + self.user_dir.mkdir(parents=True, exist_ok=True) + self.agent_dir.mkdir(parents=True, exist_ok=True) + + # Counters for file naming (thread-safe) + self._user_counter = 0 + self._agent_counter = 0 + self._lock = threading.Lock() + self._staged_metadata = None + self._staged_audio_data = None + + # Session metadata + self.session_metadata = { + "session_id": session_id, + "start_time": datetime.now().isoformat(), + "user_entries": [], + "agent_entries": [], + } + + logger.info(f"AudioLogger initialized: {self.session_dir}") + + def _get_next_counter(self, speaker: str) -> int: + """Get the next counter value for a speaker in a thread-safe manner.""" + with self._lock: + if speaker == "user": + self._user_counter += 1 + return self._user_counter + else: + self._agent_counter += 1 + return self._agent_counter + + def _save_audio_wav( + self, + audio_data: Union[bytes, np.ndarray], + file_path: Path, + sample_rate: int, + num_channels: int = 1, + ): + """ + Save audio data to a WAV file. + + Args: + audio_data: Audio data as bytes or numpy array + file_path: Path to save the WAV file + sample_rate: Audio sample rate in Hz + num_channels: Number of audio channels (default: 1) + """ + try: + # Convert audio data to bytes if it's a numpy array + if isinstance(audio_data, np.ndarray): + if audio_data.dtype in [np.float32, np.float64]: + # Convert float [-1, 1] to int16 [-32768, 32767] + audio_data = np.clip(audio_data, -1.0, 1.0) + audio_data = (audio_data * 32767).astype(np.int16) + elif audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + audio_bytes = audio_data.tobytes() + else: + audio_bytes = audio_data + + # Write WAV file + with wave.open(str(file_path), 'wb') as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(2) # 16-bit audio + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_bytes) + + logger.debug(f"Saved audio to {file_path}") + except Exception as e: + logger.error(f"Error saving audio to {file_path}: {e}") + raise + + def _save_metadata_json(self, metadata: dict, file_path: Path): + """Save metadata to a JSON file.""" + try: + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + logger.debug(f"Saved metadata to {file_path}") + except Exception as e: + logger.error(f"Error saving metadata to {file_path}: {e}") + raise + + def stage_user_audio( + self, + audio_data: Union[bytes, np.ndarray], + transcription: str, + sample_rate: int = 16000, + num_channels: int = 1, + is_final: bool = True, + additional_metadata: Optional[dict] = None, + ) -> Optional[dict]: + """ + Stage log user audio and transcription (from STT). + This data will be saved when the turn is complete by `log_user_audio` method. + + Args: + audio_data: Raw audio data as bytes or numpy array + transcription: Transcribed text + sample_rate: Audio sample rate in Hz (default: 16000) + num_channels: Number of audio channels (default: 1) + is_final: Whether this is a final transcription (default: True) + additional_metadata: Additional metadata to include + + Returns: + Dictionary with logged file paths, or None if logging is disabled + """ + if not self.enabled: + return None + + try: + # Get counter and generate filenames + counter = self._get_next_counter("user") + timestamp = datetime.now().strftime('%H%M%S') + base_name = f"{counter:05d}_{timestamp}" + + audio_file = self.user_dir / f"{base_name}.wav" + metadata_file = self.user_dir / f"{base_name}.json" + + # Save audio + # self._save_audio_wav(audio_data, audio_file, sample_rate, num_channels) + self._staged_audio_data = audio_data + + # Prepare metadata + self._staged_metadata = { + "base_name": base_name, + "counter": counter, + "speaker": "user", + "timestamp": datetime.now().isoformat(), + "transcription": transcription, + "is_final": is_final, + "audio_file": audio_file.name, + "sample_rate": sample_rate, + "num_channels": num_channels, + "audio_duration_sec": ( + len(audio_data) / (sample_rate * num_channels * 2) + if isinstance(audio_data, bytes) + else len(audio_data) / sample_rate + ), + } + + if additional_metadata: + self._staged_metadata.update(additional_metadata) + + # Save metadata + # self._save_metadata_json(metadata, metadata_file) + + # logger.info(f"Logged user audio #{counter}: '{transcription[:50]}{'...' if len(transcription) > 50 else ''}'") + + return { + "audio_file": str(audio_file), + "metadata_file": str(metadata_file), + "counter": counter, + } + + except Exception as e: + logger.error(f"Error logging user audio: {e}") + return None + + def save_user_audio(self): + """Save the user audio to the disk.""" + audio_file = self.user_dir / f"{self._staged_metadata['base_name']}.wav" + metadata_file = self.user_dir / f"{self._staged_metadata['base_name']}.json" + + self._save_audio_wav( + audio_data=self._staged_audio_data, file_path=audio_file, sample_rate=self._staged_metadata["sample_rate"] + ) + + self._save_metadata_json(metadata=self._staged_metadata, file_path=metadata_file) + logger.info( + f"Saved user audio #{self._staged_metadata['counter']}: '{self._staged_metadata['transcription'][:50]}{'...' if len(self._staged_metadata['transcription']) > 50 else ''}'" + ) + # Update session metadata + with self._lock: + self.session_metadata["user_entries"].append(self._staged_metadata) + self._save_session_metadata() + + def log_agent_audio( + self, + audio_data: Union[bytes, np.ndarray], + text: str, + sample_rate: int = 22050, + num_channels: int = 1, + additional_metadata: Optional[dict] = None, + ) -> Optional[dict]: + """ + Log agent audio and text (from TTS). + + Args: + audio_data: Generated audio data as bytes or numpy array + text: Input text that was synthesized + sample_rate: Audio sample rate in Hz (default: 22050) + num_channels: Number of audio channels (default: 1) + additional_metadata: Additional metadata to include + + Returns: + Dictionary with logged file paths, or None if logging is disabled + """ + if not self.enabled: + return None + + try: + # Get counter and generate filenames + counter = self._get_next_counter("agent") + timestamp = datetime.now().strftime('%H%M%S') + base_name = f"{counter:05d}_{timestamp}" + + audio_file = self.agent_dir / f"{base_name}.wav" + metadata_file = self.agent_dir / f"{base_name}.json" + + # Save audio + self._save_audio_wav(audio_data, audio_file, sample_rate, num_channels) + + # Prepare metadata + metadata = { + "counter": counter, + "speaker": "agent", + "timestamp": datetime.now().isoformat(), + "text": text, + "audio_file": audio_file.name, + "sample_rate": sample_rate, + "num_channels": num_channels, + "audio_duration_sec": ( + len(audio_data) / (sample_rate * num_channels * 2) + if isinstance(audio_data, bytes) + else len(audio_data) / sample_rate + ), + } + + if additional_metadata: + metadata.update(additional_metadata) + + # Save metadata + self._save_metadata_json(metadata, metadata_file) + + # Update session metadata + with self._lock: + self.session_metadata["agent_entries"].append(metadata) + self._save_session_metadata() + + logger.info(f"Logged agent audio #{counter}: '{text[:50]}{'...' if len(text) > 50 else ''}'") + + return { + "audio_file": str(audio_file), + "metadata_file": str(metadata_file), + "counter": counter, + } + + except Exception as e: + logger.error(f"Error logging agent audio: {e}") + return None + + def _save_session_metadata(self): + """Save the session metadata to disk.""" + if not self.enabled: + return + + try: + metadata_file = self.session_dir / "session_metadata.json" + self.session_metadata["last_updated"] = datetime.now().isoformat() + self._save_metadata_json(self.session_metadata, metadata_file) + except Exception as e: + logger.error(f"Error saving session metadata: {e}") + + def finalize_session(self): + """Finalize the session and save final metadata.""" + if not self.enabled: + return + + self.session_metadata["end_time"] = datetime.now().isoformat() + self.session_metadata["total_user_entries"] = self._user_counter + self.session_metadata["total_agent_entries"] = self._agent_counter + self._save_session_metadata() + logger.info(f"Session finalized: {self.session_id} (User: {self._user_counter}, Agent: {self._agent_counter})") + + def get_session_info(self) -> dict: + """Get current session information.""" + return { + "session_id": self.session_id, + "session_dir": str(self.session_dir), + "user_entries": self._user_counter, + "agent_entries": self._agent_counter, + "enabled": self.enabled, + } diff --git a/nemo/agents/voice_agent/pipecat/services/nemo/diar.py b/nemo/agents/voice_agent/pipecat/services/nemo/diar.py index 912179fd93e0..4be79e4b1997 100644 --- a/nemo/agents/voice_agent/pipecat/services/nemo/diar.py +++ b/nemo/agents/voice_agent/pipecat/services/nemo/diar.py @@ -324,6 +324,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._audio_buffer = [] def reset(self): + """Reset the diarization service.""" self._current_speaker_id = None self._audio_buffer = [] self._vad_user_speaking = False diff --git a/nemo/agents/voice_agent/pipecat/services/nemo/stt.py b/nemo/agents/voice_agent/pipecat/services/nemo/stt.py index 63ef595d2b00..1127e033b8ca 100644 --- a/nemo/agents/voice_agent/pipecat/services/nemo/stt.py +++ b/nemo/agents/voice_agent/pipecat/services/nemo/stt.py @@ -24,6 +24,7 @@ InterimTranscriptionFrame, StartFrame, TranscriptionFrame, + VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection @@ -33,6 +34,7 @@ from pipecat.utils.tracing.service_decorators import traced_stt from pydantic import BaseModel +from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger from nemo.agents.voice_agent.pipecat.services.nemo.legacy_asr import NemoLegacyASRService try: @@ -69,6 +71,8 @@ def __init__( has_turn_taking: bool = False, backend: Optional[str] = "legacy", decoder_type: Optional[str] = "rnnt", + record_audio_data: Optional[bool] = False, + audio_logger: Optional[AudioLogger] = None, **kwargs, ): super().__init__(**kwargs) @@ -81,6 +85,9 @@ def __init__( self._has_turn_taking = has_turn_taking self._backend = backend self._decoder_type = decoder_type + self._record_audio_data = record_audio_data + self._audio_logger = audio_logger + self._is_vad_active = False if not params: raise ValueError("params is required") @@ -89,6 +96,9 @@ def __init__( self._load_model() self.audio_buffer = [] + # Buffers for accumulating audio and transcriptions for a complete turn (for logging) + self._turn_audio_buffer = [] + self._turn_transcription_buffer = [] def _load_model(self): if self._backend == "legacy": @@ -158,6 +168,13 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: self.audio_buffer = [] transcription, is_final = self._model.transcribe(audio) + if self._record_audio_data and self._audio_logger and self._is_vad_active: + self._turn_audio_buffer.append(audio) + # Accumulate transcriptions for turn-based logging + if transcription: + self._turn_transcription_buffer.append(transcription) + self._stage_turn_audio_and_transcription() + await self.stop_ttfb_metrics() await self.stop_processing_metrics() @@ -201,6 +218,47 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: ) yield None + def _stage_turn_audio_and_transcription(self): + """ + Stage the complete turn audio and accumulated transcriptions. + + This method is called when a final transcription is received. + It joins all accumulated audio and transcription chunks and logs them together. + """ + if not self._turn_audio_buffer or not self._turn_transcription_buffer: + logger.debug("No audio or transcription to log") + return + + try: + # Join all accumulated audio and transcriptions for this turn + complete_turn_audio = b"".join(self._turn_audio_buffer) + complete_transcription = "".join(self._turn_transcription_buffer) + + logger.debug( + f"Staging a turn with: {len(self._turn_audio_buffer)} audio chunks, " + f"{len(self._turn_transcription_buffer)} transcription chunks" + ) + + self._audio_logger.stage_user_audio( + audio_data=complete_turn_audio, + transcription=complete_transcription, + sample_rate=self._sample_rate, + num_channels=1, + is_final=True, + additional_metadata={ + "model": self._model_name, + "backend": self._backend, + "audio_duration_sec": len(complete_turn_audio) / (self._sample_rate * 2), + "num_transcription_chunks": len(self._turn_transcription_buffer), + "num_audio_chunks": len(self._turn_audio_buffer), + }, + ) + + logger.info(f"Staged the audio and transcription for turn: '{complete_transcription[:50]}...'") + + except Exception as e: + logger.warning(f"Failed to log user audio: {e}") + @traced_stt async def _handle_transcription(self, transcript: str, is_final: bool, language: Optional[str] = None): """Handle a transcription result. @@ -236,8 +294,18 @@ async def set_model(self, model: str): self._load_model() async def process_frame(self, frame: Frame, direction: FrameDirection): - if isinstance(frame, VADUserStoppedSpeakingFrame) and isinstance(self._model, NemoLegacyASRService): - # manualy reset the state of the model when end of utterance is detected by VAD - logger.debug("Resetting state of the model due to VADUserStoppedSpeakingFrame") - self._model.reset_state() + if isinstance(self._model, NemoLegacyASRService): + if isinstance(frame, VADUserStoppedSpeakingFrame): + self._is_vad_active = False + # manualy reset the state of the model when end of utterance is detected by VAD + logger.debug("Resetting state of the model due to VADUserStoppedSpeakingFrame") + self._model.reset_state() + # Clear turn buffers if logging wasn't completed (e.g., no final transcription) + if len(self._turn_audio_buffer) > 0 or len(self._turn_transcription_buffer) > 0: + logger.debug("Clearing turn audio and transcription buffers due to VAD user stopped speaking") + self._turn_audio_buffer = [] + self._turn_transcription_buffer = [] + elif isinstance(frame, VADUserStartedSpeakingFrame): + self._is_vad_active = True + await super().process_frame(frame, direction) diff --git a/nemo/agents/voice_agent/pipecat/services/nemo/tts.py b/nemo/agents/voice_agent/pipecat/services/nemo/tts.py index 269838b8f140..6885dbabb835 100644 --- a/nemo/agents/voice_agent/pipecat/services/nemo/tts.py +++ b/nemo/agents/voice_agent/pipecat/services/nemo/tts.py @@ -32,6 +32,7 @@ ) from pipecat.services.tts_service import TTSService +from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger from nemo.collections.tts.models import FastPitchModel, HifiGanModel @@ -55,6 +56,8 @@ def __init__( device: str = "cuda", sample_rate: int = 22050, think_tokens: Optional[List[str]] = None, + record_audio_data: Optional[bool] = False, + audio_logger: Optional[AudioLogger] = None, **kwargs, ): super().__init__(sample_rate=sample_rate, **kwargs) @@ -62,6 +65,8 @@ def __init__( self._device = device self._model = self._setup_model() self._think_tokens = think_tokens + self._record_audio_data = record_audio_data + self._audio_logger = audio_logger if think_tokens is not None: assert ( isinstance(think_tokens, list) and len(think_tokens) == 2 @@ -275,6 +280,9 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: await self.start_tts_usage_metrics(text) + # Collect all audio for logging + all_audio_bytes = b"" + # Process the audio result (same as before) if ( inspect.isgenerator(audio_result) @@ -292,6 +300,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: break audio_bytes = self._convert_to_bytes(audio_chunk) + all_audio_bytes += audio_bytes chunk_size = self.chunk_size for i in range(0, len(audio_bytes), chunk_size): audio_chunk_bytes = audio_bytes[i : i + chunk_size] @@ -306,6 +315,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: # Handle single result case await self.stop_ttfb_metrics() audio_bytes = self._convert_to_bytes(audio_result) + all_audio_bytes = audio_bytes chunk_size = self.chunk_size for i in range(0, len(audio_bytes), chunk_size): @@ -316,6 +326,21 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1) yield frame + # Log the complete audio if logger is available + if self._record_audio_data and self._audio_logger and all_audio_bytes: + try: + self._audio_logger.log_agent_audio( + audio_data=all_audio_bytes, + text=text, + sample_rate=self.sample_rate, + num_channels=1, + additional_metadata={ + "model": self._model_name, + }, + ) + except Exception as e: + logger.warning(f"Failed to log agent audio: {e}") + yield TTSStoppedFrame() finally: @@ -473,8 +498,12 @@ def _generate_audio(self, text: str) -> Iterator[np.ndarray]: # We only need the audio component for i, (gs, ps, audio) in enumerate(generator): logger.debug( +<<<<<<< HEAD + f"Kokoro generated audio chunk {i}: gs={gs}, ps={ps}, audio_shape={audio.shape if hasattr(audio, 'shape') else len(audio)}" +======= f"Kokoro generated audio chunk {i}: gs={gs}, ps={ps}," f"audio_shape={audio.shape if hasattr(audio, 'shape') else len(audio)}" +>>>>>>> origin/heh/va_fix_misc ) if isinstance(audio, torch.Tensor): audio = audio.detach().cpu().numpy() diff --git a/nemo/agents/voice_agent/pipecat/services/nemo/turn_taking.py b/nemo/agents/voice_agent/pipecat/services/nemo/turn_taking.py index 48e78f20e5cb..9e62537ff67d 100644 --- a/nemo/agents/voice_agent/pipecat/services/nemo/turn_taking.py +++ b/nemo/agents/voice_agent/pipecat/services/nemo/turn_taking.py @@ -35,6 +35,7 @@ from pipecat.utils.time import time_now_iso8601 from nemo.agents.voice_agent.pipecat.frames.frames import DiarResultFrame +from nemo.agents.voice_agent.pipecat.services.nemo.audio_logger import AudioLogger class NeMoTurnTakingService(FrameProcessor): @@ -48,6 +49,7 @@ def __init__( use_diar: bool = False, max_buffer_size: int = 3, bot_stop_delay: float = 0.5, + audio_logger: Optional[AudioLogger] = None, **kwargs, ): super().__init__(**kwargs) @@ -69,6 +71,7 @@ def __init__( self._vad_user_speaking = False self._have_sent_user_started_speaking = False self._user_speaking_buffer = "" + self._audio_logger = audio_logger if not self.use_vad: # if vad is not used, we assume the user is always speaking self._vad_user_speaking = True @@ -151,6 +154,20 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): else: await self.push_frame(frame, direction) + async def _handle_backchannel_text(self, text: str): + # ignore the backchannel string while bot is speaking + # push the backchannel string upstream, not downstream + await self.push_frame( + TranscriptionFrame( + text=f"({text})", + user_id="", + timestamp=time_now_iso8601(), + language=self.language if self.language else Language.EN_US, + result={"text": f"Backchannel detected: {text}"}, + ), + direction=FrameDirection.UPSTREAM, + ) + async def _handle_transcription( self, frame: TranscriptionFrame | InterimTranscriptionFrame, direction: FrameDirection ): @@ -163,29 +180,19 @@ async def _handle_transcription( # EOU detected, we assume the user is done speaking, so we push the completed text and interrupt the bot logger.debug(f" Detected: `{self._user_speaking_buffer}`") completed_text = self._user_speaking_buffer[: -len(self.eou_string)].strip() - self._user_speaking_buffer = "" if self._bot_speaking and self.is_backchannel(completed_text): logger.debug(f" detected for a backchannel phrase while bot is speaking: `{completed_text}`") + await self._handle_backchannel_text(completed_text) else: await self._handle_completed_text(completed_text, direction) await self._handle_user_interruption(UserStoppedSpeakingFrame()) + self._user_speaking_buffer = "" self._have_sent_user_started_speaking = False # user is done speaking, so we reset the flag elif has_eob and self._bot_speaking: - # ignore the backchannel string while bot is speaking - logger.debug(f"Ignoring backchannel string while bot is speaking: `{self._user_speaking_buffer}`") - # push the backchannel string upstream, not downstream - await self.push_frame( - TranscriptionFrame( - text=f"({self._user_speaking_buffer})", - user_id="", - timestamp=time_now_iso8601(), - language=self.language if self.language else Language.EN_US, - result={"text": f"Backchannel detected: {self._user_speaking_buffer}"}, - ), - direction=FrameDirection.UPSTREAM, - ) - self._have_sent_user_started_speaking = False # treat it as if the user is not speaking - self._user_speaking_buffer = "" # discard backchannel string and reset the buffer + logger.debug(f" detected while bot is speaking: `{self._user_speaking_buffer}`") + await self._handle_backchannel_text(str(self._user_speaking_buffer)) + self._user_speaking_buffer = "" + self._have_sent_user_started_speaking = False # user is done speaking, so we reset the flag else: # if bot is not speaking, the backchannel string is not considered a backchannel phrase # user is still speaking, so we append the text segment to the buffer @@ -309,6 +316,8 @@ async def _handle_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame self._have_sent_user_started_speaking = False elif is_backchannel: logger.debug(f"Backchannel detected: `{self._user_speaking_buffer}`") + if self._audio_logger: + self._audio_logger.save_user_audio() # push the backchannel string upstream, not downstream await self.push_frame( TranscriptionFrame( @@ -331,6 +340,8 @@ async def _handle_user_interruption(self, frame: Frame): await self.push_frame(StartInterruptionFrame(), direction=FrameDirection.DOWNSTREAM) elif isinstance(frame, UserStoppedSpeakingFrame): logger.debug("User stopped speaking") + if self._audio_logger: + self._audio_logger.save_user_audio() await self.push_frame(frame) else: logger.debug(f"Unknown frame type for _handle_user_interruption: {type(frame)}") diff --git a/nemo/agents/voice_agent/pipecat/utils/text/simple_text_aggregator.py b/nemo/agents/voice_agent/pipecat/utils/text/simple_text_aggregator.py index ada66aef6dec..92de2abdf53f 100644 --- a/nemo/agents/voice_agent/pipecat/utils/text/simple_text_aggregator.py +++ b/nemo/agents/voice_agent/pipecat/utils/text/simple_text_aggregator.py @@ -12,41 +12,172 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Optional +from loguru import logger from pipecat.utils.string import match_endofsentence from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator +def has_partial_decimal(text: str) -> bool: + """Check if the text ends with a partial decimal. + + Returns True if the text ends with a number that looks like it could + be a partial decimal (e.g., "3.", "3.14", "($3.14)"), but NOT if it's + clearly a complete sentence (e.g., "It costs $3.14.") or a bullet point + (e.g., "1. Alpha; 2."). + """ + text = text.strip() + + # Check for bullet point pattern: ends with 1-3 digits followed by period + # Examples: "1.", "12.", "123.", or "text; 2." + # Bullet points are typically small numbers (1-999) at the end + bullet_match = re.search(r'(?:^|[\s;,]|[^\d])(\d{1,3})\.$', text) + if bullet_match: + # It's likely a bullet point, not a partial decimal + return False + + # Pattern to find decimal numbers near the end, allowing for trailing + # non-word characters like ), ], ", ', etc. + # Match: digit(s) + period + optional digit(s) + optional trailing non-word chars + match = re.search(r'\d+\.(?:\d+)?([^\w\s]*)$', text) + + if not match: + return False + + trailing = match.group(1) # e.g., ")" or "" or "." + + # If trailing contains a period, it's sentence-ending punctuation + # e.g., "3.14." means complete sentence + if '.' in trailing: + return False + + # Otherwise, it's a partial decimal (either incomplete like "3." + # or complete number but sentence not finished like "($3.14)") + return True + + +def find_last_period_index(text: str) -> int: + """ + Find the last occurrence of a period in the text, + but return -1 if the text doesn't seem to be a complete sentence. + """ + num_periods = text.count(".") + if num_periods == 0: + return -1 + + if num_periods == 1: + if has_partial_decimal(text): + # if the only period in the text is part of a number, return -1 + return -1 + # Check if the only period is a bullet point (e.g., "1. Alpha" or incomplete "1.") + if re.search(r'(?:^|[\s;,]|[^\d])(\d{1,3})\.(?:\s+\w|\s*$)', text): + # The period is after a bullet point number, either: + # - followed by content (e.g., "1. Alpha") + # - or at the end with optional whitespace (e.g., "1." or "1. ") + return -1 + + # Check if any of the abbreviations "e.", "i." "g.", "etc." are present in the text + if re.search(r'\b(e\.|i\.|g\.)\b', text): + # The period is after a character/word that is likely to be a abbreviation, return -1 + return -1 + + # otherwise, check the last occurrence of a period + idx = text.rfind(".") + if idx <= 0: + return idx + if text[idx - 1].isdigit(): + # if the period is after a digit, it's likely a partial decimal, return -1 + return -1 + elif idx > 2 and text[idx - 3 : idx + 1] in ["e.g.", "i.e."]: + # The period is after a character/word that is likely to be a abbreviation, return -1 + return -1 + + # the text seems to have a complete sentence, return the index of the last period + return idx + + class SimpleSegmentedTextAggregator(SimpleTextAggregator): - def __init__(self, punctuation_marks: str | list[str] = ",!?", **kwargs): + def __init__( + self, + punctuation_marks: str | list[str] = ".,!?;:", + ignore_marks: str | list[str] = "*", + min_sentence_length: int = 0, + use_legacy_eos_detection: bool = False, + **kwargs, + ): + """ + Args: + punctuation_marks: The punctuation marks to use for sentence detection. + ignore_marks: The marks to ignore in the text. + min_sentence_length: The minimum length of a sentence to be considered. + use_legacy_eos_detection: Whether to use the legacy EOS detection from pipecat. + **kwargs: Additional arguments to pass to the SimpleTextAggregator constructor. + """ super().__init__(**kwargs) + self._use_legacy_eos_detection = use_legacy_eos_detection + self._min_sentence_length = min_sentence_length + if not ignore_marks: + self._ignore_marks = set() + else: + self._ignore_marks = set(ignore_marks) if not punctuation_marks: - self._punctuation_marks = set() + self._punctuation_marks = list() else: - self._punctuation_marks = set(punctuation_marks) + punctuation_marks = ( + [c for c in punctuation_marks] if isinstance(punctuation_marks, str) else punctuation_marks + ) + if "." in punctuation_marks: + punctuation_marks.remove(".") + punctuation_marks += [ + "." + ] # put period at the end of the list to ensure it's the last punctuation mark to be matched + self._punctuation_marks = punctuation_marks def _find_segment_end(self, text: str) -> Optional[int]: + """find the end of text segment. + + Args: + text: The text to find the end of the segment. + + Returns: + The index of the end of the segment, or None if the text is too short. + """ + if len(text.strip()) < self._min_sentence_length: + return None + for punc in self._punctuation_marks: - idx = text.find(punc) + if punc == ".": + idx = find_last_period_index(text) + else: + idx = text.find(punc) if idx != -1: - return idx + return idx + 1 return None async def aggregate(self, text: str) -> Optional[str]: result: Optional[str] = None - self._text += text + self._text += str(text) - self._text = self._text.replace("*", "") + for ignore_mark in self._ignore_marks: + self._text = self._text.replace(ignore_mark, "") - eos_end_marker = match_endofsentence(self._text) + eos_end_index = self._find_segment_end(self._text) - if not eos_end_marker: - eos_end_marker = self._find_segment_end(self._text) + if not eos_end_index and not has_partial_decimal(self._text) and self._use_legacy_eos_detection: + # if the text doesn't have partial decimal, and no punctuation marks, + # we use match_endofsentence to find the end of the sentence + eos_end_index = match_endofsentence(self._text) - if eos_end_marker: - result = self._text[:eos_end_marker] - self._text = self._text[eos_end_marker:] + if eos_end_index: + result = self._text[:eos_end_index] + if len(result.strip()) < self._min_sentence_length: + result = None + logger.debug(f"Text is too short, skipping: `{result}`, full text: `{self._text}`") + else: + logger.debug(f"Text Aggregator Result: `{result}`, full text: `{self._text}`") + self._text = self._text[eos_end_index:] return result