diff --git a/apps/llm/android/app/src/main/AndroidManifest.xml b/apps/llm/android/app/src/main/AndroidManifest.xml index c1d260939..87227eb01 100644 --- a/apps/llm/android/app/src/main/AndroidManifest.xml +++ b/apps/llm/android/app/src/main/AndroidManifest.xml @@ -7,6 +7,7 @@ + diff --git a/apps/llm/app.json b/apps/llm/app.json index c3031d0ed..5eef49366 100644 --- a/apps/llm/app.json +++ b/apps/llm/app.json @@ -23,7 +23,22 @@ "calendarPermission": "The app needs to access your calendar." } ], - "expo-router" + "expo-router", + [ + "react-native-audio-api", + { + "iosBackgroundMode": true, + "iosMicrophonePermission": "This app requires access to the microphone to record audio.", + "androidPermissions": [ + "android.permission.MODIFY_AUDIO_SETTINGS", + "android.permission.FOREGROUND_SERVICE", + "android.permission.FOREGROUND_SERVICE_MEDIA_PLAYBACK", + "android.permission.RECORD_AUDIO" + ], + "androidForegroundService": true, + "androidFSTypes": ["mediaPlayback"] + } + ] ], "newArchEnabled": true, "splash": { diff --git a/apps/llm/app/voice_chat/index.tsx b/apps/llm/app/voice_chat/index.tsx index 53c3f8508..c709a6c4a 100644 --- a/apps/llm/app/voice_chat/index.tsx +++ b/apps/llm/app/voice_chat/index.tsx @@ -22,38 +22,10 @@ import MicIcon from '../../assets/icons/mic_icon.svg'; import StopIcon from '../../assets/icons/stop_icon.svg'; import ColorPalette from '../../colors'; import Messages from '../../components/Messages'; -import LiveAudioStream from 'react-native-live-audio-stream'; +import { AudioManager, AudioRecorder } from 'react-native-audio-api'; import DeviceInfo from 'react-native-device-info'; -import { Buffer } from 'buffer'; import { useIsFocused } from '@react-navigation/native'; import { GeneratingContext } from '../../context'; -const audioStreamOptions = { - sampleRate: 16000, - channels: 1, - bitsPerSample: 16, - audioSource: 1, - bufferSize: 16000, -}; - -const startStreamingAudio = (options: any, onChunk: (data: string) => void) => { - LiveAudioStream.init(options); - LiveAudioStream.on('data', onChunk); - LiveAudioStream.start(); -}; - -const float32ArrayFromPCMBinaryBuffer = (b64EncodedBuffer: string) => { - const b64DecodedChunk = Buffer.from(b64EncodedBuffer, 'base64'); - const int16Array = new Int16Array(b64DecodedChunk.buffer); - - const float32Array = new Float32Array(int16Array.length); - for (let i = 0; i < int16Array.length; i++) { - float32Array[i] = Math.max( - -1, - Math.min(1, (int16Array[i] / audioStreamOptions.bufferSize) * 8) - ); - } - return float32Array; -}; export default function VoiceChatScreenWrapper() { const isFocused = useIsFocused(); @@ -63,6 +35,13 @@ export default function VoiceChatScreenWrapper() { function VoiceChatScreen() { const [isRecording, setIsRecording] = useState(false); + const [recorder] = useState( + () => + new AudioRecorder({ + sampleRate: 16000, + bufferLengthInSamples: 1600, + }) + ); const messageRecorded = useRef(false); const { setGlobalGenerating } = useContext(GeneratingContext); @@ -75,20 +54,27 @@ function VoiceChatScreen() { setGlobalGenerating(llm.isGenerating || speechToText.isGenerating); }, [llm.isGenerating, speechToText.isGenerating, setGlobalGenerating]); - const onChunk = (data: string) => { - const float32Chunk = float32ArrayFromPCMBinaryBuffer(data); - speechToText.streamInsert(Array.from(float32Chunk)); - }; + useEffect(() => { + AudioManager.setAudioSessionOptions({ + iosCategory: 'playAndRecord', + iosMode: 'spokenAudio', + iosOptions: ['allowBluetooth', 'defaultToSpeaker'], + }); + AudioManager.requestRecordingPermissions(); + }, []); const handleRecordPress = async () => { if (isRecording) { setIsRecording(false); - LiveAudioStream.stop(); + recorder.stop(); messageRecorded.current = true; - speechToText.streamStop(); + await speechToText.streamStop(); } else { setIsRecording(true); - startStreamingAudio(audioStreamOptions, onChunk); + recorder.onAudioReady(async ({ buffer }) => { + await speechToText.streamInsert(buffer.getChannelData(0)); + }); + recorder.start(); const transcription = await speechToText.stream(); await llm.sendMessage(transcription); } diff --git a/apps/llm/ios/Podfile.lock b/apps/llm/ios/Podfile.lock index 48632bb03..46d2399e8 100644 --- a/apps/llm/ios/Podfile.lock +++ b/apps/llm/ios/Podfile.lock @@ -1403,7 +1403,7 @@ PODS: - React-jsiexecutor - React-RCTFBReactNativeSpec - ReactCommon/turbomodule/core - - react-native-executorch (0.5.2): + - react-native-executorch (0.5.0): - DoubleConversion - glog - hermes-engine @@ -1826,7 +1826,7 @@ PODS: - React-logger (= 0.79.2) - React-perflogger (= 0.79.2) - React-utils (= 0.79.2) - - RNAudioAPI (0.5.7): + - RNAudioAPI (0.8.2): - DoubleConversion - glog - hermes-engine @@ -1849,9 +1849,9 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - - RNAudioAPI/audioapi (= 0.5.7) + - RNAudioAPI/audioapi (= 0.8.2) - Yoga - - RNAudioAPI/audioapi (0.5.7): + - RNAudioAPI/audioapi (0.8.2): - DoubleConversion - glog - hermes-engine @@ -1874,9 +1874,9 @@ PODS: - ReactCodegen - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - - RNAudioAPI/audioapi/ios (= 0.5.7) + - RNAudioAPI/audioapi/ios (= 0.8.2) - Yoga - - RNAudioAPI/audioapi/ios (0.5.7): + - RNAudioAPI/audioapi/ios (0.8.2): - DoubleConversion - glog - hermes-engine @@ -1926,8 +1926,6 @@ PODS: - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - Yoga - - RNLiveAudioStream (1.1.1): - - React - RNReanimated (3.17.5): - DoubleConversion - glog @@ -2246,7 +2244,6 @@ DEPENDENCIES: - RNAudioAPI (from `../../../node_modules/react-native-audio-api`) - RNDeviceInfo (from `../../../node_modules/react-native-device-info`) - RNGestureHandler (from `../../../node_modules/react-native-gesture-handler`) - - RNLiveAudioStream (from `../../../node_modules/react-native-live-audio-stream`) - RNReanimated (from `../../../node_modules/react-native-reanimated`) - RNScreens (from `../../../node_modules/react-native-screens`) - RNSVG (from `../../../node_modules/react-native-svg`) @@ -2430,8 +2427,6 @@ EXTERNAL SOURCES: :path: "../../../node_modules/react-native-device-info" RNGestureHandler: :path: "../../../node_modules/react-native-gesture-handler" - RNLiveAudioStream: - :path: "../../../node_modules/react-native-live-audio-stream" RNReanimated: :path: "../../../node_modules/react-native-reanimated" RNScreens: @@ -2525,15 +2520,14 @@ SPEC CHECKSUMS: ReactAppDependencyProvider: 04d5eb15eb46be6720e17a4a7fa92940a776e584 ReactCodegen: 7ea266ccd94436294f516247db7402b57b1214af ReactCommon: 76d2dc87136d0a667678668b86f0fca0c16fdeb0 - RNAudioAPI: 2e3fd4bf75aa5717791babb30126707504996f09 + RNAudioAPI: 3e398c4e9d44bb6b0c0b00e902057613224fc024 RNDeviceInfo: d863506092aef7e7af3a1c350c913d867d795047 RNGestureHandler: 7d0931a61d7ba0259f32db0ba7d0963c3ed15d2b - RNLiveAudioStream: 93ac2bb6065be9018d0b00157b220f11cebc1513 RNReanimated: afd6a269a47d6f13ba295c46c6c0e14e3cbd0d8a RNScreens: 482e9707f9826230810c92e765751af53826d509 RNSVG: 794f269526df9ddc1f79b3d1a202b619df0368e3 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 - sqlite3: 83105acd294c9137c026e2da1931c30b4588ab81 + sqlite3: 1d85290c3321153511f6e900ede7a1608718bbd5 Yoga: c758bfb934100bb4bf9cbaccb52557cee35e8bdf PODFILE CHECKSUM: bba19a069e673f2259009e9d2caab44374fdebcf diff --git a/apps/llm/ios/llm.xcodeproj/project.pbxproj b/apps/llm/ios/llm.xcodeproj/project.pbxproj index 582d3baf5..215dad425 100644 --- a/apps/llm/ios/llm.xcodeproj/project.pbxproj +++ b/apps/llm/ios/llm.xcodeproj/project.pbxproj @@ -28,12 +28,11 @@ B79E360E00239D910BF9B38D /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xml; name = PrivacyInfo.xcprivacy; path = llm/PrivacyInfo.xcprivacy; sourceTree = ""; }; BB2F792C24A3F905000567C9 /* Expo.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Expo.plist; sourceTree = ""; }; E8C01EF33FCE4105BBBC9DF6 /* Aeonik-Medium.otf */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 9; includeInIndex = 0; lastKnownFileType = unknown; name = "Aeonik-Medium.otf"; path = "../assets/fonts/Aeonik-Medium.otf"; sourceTree = ""; }; - EA4529BE680FEB0AB7539557 /* Pods-llm.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-llm.release.xcconfig"; path = "Target Support Files/Pods-llm/Pods-llm.release.xcconfig"; sourceTree = ""; }; ED297162215061F000B7C4FE /* JavaScriptCore.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = JavaScriptCore.framework; path = System/Library/Frameworks/JavaScriptCore.framework; sourceTree = SDKROOT; }; F11748412D0307B40044C1D9 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = AppDelegate.swift; path = llm/AppDelegate.swift; sourceTree = ""; }; F11748442D0722820044C1D9 /* llm-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "llm-Bridging-Header.h"; path = "llm/llm-Bridging-Header.h"; sourceTree = ""; }; + F5CE0775ADE5923FA417B603 /* libPods-llm.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-llm.a"; sourceTree = BUILT_PRODUCTS_DIR; }; F866B7979FB94C8797EE2E3D /* Aeonik-Regular.otf */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 9; includeInIndex = 0; lastKnownFileType = unknown; name = "Aeonik-Regular.otf"; path = "../assets/fonts/Aeonik-Regular.otf"; sourceTree = ""; }; - FCA4A9AE0011869427989B32 /* libPods-llm.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-llm.a"; sourceTree = BUILT_PRODUCTS_DIR; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -96,6 +95,15 @@ name = Frameworks; sourceTree = ""; }; + 3014A6CAF64EC97E4003A2A3 /* Pods */ = { + isa = PBXGroup; + children = ( + 4F489A14802F01369BFDDEFD /* Pods-llm.debug.xcconfig */, + 63C842393C3838DA2ECEFC7C /* Pods-llm.release.xcconfig */, + ); + path = Pods; + sourceTree = ""; + }; 832341AE1AAA6A7D00B99B32 /* Libraries */ = { isa = PBXGroup; children = ( @@ -292,7 +300,52 @@ shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; showEnvVarsInLog = 0; }; - E0CDBD4D0993974173A0E9FD /* [CP] Copy Pods Resources */ = { + 281D8603161F8B331E2BA335 /* [Expo] Configure project */ = { + isa = PBXShellScriptBuildPhase; + alwaysOutOfDate = 1; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "[Expo] Configure project"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "# This script configures Expo modules and generates the modules provider file.\nbash -l -c \"./Pods/Target\\ Support\\ Files/Pods-llm/expo-configure-project.sh\"\n"; + }; + 62055444ECB4CA2743E68CDC /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_ROOT}/Target Support Files/Pods-llm/Pods-llm-frameworks.sh", + "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavcodec.framework/libavcodec", + "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavformat.framework/libavformat", + "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavutil.framework/libavutil", + "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libswresample.framework/libswresample", + "${PODS_XCFRAMEWORKS_BUILD_DIR}/hermes-engine/Pre-built/hermes.framework/hermes", + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavcodec.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavformat.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavutil.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libswresample.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/hermes.framework", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-llm/Pods-llm-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 800E24972A6A228C8D4807E9 /* [CP] Copy Pods Resources */ = { isa = PBXShellScriptBuildPhase; buildActionMask = 2147483647; files = ( diff --git a/apps/llm/metro.config.js b/apps/llm/metro.config.js index 1ce72df78..fa6ef6d8b 100644 --- a/apps/llm/metro.config.js +++ b/apps/llm/metro.config.js @@ -1,21 +1,22 @@ const { getDefaultConfig } = require('expo/metro-config'); +const { + wrapWithAudioAPIMetroConfig, +} = require('react-native-audio-api/metro-config'); -module.exports = (() => { - const config = getDefaultConfig(__dirname); +const config = getDefaultConfig(__dirname); - const { transformer, resolver } = config; +const { transformer, resolver } = config; - config.transformer = { - ...transformer, - babelTransformerPath: require.resolve('react-native-svg-transformer/expo'), - }; - config.resolver = { - ...resolver, - assetExts: resolver.assetExts.filter((ext) => ext !== 'svg'), - sourceExts: [...resolver.sourceExts, 'svg'], - }; +config.transformer = { + ...transformer, + babelTransformerPath: require.resolve('react-native-svg-transformer/expo'), +}; +config.resolver = { + ...resolver, + assetExts: resolver.assetExts.filter((ext) => ext !== 'svg'), + sourceExts: [...resolver.sourceExts, 'svg'], +}; - config.resolver.assetExts.push('pte'); +config.resolver.assetExts.push('pte'); - return config; -})(); +module.exports = wrapWithAudioAPIMetroConfig(config); diff --git a/apps/llm/package.json b/apps/llm/package.json index 03c76dbff..32d8c3392 100644 --- a/apps/llm/package.json +++ b/apps/llm/package.json @@ -28,11 +28,10 @@ "metro-config": "^0.81.0", "react": "19.0.0", "react-native": "0.79.2", - "react-native-audio-api": "0.5.7", + "react-native-audio-api": "^0.8.2", "react-native-device-info": "^14.0.4", "react-native-executorch": "workspace:*", "react-native-gesture-handler": "~2.24.0", - "react-native-live-audio-stream": "^1.1.1", "react-native-loading-spinner-overlay": "^3.0.1", "react-native-markdown-display": "^7.0.2", "react-native-reanimated": "~3.17.4", diff --git a/apps/speech-to-text/ios/Podfile.lock b/apps/speech-to-text/ios/Podfile.lock index 3c162c082..fe20c8804 100644 --- a/apps/speech-to-text/ios/Podfile.lock +++ b/apps/speech-to-text/ios/Podfile.lock @@ -1395,7 +1395,7 @@ PODS: - React-jsiexecutor - React-RCTFBReactNativeSpec - ReactCommon/turbomodule/core - - react-native-executorch (0.4.2): + - react-native-executorch (0.5.0): - DoubleConversion - glog - hermes-engine @@ -2382,7 +2382,7 @@ SPEC CHECKSUMS: React-logger: 8edfcedc100544791cd82692ca5a574240a16219 React-Mapbuffer: c3f4b608e4a59dd2f6a416ef4d47a14400194468 React-microtasksnativemodule: 054f34e9b82f02bd40f09cebd4083828b5b2beb6 - react-native-executorch: c18d209e226f0530a9ee88f1d60ce5837d4800ee + react-native-executorch: 3c871f7ed2e2b0ff92519ce38f06f0904784dbdb react-native-safe-area-context: 562163222d999b79a51577eda2ea8ad2c32b4d06 React-NativeModulesApple: 2c4377e139522c3d73f5df582e4f051a838ff25e React-oscompat: ef5df1c734f19b8003e149317d041b8ce1f7d29c diff --git a/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj b/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj index 6118e5088..0787f2ae6 100644 --- a/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj +++ b/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj @@ -276,7 +276,6 @@ "${PODS_CONFIGURATION_BUILD_DIR}/React-cxxreact/React-cxxreact_privacy.bundle", "${PODS_CONFIGURATION_BUILD_DIR}/boost/boost_privacy.bundle", "${PODS_CONFIGURATION_BUILD_DIR}/glog/glog_privacy.bundle", - "${PODS_CONFIGURATION_BUILD_DIR}/react-native-image-picker/RNImagePickerPrivacyInfo.bundle", ); name = "[CP] Copy Pods Resources"; outputPaths = ( @@ -290,7 +289,6 @@ "${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/React-cxxreact_privacy.bundle", "${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/boost_privacy.bundle", "${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/glog_privacy.bundle", - "${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/RNImagePickerPrivacyInfo.bundle", ); runOnlyForDeploymentPostprocessing = 0; shellPath = /bin/sh; @@ -305,12 +303,10 @@ inputPaths = ( "${PODS_ROOT}/Target Support Files/Pods-speechtotext/Pods-speechtotext-frameworks.sh", "${PODS_XCFRAMEWORKS_BUILD_DIR}/hermes-engine/Pre-built/hermes.framework/hermes", - "${PODS_XCFRAMEWORKS_BUILD_DIR}/react-native-executorch/ExecutorchLib.framework/ExecutorchLib", ); name = "[CP] Embed Pods Frameworks"; outputPaths = ( "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/hermes.framework", - "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ExecutorchLib.framework", ); runOnlyForDeploymentPostprocessing = 0; shellPath = /bin/sh; diff --git a/apps/speech-to-text/metro.config.js b/apps/speech-to-text/metro.config.js index f8ab2ab96..fa6ef6d8b 100644 --- a/apps/speech-to-text/metro.config.js +++ b/apps/speech-to-text/metro.config.js @@ -1,7 +1,8 @@ -// Learn more https://docs.expo.io/guides/customizing-metro const { getDefaultConfig } = require('expo/metro-config'); +const { + wrapWithAudioAPIMetroConfig, +} = require('react-native-audio-api/metro-config'); -/** @type {import('expo/metro-config').MetroConfig} */ const config = getDefaultConfig(__dirname); const { transformer, resolver } = config; @@ -18,4 +19,4 @@ config.resolver = { config.resolver.assetExts.push('pte'); -module.exports = config; +module.exports = wrapWithAudioAPIMetroConfig(config); diff --git a/apps/speech-to-text/screens/SpeechToTextScreen.tsx b/apps/speech-to-text/screens/SpeechToTextScreen.tsx index a7789a7f3..3be1750d3 100644 --- a/apps/speech-to-text/screens/SpeechToTextScreen.tsx +++ b/apps/speech-to-text/screens/SpeechToTextScreen.tsx @@ -66,8 +66,7 @@ export const SpeechToTextScreen = () => { try { const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); - const audioArray = Array.from(audioBuffer); - setTranscription(await model.transcribe(audioArray)); + setTranscription(await model.transcribe(audioBuffer)); } catch (error) { console.error('Error decoding audio data', error); console.warn('Note: Supported file formats: mp3, wav, flac'); @@ -79,8 +78,7 @@ export const SpeechToTextScreen = () => { setLiveTranscribing(true); setTranscription(''); recorder.onAudioReady(async ({ buffer }) => { - const bufferArray = Array.from(buffer.getChannelData(0)); - model.streamInsert(bufferArray); + await model.streamInsert(buffer.getChannelData(0)); }); recorder.start(); @@ -93,7 +91,7 @@ export const SpeechToTextScreen = () => { const handleStopTranscribeFromMicrophone = async () => { recorder.stop(); - model.streamStop(); + await model.streamStop(); console.log('Live transcription stopped'); setLiveTranscribing(false); }; diff --git a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md index 3b5af0e23..0b7d95398 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md @@ -44,10 +44,9 @@ const { uri } = await FileSystem.downloadAsync( const audioContext = new AudioContext({ sampleRate: 16000 }); const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); -const audioArray = Array.from(audioBuffer); try { - const transcription = await model.transcribe(audioArray); + const transcription = await model.transcribe(audioBuffer); console.log(transcription); } catch (error) { console.error('Error during audio transcription', error); @@ -76,20 +75,20 @@ For more information on loading resources, take a look at [loading models](../.. ### Returns -| Field | Type | Description | -| --------------------------- | --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `transcribe` | `(waveform: number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. | -| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. | -| `streamInsert` | `(waveform: number[]) => void` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. | -| `streamStop` | `() => void` | Stops the ongoing streaming transcription process. | -| `encode` | `(waveform: Float32Array) => Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. | -| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. | -| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. | -| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. | -| `error` | `string \| null` | Contains the error message if the model failed to load. | -| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | -| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | -| `downloadProgress` | `number` | Tracks the progress of the model download process. | +| Field | Type | Description | +| --------------------------- | ---------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. Passing `number[]` is deprecated. | +| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. | +| `streamInsert` | `(waveform: Float32Array \| number[]) => Promise` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. Passing `number[]` is deprecated. | +| `streamStop` | `() => Promise` | Stops the ongoing streaming transcription process. | +| `encode` | `(waveform: Float32Array \| number[]) => Promise` | Runs the encoding part of the model on the provided waveform. Passing `number[]` is deprecated. | +| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]) => Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. | +| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. | +| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. | +| `error` | `string \| null` | Contains the error message if the model failed to load. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | +| `downloadProgress` | `number` | Tracks the progress of the model download process. |
Type definitions @@ -231,7 +230,7 @@ function App() { const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); - return Array.from(audioBuffer); + return audioBuffer; }; const handleTranscribe = async () => { @@ -281,8 +280,7 @@ function App() { const handleStartStreamingTranscribe = async () => { recorder.onAudioReady(async ({ buffer }) => { - const bufferArray = Array.from(buffer.getChannelData(0)); - model.streamInsert(bufferArray); + await model.streamInsert(buffer.getChannelData(0)); }); recorder.start(); @@ -295,7 +293,7 @@ function App() { const handleStopStreamingTranscribe = async () => { recorder.stop(); - model.streamStop(); + await model.streamStop(); }; return ( diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md index df57a8f3d..839e58c78 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md @@ -19,15 +19,15 @@ await model.transcribe(waveform); ### Methods -| Method | Type | Description | -| -------------- | ---------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. | -| `encode` | `(waveform: Float32Array): Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. | -| `decode` | `(tokens: number[]): Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. | -| `transcribe` | `(waveform: number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. | -| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. | -| `streamStop` | `(): void` | Stops the current streaming transcription session. | -| `streamInsert` | `(waveform: number[]): void` | Inserts a new audio chunk into the streaming transcription session. | +| Method | Type | Description | +| -------------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. | +| `encode` | `(waveform: Float32Array \| number[]): Promise` | Runs the encoding part of the model on the provided waveform. Returns the encoded waveform as a Float32Array. Passing `number[]` is deprecated. | +| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]): Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. | +| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. Passing `number[]` is deprecated. | +| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. | +| `streamStop` | `(): Promise` | Stops the current streaming transcription session. | +| `streamInsert` | `(waveform: Float32Array \| number[]): Promise` | Inserts a new audio chunk into the streaming transcription session. Passing `number[]` is deprecated. | :::info @@ -192,11 +192,10 @@ const { uri } = await FileSystem.downloadAsync( const audioContext = new AudioContext({ sampleRate: 16000 }); const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); -const audioArray = Array.from(audioBuffer); // Transcribe the audio try { - const transcription = await model.transcribe(audioArray); + const transcription = await model.transcribe(audioBuffer); console.log(transcription); } catch (error) { console.error('Error during audio transcription', error); @@ -229,9 +228,8 @@ const recorder = new AudioRecorder({ bufferLengthInSamples: 1600, }); recorder.onAudioReady(async ({ buffer }) => { - const bufferArray = Array.from(buffer.getChannelData(0)); // Insert the audio into the streaming transcription - model.streamInsert(bufferArray); + await model.streamInsert(buffer.getChannelData(0)); }); recorder.start(); @@ -248,6 +246,6 @@ try { } // Stop streaming transcription -model.streamStop(); +await model.streamStop(); recorder.stop(); ``` diff --git a/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md b/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md index 3b5af0e23..4c3237743 100644 --- a/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md +++ b/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md @@ -44,10 +44,9 @@ const { uri } = await FileSystem.downloadAsync( const audioContext = new AudioContext({ sampleRate: 16000 }); const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); -const audioArray = Array.from(audioBuffer); try { - const transcription = await model.transcribe(audioArray); + const transcription = await model.transcribe(audioBuffer); console.log(transcription); } catch (error) { console.error('Error during audio transcription', error); @@ -76,20 +75,20 @@ For more information on loading resources, take a look at [loading models](../.. ### Returns -| Field | Type | Description | -| --------------------------- | --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `transcribe` | `(waveform: number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. | -| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. | -| `streamInsert` | `(waveform: number[]) => void` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. | -| `streamStop` | `() => void` | Stops the ongoing streaming transcription process. | -| `encode` | `(waveform: Float32Array) => Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. | -| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. | -| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. | -| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. | -| `error` | `string \| null` | Contains the error message if the model failed to load. | -| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | -| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | -| `downloadProgress` | `number` | Tracks the progress of the model download process. | +| Field | Type | Description | +| --------------------------- | ---------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. Passing `number[]` is deprecated. | +| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. | +| `streamInsert` | `(waveform: Float32Array \| number[]) => Promise` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. Passing `number[]` is deprecated. | +| `streamStop` | `() => Promise` | Stops the ongoing streaming transcription process. | +| `encode` | `(waveform: Float32Array \| number[]) => Promise` | Runs the encoding part of the model on the provided waveform. Passing `number[]` is deprecated. | +| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]) => Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. | +| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. | +| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. | +| `error` | `string \| null` | Contains the error message if the model failed to load. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | +| `downloadProgress` | `number` | Tracks the progress of the model download process. |
Type definitions @@ -231,7 +230,7 @@ function App() { const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); - return Array.from(audioBuffer); + return audioBuffer; }; const handleTranscribe = async () => { @@ -281,8 +280,7 @@ function App() { const handleStartStreamingTranscribe = async () => { recorder.onAudioReady(async ({ buffer }) => { - const bufferArray = Array.from(buffer.getChannelData(0)); - model.streamInsert(bufferArray); + model.streamInsert(buffer.getChannelData(0)); }); recorder.start(); diff --git a/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md index df57a8f3d..039189e73 100644 --- a/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md +++ b/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md @@ -19,15 +19,15 @@ await model.transcribe(waveform); ### Methods -| Method | Type | Description | -| -------------- | ---------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. | -| `encode` | `(waveform: Float32Array): Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. | -| `decode` | `(tokens: number[]): Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. | -| `transcribe` | `(waveform: number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. | -| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. | -| `streamStop` | `(): void` | Stops the current streaming transcription session. | -| `streamInsert` | `(waveform: number[]): void` | Inserts a new audio chunk into the streaming transcription session. | +| Method | Type | Description | +| -------------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. | +| `encode` | `(waveform: Float32Array \| number[]): Promise` | Runs the encoding part of the model on the provided waveform. Returns the encoded waveform as a Float32Array. Passing `number[]` is deprecated. | +| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]): Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. | +| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. Passing `number[]` is deprecated. | +| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. | +| `streamStop` | `(): void` | Stops the current streaming transcription session. | +| `streamInsert` | `(waveform: Float32Array \| number[]): void` | Inserts a new audio chunk into the streaming transcription session. Passing `number[]` is deprecated. | :::info @@ -192,11 +192,10 @@ const { uri } = await FileSystem.downloadAsync( const audioContext = new AudioContext({ sampleRate: 16000 }); const decodedAudioData = await audioContext.decodeAudioDataSource(uri); const audioBuffer = decodedAudioData.getChannelData(0); -const audioArray = Array.from(audioBuffer); // Transcribe the audio try { - const transcription = await model.transcribe(audioArray); + const transcription = await model.transcribe(audioBuffer); console.log(transcription); } catch (error) { console.error('Error during audio transcription', error); @@ -229,9 +228,8 @@ const recorder = new AudioRecorder({ bufferLengthInSamples: 1600, }); recorder.onAudioReady(async ({ buffer }) => { - const bufferArray = Array.from(buffer.getChannelData(0)); // Insert the audio into the streaming transcription - model.streamInsert(bufferArray); + model.streamInsert(buffer.getChannelData(0)); }); recorder.start(); diff --git a/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt b/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt index bf1544aeb..d77ec7563 100644 --- a/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt +++ b/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt @@ -95,4 +95,5 @@ target_link_libraries( ${OPENCV_THIRD_PARTY_LIBS} executorch ${EXECUTORCH_LIBS} -) \ No newline at end of file + z +) diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp index 1b155f72d..0ccac9d07 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp @@ -9,7 +9,7 @@ #include namespace rnexecutorch::numerical { -void softmax(std::vector &v) { +void softmax(std::span v) { float max = *std::max_element(v.begin(), v.end()); float sum = 0.0f; @@ -22,32 +22,40 @@ void softmax(std::vector &v) { } } -void normalize(std::span span) { - auto sum = 0.0f; - for (const auto &val : span) { - sum += val * val; +void softmaxWithTemperature(std::span input, float temperature) { + if (input.empty()) { + return; } - if (isClose(sum, 0.0f)) { - return; + if (temperature <= 0.0F) { + throw std::invalid_argument( + "Temperature must be greater than 0 for softmax with temperature."); } - float norm = std::sqrt(sum); - for (auto &val : span) { - val /= norm; + const auto maxElement = *std::ranges::max_element(input); + + for (auto &value : input) { + value = std::exp((value - maxElement) / temperature); } -} -void normalize(std::vector &v) { - float sum = 0.0f; - for (float &x : v) { - sum += x * x; + const auto sum = std::reduce(input.begin(), input.end()); + + // sum is at least 1 since exp(max - max) == exp(0) == 1 + for (auto &value : input) { + value /= sum; } +} - float norm = - std::max(std::sqrt(sum), 1e-9f); // Solely for preventing division by 0 - for (float &x : v) { - x /= norm; +void normalize(std::span input) { + const auto sumOfSquares = + std::inner_product(input.begin(), input.end(), input.begin(), 0.0F); + + constexpr auto kEpsilon = 1.0e-15F; + + const auto norm = std::sqrt(sumOfSquares) + kEpsilon; + + for (auto &value : input) { + value /= norm; } } diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h index 77a13f44f..fa808f21f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h @@ -4,10 +4,59 @@ #include namespace rnexecutorch::numerical { -void softmax(std::vector &v); -void normalize(std::span span); -void normalize(std::vector &v); -void normalize(std::span span); + +/** + * @brief Applies the softmax function in-place to a sequence of numbers. + * + * @param input A mutable span of floating-point numbers. After the function + * returns, `input` contains the softmax probabilities. + */ +void softmax(std::span input); + +/** + * @brief Applies the softmax function with temperature scaling in-place to a + * sequence of numbers. + * + * The temperature parameter controls the "sharpness" of the resulting + * probability distribution. A temperature of 1.0 means no scaling, while lower + * values make the distribution sharper (more peaked), and higher values make it + * softer (more uniform). + * + * @param input A mutable span of floating-point numbers. After the function + * returns, `input` contains the softmax probabilities. + * @param temperature A positive float value used to scale the logits before + * applying softmax. Must be greater than 0. + */ +void softmaxWithTemperature(std::span input, float temperature); + +/** + * @brief Normalizes the elements of the given float span in-place using the + * L2 norm method. + * + * This function scales the input vector such that its L2 norm (Euclidean norm) + * becomes 1. If the norm is zero, the result is a zero vector with the same + * size as the input. + * + * @param input A mutable span of floating-point values representing the data to + * be normalized. + */ +void normalize(std::span input); + +/** + * @brief Computes mean pooling across the modelOutput adjusted by an attention + * mask. + * + * This function aggregates the `modelOutput` span by sections defined by + * `attnMask`, computing the mean of sections influenced by the mask. The result + * is a vector where each element is the mean of a segment from the original + * data. + * + * @param modelOutput A span of floating-point numbers representing the model + * output. + * @param attnMask A span of integers where each integer is a weight + * corresponding to the elements in `modelOutput`. + * @return A std::vector containing the computed mean values of segments. + */ std::vector meanPooling(std::span modelOutput, std::span attnMask); /** diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp index f088659b8..d3761dced 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp @@ -18,7 +18,7 @@ std::vector hannWindow(size_t size) { return window; } -std::vector stftFromWaveform(std::span waveform, +std::vector stftFromWaveform(std::span waveform, size_t fftWindowSize, size_t hopSize) { // Initialize FFT FFT fft(fftWindowSize); diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h index d04939f12..7eaa26d83 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h @@ -6,7 +6,7 @@ namespace rnexecutorch::dsp { std::vector hannWindow(size_t size); -std::vector stftFromWaveform(std::span waveform, +std::vector stftFromWaveform(std::span waveform, size_t fftWindowSize, size_t hopSize); } // namespace rnexecutorch::dsp diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp new file mode 100644 index 000000000..877b85995 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include "gzip.h" + +namespace rnexecutorch::gzip { + +namespace { +constexpr int32_t kGzipWrapper = 16; // gzip header/trailer +constexpr int32_t kMemLevel = 8; // memory level +constexpr size_t kChunkSize = 16 * 1024; // 16 KiB stream buffer +} // namespace + +size_t deflateSize(const std::string &input) { + z_stream strm{}; + if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, + MAX_WBITS + kGzipWrapper, kMemLevel, + Z_DEFAULT_STRATEGY) != Z_OK) { + throw std::runtime_error("deflateInit2 failed"); + } + + size_t outSize = 0; + + strm.next_in = reinterpret_cast( + const_cast(input.data())); + strm.avail_in = static_cast(input.size()); + + std::vector buf(kChunkSize); + int ret; + do { + strm.next_out = buf.data(); + strm.avail_out = static_cast(buf.size()); + + ret = ::deflate(&strm, strm.avail_in ? Z_NO_FLUSH : Z_FINISH); + if (ret == Z_STREAM_ERROR) { + ::deflateEnd(&strm); + throw std::runtime_error("deflate stream error"); + } + + outSize += buf.size() - strm.avail_out; + } while (ret != Z_STREAM_END); + + ::deflateEnd(&strm); + return outSize; +} + +} // namespace rnexecutorch::gzip diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h new file mode 100644 index 000000000..73890e1f9 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h @@ -0,0 +1,7 @@ +#pragma once + +namespace rnexecutorch::gzip { + +size_t deflateSize(const std::string &input); + +} // namespace rnexecutorch::gzip diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 17225cce0..949d75ad1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -62,6 +62,30 @@ template class ModelHostObject : public JsiHostObject { "decode")); } + if constexpr (meta::HasTranscribe) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::transcribe>, + "transcribe")); + } + + if constexpr (meta::HasStream) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::stream>, + "stream")); + } + + if constexpr (meta::HasStreamInsert) { + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, promiseHostFunction<&Model::streamInsert>, + "streamInsert")); + } + + if constexpr (meta::HasStreamStop) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::streamStop>, + "streamStop")); + } + if constexpr (meta::SameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::encode>, diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 253a53238..10bc23af3 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -26,6 +26,26 @@ concept HasDecode = requires(T t) { { &T::decode }; }; +template +concept HasTranscribe = requires(T t) { + { &T::transcribe }; +}; + +template +concept HasStream = requires(T t) { + { &T::stream }; +}; + +template +concept HasStreamInsert = requires(T t) { + { &T::streamInsert }; +}; + +template +concept HasStreamStop = requires(T t) { + { &T::streamStop }; +}; + template concept IsNumeric = std::is_arithmetic_v; @@ -34,4 +54,4 @@ concept ProvidesMemoryLowerBound = requires(T t) { { &T::getMemoryLowerBound }; }; -} // namespace rnexecutorch::meta \ No newline at end of file +} // namespace rnexecutorch::meta diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 77b4fd63a..3eab6dd83 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -142,7 +142,8 @@ BaseModel::getMethodMeta(const std::string &methodName) { return module_->method_meta(methodName); } -Result> BaseModel::forward(const EValue &input_evalue) { +Result> +BaseModel::forward(const EValue &input_evalue) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot perform forward pass"); } @@ -150,7 +151,7 @@ Result> BaseModel::forward(const EValue &input_evalue) { } Result> -BaseModel::forward(const std::vector &input_evalues) { +BaseModel::forward(const std::vector &input_evalues) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot perform forward pass"); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index dbf3c433a..983dc9b74 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -26,8 +26,9 @@ class BaseModel { getAllInputShapes(std::string methodName = "forward"); std::vector forwardJS(std::vector tensorViewVec); - Result> forward(const EValue &input_value); - Result> forward(const std::vector &input_value); + Result> forward(const EValue &input_value) const; + Result> + forward(const std::vector &input_value) const; Result> execute(const std::string &methodName, const std::vector &input_value); Result diff --git a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp b/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp deleted file mode 100644 index a0da38708..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -namespace rnexecutorch::models { - -EncoderDecoderBase::EncoderDecoderBase( - const std::string &encoderPath, const std::string &decoderPath, - std::shared_ptr callInvoker) - : callInvoker(callInvoker), - encoder_(std::make_unique(encoderPath, callInvoker)), - decoder_(std::make_unique(decoderPath, callInvoker)) {}; - -size_t EncoderDecoderBase::getMemoryLowerBound() const noexcept { - return encoder_->getMemoryLowerBound() + decoder_->getMemoryLowerBound(); -} - -void EncoderDecoderBase::unload() noexcept { - encoder_.reset(nullptr); - decoder_.reset(nullptr); -} - -} // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h b/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h deleted file mode 100644 index f2f2265aa..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace rnexecutorch::models { - -using namespace facebook; -using executorch::aten::Tensor; -using executorch::runtime::EValue; - -class EncoderDecoderBase { -public: - explicit EncoderDecoderBase(const std::string &encoderPath, - const std::string &decoderPath, - std::shared_ptr callInvoker); - size_t getMemoryLowerBound() const noexcept; - void unload() noexcept; - -protected: - std::shared_ptr callInvoker; - std::unique_ptr encoder_; - std::unique_ptr decoder_; - -private: - size_t memorySizeLowerBound; -}; - -} // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index 3f627095e..d444b9c91 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -1,64 +1,125 @@ -#include -#include -#include +#include + +#include "SpeechToText.h" namespace rnexecutorch::models::speech_to_text { using namespace ::executorch::extension; -SpeechToText::SpeechToText(const std::string &encoderPath, - const std::string &decoderPath, - const std::string &modelName, +SpeechToText::SpeechToText(const std::string &encoderSource, + const std::string &decoderSource, + const std::string &tokenizerSource, std::shared_ptr callInvoker) - : EncoderDecoderBase(encoderPath, decoderPath, callInvoker), - modelName(modelName) { - initializeStrategy(); + : callInvoker(std::move(callInvoker)), + encoder(std::make_unique(encoderSource, this->callInvoker)), + decoder(std::make_unique(decoderSource, this->callInvoker)), + tokenizer(std::make_unique(tokenizerSource, + this->callInvoker)), + asr(std::make_unique(this->encoder.get(), this->decoder.get(), + this->tokenizer.get())), + processor(std::make_unique(this->asr.get())), + isStreaming(false), readyToProcess(false) {} + +std::shared_ptr +SpeechToText::encode(std::span waveform) const { + std::vector encoderOutput = this->asr->encode(waveform); + return this->makeOwningBuffer(encoderOutput); } -void SpeechToText::initializeStrategy() { - if (modelName == "whisper") { - strategy = std::make_unique(); - } else { - throw std::runtime_error("Unsupported STT model: " + modelName + - ". Only 'whisper' is supported."); - } +std::shared_ptr +SpeechToText::decode(std::span tokens, + std::span encoderOutput) const { + std::vector decoderOutput = this->asr->decode(tokens, encoderOutput); + return this->makeOwningBuffer(decoderOutput); } -void SpeechToText::encode(std::span waveform) { - const auto modelInputTensor = strategy->prepareAudioInput(waveform); +std::string SpeechToText::transcribe(std::span waveform, + std::string languageOption) const { + std::vector segments = + this->asr->transcribe(waveform, DecodingOptions(languageOption)); + std::string transcription; + + size_t transcriptionLength = 0; + for (auto &segment : segments) { + for (auto &word : segment.words) { + transcriptionLength += word.content.size(); + } + } + transcription.reserve(transcriptionLength); - const auto result = encoder_->forward(modelInputTensor); - if (!result.ok()) { - throw std::runtime_error( - "Forward pass failed during encoding, error code: " + - std::to_string(static_cast(result.error()))); + for (auto &segment : segments) { + for (auto &word : segment.words) { + transcription += word.content; + } } + return transcription; +} - encoderOutput = result.get().at(0); +size_t SpeechToText::getMemoryLowerBound() const noexcept { + return this->encoder->getMemoryLowerBound() + + this->decoder->getMemoryLowerBound() + + this->tokenizer->getMemoryLowerBound(); } std::shared_ptr -SpeechToText::decode(std::vector prevTokens) { - if (encoderOutput.isNone()) { - throw std::runtime_error("Empty encodings on decode call, make sure to " - "call encode() prior to decode()!"); +SpeechToText::makeOwningBuffer(std::span vectorView) const { + auto owningArrayBuffer = + std::make_shared(vectorView.size_bytes()); + std::memcpy(owningArrayBuffer->data(), vectorView.data(), + vectorView.size_bytes()); + return owningArrayBuffer; +} + +void SpeechToText::stream(std::shared_ptr callback, + std::string languageOption) { + if (this->isStreaming) { + throw std::runtime_error("Streaming is already in progress"); + } + + auto nativeCallback = [this, callback](const std::string &committed, + const std::string &nonCommitted, + bool isDone) { + this->callInvoker->invokeAsync( + [callback, committed, nonCommitted, isDone](jsi::Runtime &rt) { + callback->call(rt, jsi::String::createFromUtf8(rt, committed), + jsi::String::createFromUtf8(rt, nonCommitted), + jsi::Value(isDone)); + }); + }; + + this->resetStreamState(); + + this->isStreaming = true; + while (this->isStreaming) { + if (!this->readyToProcess || + this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + ProcessResult res = + this->processor->processIter(DecodingOptions(languageOption)); + nativeCallback(res.committed, res.nonCommitted, false); + this->readyToProcess = false; } - const auto prevTokensTensor = strategy->prepareTokenInput(prevTokens); + std::string committed = this->processor->finish(); + nativeCallback(committed, "", true); +} - const auto decoderMethod = strategy->getDecoderMethod(); - const auto decoderResult = - decoder_->execute(decoderMethod, {prevTokensTensor, encoderOutput}); +void SpeechToText::streamStop() { this->isStreaming = false; } - if (!decoderResult.ok()) { - throw std::runtime_error( - "Forward pass failed during decoding, error code: " + - std::to_string(static_cast(decoderResult.error()))); +void SpeechToText::streamInsert(std::span waveform) { + if (!this->isStreaming) { + throw std::runtime_error("Streaming is not started"); } + this->processor->insertAudioChunk(waveform); + this->readyToProcess = true; +} - const auto decoderOutputTensor = decoderResult.get().at(0).toTensor(); - const auto innerDim = decoderOutputTensor.size(1); - return strategy->extractOutputToken(decoderOutputTensor); +void SpeechToText::resetStreamState() { + this->isStreaming = false; + this->readyToProcess = false; + this->processor = std::make_unique(this->asr.get()); } } // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h index fa90b53dc..a6f3779e4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h @@ -1,38 +1,60 @@ #pragma once -#include "ReactCommon/CallInvoker.h" -#include "executorch/runtime/core/evalue.h" -#include -#include -#include -#include -#include - -#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include -#include +#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h" namespace rnexecutorch { + namespace models::speech_to_text { -class SpeechToText : public EncoderDecoderBase { + +using namespace asr; +using namespace types; +using namespace stream; + +class SpeechToText { public: - explicit SpeechToText(const std::string &encoderPath, - const std::string &decoderPath, - const std::string &modelName, + explicit SpeechToText(const std::string &encoderSource, + const std::string &decoderSource, + const std::string &tokenizerSource, std::shared_ptr callInvoker); - void encode(std::span waveform); - std::shared_ptr decode(std::vector prevTokens); + + std::shared_ptr encode(std::span waveform) const; + std::shared_ptr + decode(std::span tokens, std::span encoderOutput) const; + std::string transcribe(std::span waveform, + std::string languageOption) const; + + size_t getMemoryLowerBound() const noexcept; + + // Stream + void stream(std::shared_ptr callback, + std::string languageOption); + void streamStop(); + void streamInsert(std::span waveform); private: - const std::string modelName; - executorch::runtime::EValue encoderOutput; - std::unique_ptr strategy; + std::unique_ptr encoder; + std::unique_ptr decoder; + std::unique_ptr tokenizer; + std::unique_ptr asr; - void initializeStrategy(); + std::shared_ptr + makeOwningBuffer(std::span vectorView) const; + + // Stream + std::shared_ptr callInvoker; + std::unique_ptr processor; + bool isStreaming; + bool readyToProcess; + + constexpr static int32_t kMinAudioSamples = 16000; // 1 second + + void resetStreamState(); }; + } // namespace models::speech_to_text REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string, std::string, std::string, std::shared_ptr); + } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h deleted file mode 100644 index c68f2e395..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "executorch/extension/tensor/tensor_ptr.h" -#include -#include -#include - -namespace rnexecutorch::models::speech_to_text { - -using TensorPtr = ::executorch::extension::TensorPtr; - -class SpeechToTextStrategy { -public: - virtual ~SpeechToTextStrategy() = default; - - virtual TensorPtr prepareAudioInput(std::span waveform) = 0; - - virtual TensorPtr - prepareTokenInput(const std::vector &prevTokens) = 0; - - virtual std::string getDecoderMethod() const = 0; - - virtual std::shared_ptr extractOutputToken( - const executorch::aten::Tensor &decoderOutputTensor) const = 0; -}; - -} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp deleted file mode 100644 index 3b33c4450..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "executorch/extension/tensor/tensor_ptr.h" -#include "rnexecutorch/data_processing/dsp.h" -#include - -namespace rnexecutorch::models::speech_to_text { - -using namespace ::executorch::extension; -using namespace ::executorch::aten; - -TensorPtr WhisperStrategy::prepareAudioInput(std::span waveform) { - constexpr auto fftWindowSize = 512; - constexpr auto stftHopLength = 160; - constexpr auto innerDim = 256; - preprocessedData = - dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength); - const auto numFrames = preprocessedData.size() / innerDim; - std::vector inputShape = {static_cast(numFrames), innerDim}; - return make_tensor_ptr(std::move(inputShape), std::move(preprocessedData)); -} - -TensorPtr -WhisperStrategy::prepareTokenInput(const std::vector &prevTokens) { - tokens32.clear(); - tokens32.reserve(prevTokens.size()); - for (auto token : prevTokens) { - tokens32.push_back(static_cast(token)); - } - auto tensorSizes = {1, static_cast(tokens32.size())}; - return make_tensor_ptr(std::move(tensorSizes), std::move(tokens32)); -} - -std::shared_ptr WhisperStrategy::extractOutputToken( - const executorch::aten::Tensor &decoderOutputTensor) const { - const auto innerDim = decoderOutputTensor.size(1); - const auto dictSize = decoderOutputTensor.size(2); - auto outputNumel = decoderOutputTensor.numel(); - auto dataPtr = - static_cast(decoderOutputTensor.const_data_ptr()) + - (innerDim - 1) * dictSize; - - std::span modelOutput(dataPtr, outputNumel / innerDim); - auto createBuffer = [](const auto &data, size_t size) { - auto buffer = std::make_shared(size); - std::memcpy(buffer->data(), data, size); - return buffer; - }; - return createBuffer(modelOutput.data(), modelOutput.size_bytes()); -} - -} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h deleted file mode 100644 index 936be3c00..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include "SpeechToTextStrategy.h" -#include -#include - -namespace rnexecutorch::models::speech_to_text { - -class WhisperStrategy final : public SpeechToTextStrategy { -public: - TensorPtr prepareAudioInput(std::span waveform) override; - - TensorPtr prepareTokenInput(const std::vector &prevTokens) override; - - std::string getDecoderMethod() const override { return "forward"; } - - std::shared_ptr extractOutputToken( - const executorch::aten::Tensor &decoderOutputTensor) const override; - -private: - std::vector preprocessedData; - std::vector tokens32; -}; - -} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp new file mode 100644 index 000000000..5a56f2d7e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp @@ -0,0 +1,307 @@ +#include + +#include "ASR.h" +#include "executorch/extension/tensor/tensor_ptr.h" +#include "rnexecutorch/data_processing/Numerical.h" +#include "rnexecutorch/data_processing/dsp.h" +#include "rnexecutorch/data_processing/gzip.h" + +namespace rnexecutorch::models::speech_to_text::asr { + +ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder, + const TokenizerModule *tokenizer) + : encoder(encoder), decoder(decoder), tokenizer(tokenizer), + startOfTranscriptionToken( + this->tokenizer->tokenToId("<|startoftranscript|>")), + endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")), + timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {} + +std::vector +ASR::getInitialSequence(const DecodingOptions &options) const { + std::vector seq; + seq.push_back(this->startOfTranscriptionToken); + + if (options.language.has_value()) { + int32_t langToken = + this->tokenizer->tokenToId("<|" + options.language.value() + "|>"); + int32_t taskToken = this->tokenizer->tokenToId("<|transcribe|>"); + seq.push_back(langToken); + seq.push_back(taskToken); + } + + seq.push_back(this->timestampBeginToken); + + return seq; +} + +GenerationResult ASR::generate(std::span waveform, + float temperature, + const DecodingOptions &options) const { + std::vector encoderOutput = this->encode(waveform); + + std::vector sequenceIds = this->getInitialSequence(options); + const size_t initialSequenceLenght = sequenceIds.size(); + std::vector scores; + + while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) { + std::vector logits = this->decode(sequenceIds, encoderOutput); + + // intentionally comparing float to float + // temperatures are predefined, so this is safe + if (temperature == 0.0f) { + numerical::softmax(logits); + } else { + numerical::softmaxWithTemperature(logits, temperature); + } + + const std::vector &probs = logits; + + int32_t nextId; + float nextProb; + + // intentionally comparing float to float + // temperatures are predefined, so this is safe + if (temperature == 0.0f) { + auto maxIt = std::ranges::max_element(probs); + nextId = static_cast(std::distance(probs.begin(), maxIt)); + nextProb = *maxIt; + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + std::mt19937 gen((std::random_device{}())); + nextId = dist(gen); + nextProb = probs[nextId]; + } + + sequenceIds.push_back(nextId); + scores.push_back(nextProb); + + if (nextId == this->endOfTranscriptionToken) { + break; + } + } + + return {.tokens = std::vector( + sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()), + .scores = scores}; +} + +float ASR::getCompressionRatio(const std::string &text) const { + size_t compressedSize = gzip::deflateSize(text); + return static_cast(text.size()) / static_cast(compressedSize); +} + +std::vector +ASR::generateWithFallback(std::span waveform, + const DecodingOptions &options) const { + std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; + std::vector bestTokens; + + for (auto t : temperatures) { + auto [tokens, scores] = this->generate(waveform, t, options); + + const float cumLogProb = std::transform_reduce( + scores.begin(), scores.end(), 0.0f, std::plus<>(), + [](float s) { return std::log(std::max(s, 1e-9f)); }); + + const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1); + const std::string text = this->tokenizer->decode(tokens, true); + const float compressionRatio = this->getCompressionRatio(text); + + if (avgLogProb >= -1.0f && compressionRatio < 2.4f) { + bestTokens = std::move(tokens); + break; + } + } + + return this->calculateWordLevelTimestamps(bestTokens, waveform); +} + +std::vector +ASR::calculateWordLevelTimestamps(std::span generatedTokens, + const std::span waveform) const { + const size_t generatedTokensSize = generatedTokens.size(); + if (generatedTokensSize < 2 || + generatedTokens[generatedTokensSize - 1] != + this->endOfTranscriptionToken || + generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) { + return {}; + } + std::vector segments; + std::vector tokens; + int32_t prevTimestamp = this->timestampBeginToken; + + for (size_t i = 0; i < generatedTokensSize; i++) { + if (generatedTokens[i] < this->timestampBeginToken) { + tokens.push_back(generatedTokens[i]); + } + if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken && + generatedTokens[i] >= this->timestampBeginToken) { + const int32_t start = prevTimestamp; + const int32_t end = generatedTokens[i - 1]; + auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); + if (words.size()) { + segments.emplace_back(std::move(words), 0.0); + } + tokens.clear(); + prevTimestamp = generatedTokens[i]; + } + } + + const int32_t start = prevTimestamp; + const int32_t end = generatedTokens[generatedTokensSize - 2]; + auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); + + if (words.size()) { + segments.emplace_back(std::move(words), 0.0); + } + + float scalingFactor = + static_cast(waveform.size()) / + (ASR::kSamplingRate * (end - this->timestampBeginToken) * + ASR::kTimePrecision); + if (scalingFactor < 1.0f) { + for (auto &seg : segments) { + for (auto &w : seg.words) { + w.start *= scalingFactor; + w.end *= scalingFactor; + } + } + } + + return segments; +} + +std::vector +ASR::estimateWordLevelTimestampsLinear(std::span tokens, + int32_t start, int32_t end) const { + const std::vector tokensVec(tokens.begin(), tokens.end()); + const std::string segmentText = this->tokenizer->decode(tokensVec, true); + std::istringstream iss(segmentText); + std::vector wordsStr; + std::string word; + while (iss >> word) { + wordsStr.emplace_back(" "); + wordsStr.back().append(word); + } + + size_t numChars = 0; + for (const auto &w : wordsStr) { + numChars += w.size(); + } + const float duration = (end - start) * ASR::kTimePrecision; + const float timePerChar = duration / std::max(1, numChars); + const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision; + + std::vector wordObjs; + wordObjs.reserve(wordsStr.size()); + int32_t prevCharCount = 0; + for (auto &w : wordsStr) { + const auto wSize = static_cast(w.size()); + const float wStart = startOffset + prevCharCount * timePerChar; + const float wEnd = wStart + timePerChar * wSize; + prevCharCount += wSize; + wordObjs.emplace_back(std::move(w), wStart, wEnd); + } + + return wordObjs; +} + +std::vector ASR::transcribe(std::span waveform, + const DecodingOptions &options) const { + int32_t seek = 0; + std::vector results; + + while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) { + int32_t start = seek * ASR::kSamplingRate; + const auto end = std::min( + (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size()); + std::span chunk = waveform.subspan(start, end - start); + + if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) { + break; + } + + std::vector segments = this->generateWithFallback(chunk, options); + + if (segments.empty()) { + seek += ASR::kChunkSize; + continue; + } + + for (auto &seg : segments) { + for (auto &w : seg.words) { + w.start += seek; + w.end += seek; + } + } + + seek = static_cast(segments.back().words.back().end); + results.insert(results.end(), std::make_move_iterator(segments.begin()), + std::make_move_iterator(segments.end())); + } + + return results; +} + +std::vector ASR::encode(std::span waveform) const { + constexpr int32_t fftWindowSize = 512; + constexpr int32_t stftHopLength = 160; + constexpr int32_t innerDim = 256; + + std::vector preprocessedData = + dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength); + const auto numFrames = + static_cast(preprocessedData.size()) / innerDim; + std::vector inputShape = {numFrames, innerDim}; + + const auto modelInputTensor = executorch::extension::make_tensor_ptr( + std::move(inputShape), std::move(preprocessedData)); + const auto encoderResult = this->encoder->forward(modelInputTensor); + + if (!encoderResult.ok()) { + throw std::runtime_error( + "Forward pass failed during encoding, error code: " + + std::to_string(static_cast(encoderResult.error()))); + } + + const auto decoderOutputTensor = encoderResult.get().at(0).toTensor(); + const int32_t outputNumel = decoderOutputTensor.numel(); + + const float *const dataPtr = decoderOutputTensor.const_data_ptr(); + return {dataPtr, dataPtr + outputNumel}; +} + +std::vector ASR::decode(std::span tokens, + std::span encoderOutput) const { + std::vector tokenShape = {1, static_cast(tokens.size())}; + auto tokenTensor = executorch::extension::make_tensor_ptr( + std::move(tokenShape), tokens.data(), ScalarType::Int); + + const auto encoderOutputSize = static_cast(encoderOutput.size()); + std::vector encShape = {1, ASR::kNumFrames, + encoderOutputSize / ASR::kNumFrames}; + auto encoderTensor = executorch::extension::make_tensor_ptr( + std::move(encShape), encoderOutput.data(), ScalarType::Float); + + const auto decoderResult = + this->decoder->forward({tokenTensor, encoderTensor}); + + if (!decoderResult.ok()) { + throw std::runtime_error( + "Forward pass failed during decoding, error code: " + + std::to_string(static_cast(decoderResult.error()))); + } + + const auto logitsTensor = decoderResult.get().at(0).toTensor(); + const int32_t outputNumel = logitsTensor.numel(); + + const size_t innerDim = logitsTensor.size(1); + const size_t dictSize = logitsTensor.size(2); + + const float *const dataPtr = + logitsTensor.const_data_ptr() + (innerDim - 1) * dictSize; + + return {dataPtr, dataPtr + outputNumel / innerDim}; +} + +} // namespace rnexecutorch::models::speech_to_text::asr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h new file mode 100644 index 000000000..605052363 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h @@ -0,0 +1,61 @@ +#pragma once + +#include "rnexecutorch/TokenizerModule.h" +#include "rnexecutorch/models/BaseModel.h" +#include "rnexecutorch/models/speech_to_text/types/DecodingOptions.h" +#include "rnexecutorch/models/speech_to_text/types/GenerationResult.h" +#include "rnexecutorch/models/speech_to_text/types/Segment.h" + +namespace rnexecutorch::models::speech_to_text::asr { + +using namespace types; + +class ASR { +public: + explicit ASR(const models::BaseModel *encoder, + const models::BaseModel *decoder, + const TokenizerModule *tokenizer); + std::vector transcribe(std::span waveform, + const DecodingOptions &options) const; + std::vector encode(std::span waveform) const; + std::vector decode(std::span tokens, + std::span encoderOutput) const; + +private: + const models::BaseModel *encoder; + const models::BaseModel *decoder; + const TokenizerModule *tokenizer; + + int32_t startOfTranscriptionToken; + int32_t endOfTranscriptionToken; + int32_t timestampBeginToken; + + // Time precision used by Whisper timestamps: each token spans 0.02 seconds + constexpr static float kTimePrecision = 0.02f; + // The maximum number of tokens the decoder can generate per chunk + constexpr static int32_t kMaxDecodeLength = 128; + // Maximum duration of each audio chunk to process (in seconds) + constexpr static int32_t kChunkSize = 30; + // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) + constexpr static int32_t kSamplingRate = 16000; + // Minimum allowed chunk length before processing (in audio samples) + constexpr static int32_t kMinChunkSamples = 1 * 16000; + // Number of mel frames output by the encoder (derived from input spectrogram) + constexpr static int32_t kNumFrames = 1500; + + std::vector getInitialSequence(const DecodingOptions &options) const; + GenerationResult generate(std::span waveform, float temperature, + const DecodingOptions &options) const; + std::vector + generateWithFallback(std::span waveform, + const DecodingOptions &options) const; + std::vector + calculateWordLevelTimestamps(std::span tokens, + std::span waveform) const; + std::vector + estimateWordLevelTimestampsLinear(std::span tokens, + int32_t start, int32_t end) const; + float getCompressionRatio(const std::string &text) const; +}; + +} // namespace rnexecutorch::models::speech_to_text::asr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp new file mode 100644 index 000000000..3e4d6a7ca --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp @@ -0,0 +1,80 @@ +#include "HypothesisBuffer.h" + +namespace rnexecutorch::models::speech_to_text::stream { + +void HypothesisBuffer::insert(std::span newWords, float offset) { + this->fresh.clear(); + for (const auto &word : newWords) { + const float newStart = word.start + offset; + if (newStart > lastCommittedTime - 0.5f) { + this->fresh.emplace_back(word.content, newStart, word.end + offset); + } + } + + if (!this->fresh.empty() && !this->committedInBuffer.empty()) { + const float a = this->fresh.front().start; + if (std::fabs(a - lastCommittedTime) < 1.0f) { + const size_t cn = this->committedInBuffer.size(); + const size_t nn = this->fresh.size(); + const std::size_t maxCheck = std::min({cn, nn, 5}); + for (size_t i = 1; i <= maxCheck; i++) { + std::string c; + for (auto it = this->committedInBuffer.cend() - i; + it != this->committedInBuffer.cend(); ++it) { + if (!c.empty()) { + c += ' '; + } + c += it->content; + } + + std::string tail; + auto it = this->fresh.cbegin(); + for (size_t k = 0; k < i; k++, it++) { + if (!tail.empty()) { + tail += ' '; + } + tail += it->content; + } + + if (c == tail) { + this->fresh.erase(this->fresh.begin(), this->fresh.begin() + i); + break; + } + } + } + } +} + +std::deque HypothesisBuffer::flush() { + std::deque commit; + + while (!this->fresh.empty() && !this->buffer.empty()) { + if (this->fresh.front().content != this->buffer.front().content) { + break; + } + commit.push_back(this->fresh.front()); + this->buffer.pop_front(); + this->fresh.pop_front(); + } + + if (!commit.empty()) { + lastCommittedTime = commit.back().end; + } + + this->buffer = std::move(this->fresh); + this->fresh.clear(); + this->committedInBuffer.insert(this->committedInBuffer.end(), commit.begin(), + commit.end()); + return commit; +} + +void HypothesisBuffer::popCommitted(float time) { + while (!this->committedInBuffer.empty() && + this->committedInBuffer.front().end <= time) { + this->committedInBuffer.pop_front(); + } +} + +std::deque HypothesisBuffer::complete() const { return this->buffer; } + +} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h new file mode 100644 index 000000000..ea4e73328 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#include "rnexecutorch/models/speech_to_text/types/Word.h" + +namespace rnexecutorch::models::speech_to_text::stream { + +using namespace types; + +class HypothesisBuffer { +public: + void insert(std::span newWords, float offset); + std::deque flush(); + void popCommitted(float time); + std::deque complete() const; + +private: + float lastCommittedTime = 0.0f; + + std::deque committedInBuffer; + std::deque buffer; + std::deque fresh; +}; + +} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp new file mode 100644 index 000000000..63cffd67c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp @@ -0,0 +1,96 @@ +#include + +#include "OnlineASRProcessor.h" + +namespace rnexecutorch::models::speech_to_text::stream { + +OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {} + +void OnlineASRProcessor::insertAudioChunk(std::span audio) { + audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end()); +} + +ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) { + std::vector res = asr->transcribe(audioBuffer, options); + + std::vector tsw; + for (const auto &segment : res) { + for (const auto &word : segment.words) { + tsw.push_back(word); + } + } + + this->hypothesisBuffer.insert(tsw, this->bufferTimeOffset); + std::deque flushed = this->hypothesisBuffer.flush(); + this->committed.insert(this->committed.end(), flushed.begin(), flushed.end()); + + constexpr int32_t chunkThresholdSec = 15; + if (static_cast(audioBuffer.size()) / + OnlineASRProcessor::kSamplingRate > + chunkThresholdSec) { + chunkCompletedSegment(res); + } + + std::deque nonCommittedWords = this->hypothesisBuffer.complete(); + return {this->toFlush(flushed), this->toFlush(nonCommittedWords)}; +} + +void OnlineASRProcessor::chunkCompletedSegment(std::span res) { + if (this->committed.empty()) + return; + + std::vector ends(res.size()); + std::ranges::transform(res, ends.begin(), [](const Segment &seg) { + return seg.words.back().end; + }); + + const float t = this->committed.back().end; + + if (ends.size() > 1) { + float e = ends[ends.size() - 2] + this->bufferTimeOffset; + while (ends.size() > 2 && e > t) { + ends.pop_back(); + e = ends[ends.size() - 2] + this->bufferTimeOffset; + } + if (e <= t) { + chunkAt(e); + } + } +} + +void OnlineASRProcessor::chunkAt(float time) { + this->hypothesisBuffer.popCommitted(time); + + const float cutSeconds = time - this->bufferTimeOffset; + auto startIndex = + static_cast(cutSeconds * OnlineASRProcessor::kSamplingRate); + + if (startIndex < audioBuffer.size()) { + audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex); + } else { + audioBuffer.clear(); + } + + this->bufferTimeOffset = time; +} + +std::string OnlineASRProcessor::finish() { + const std::deque buffer = this->hypothesisBuffer.complete(); + std::string committedText = this->toFlush(buffer); + this->bufferTimeOffset += static_cast(audioBuffer.size()) / + OnlineASRProcessor::kSamplingRate; + return committedText; +} + +std::string OnlineASRProcessor::toFlush(const std::deque &words) const { + std::string text; + text.reserve(std::accumulate( + words.cbegin(), words.cend(), 0, + [](size_t sum, const Word &w) { return sum + w.content.size(); })); + for (const auto &word : words) { + text.append(word.content); + } + return text; +} + +} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h new file mode 100644 index 000000000..403cf87d1 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h @@ -0,0 +1,36 @@ +#pragma once + +#include "rnexecutorch/models/speech_to_text/asr/ASR.h" +#include "rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h" +#include "rnexecutorch/models/speech_to_text/types/ProcessResult.h" + +namespace rnexecutorch::models::speech_to_text::stream { + +using namespace asr; +using namespace types; + +class OnlineASRProcessor { +public: + explicit OnlineASRProcessor(const ASR *asr); + + void insertAudioChunk(std::span audio); + ProcessResult processIter(const DecodingOptions &options); + std::string finish(); + + std::vector audioBuffer; + +private: + const ASR *asr; + constexpr static int32_t kSamplingRate = 16000; + + HypothesisBuffer hypothesisBuffer; + float bufferTimeOffset = 0.0f; + std::vector committed; + + void chunkCompletedSegment(std::span res); + void chunkAt(float time); + + std::string toFlush(const std::deque &words) const; +}; + +} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h new file mode 100644 index 000000000..c351ddc55 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::models::speech_to_text::types { + +struct DecodingOptions { + explicit DecodingOptions(const std::string &language) + : language(language.empty() ? std::nullopt : std::optional(language)) {} + + std::optional language; +}; + +} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h new file mode 100644 index 000000000..efd520442 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text::types { + +struct GenerationResult { + std::vector tokens; + std::vector scores; +}; + +} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h new file mode 100644 index 000000000..0cb05e5a6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text::types { + +struct ProcessResult { + std::string committed; + std::string nonCommitted; +}; + +} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h new file mode 100644 index 000000000..5b8368fe4 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include "Word.h" + +namespace rnexecutorch::models::speech_to_text::types { + +struct Segment { + std::vector words; + float noSpeechProbability; +}; + +} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h new file mode 100644 index 000000000..98c72f273 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text::types { + +struct Word { + std::string content; + float start; + float end; +}; + +} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/react-native-executorch.podspec b/packages/react-native-executorch/react-native-executorch.podspec index e58cf31c2..b94cc771a 100644 --- a/packages/react-native-executorch/react-native-executorch.podspec +++ b/packages/react-native-executorch/react-native-executorch.podspec @@ -75,6 +75,8 @@ Pod::Spec.new do |s| "common/**/*.{cpp,c,h,hpp}", ] + s.libraries = "z" + # Exclude file with tests to not introduce gtest dependency. # Do not include the headers from common/rnexecutorch/jsi/ as source files. # Xcode/Cocoapods leaks them to other pods that an app also depends on, so if diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index 878a67381..74167b8b7 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -1,84 +1,154 @@ +import { Logger } from '../../common/Logger'; import { DecodingOptions, SpeechToTextModelConfig } from '../../types/stt'; -import { ASR } from '../../utils/SpeechToTextModule/ASR'; -import { OnlineASRProcessor } from '../../utils/SpeechToTextModule/OnlineProcessor'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; export class SpeechToTextModule { - private modelConfig!: SpeechToTextModelConfig; - private asr: ASR = new ASR(); + private nativeModule: any; - private processor: OnlineASRProcessor = new OnlineASRProcessor(this.asr); - private isStreaming = false; - private readyToProcess = false; - private minAudioSamples: number = 1 * 16000; // 1 second + private modelConfig!: SpeechToTextModelConfig; public async load( model: SpeechToTextModelConfig, onDownloadProgressCallback: (progress: number) => void = () => {} ) { this.modelConfig = model; - return this.asr.load(model, onDownloadProgressCallback); + + const tokenizerLoadPromise = ResourceFetcher.fetch( + undefined, + model.tokenizerSource + ); + const encoderDecoderPromise = ResourceFetcher.fetch( + onDownloadProgressCallback, + model.encoderSource, + model.decoderSource + ); + const [tokenizerSources, encoderDecoderResults] = await Promise.all([ + tokenizerLoadPromise, + encoderDecoderPromise, + ]); + const encoderSource = encoderDecoderResults?.[0]; + const decoderSource = encoderDecoderResults?.[1]; + if (!encoderSource || !decoderSource || !tokenizerSources) { + throw new Error('Download interrupted.'); + } + this.nativeModule = await global.loadSpeechToText( + encoderSource, + decoderSource, + tokenizerSources[0]! + ); } - public async encode(waveform: Float32Array): Promise { - return this.asr.encode(waveform); + public async encode( + waveform: Float32Array | number[] + ): Promise { + if (Array.isArray(waveform)) { + Logger.info( + 'Passing waveform as number[] is deprecated, use Float32Array instead' + ); + waveform = new Float32Array(waveform); + } + return new Float32Array(await this.nativeModule.encode(waveform)); } - public async decode(tokens: number[]): Promise { - return this.asr.decode(tokens); + public async decode( + tokens: Int32Array | number[], + encoderOutput: Float32Array | number[] + ): Promise { + if (Array.isArray(tokens)) { + Logger.info( + 'Passing tokens as number[] is deprecated, use Int32Array instead' + ); + tokens = new Int32Array(tokens); + } + if (Array.isArray(encoderOutput)) { + Logger.info( + 'Passing encoderOutput as number[] is deprecated, use Float32Array instead' + ); + encoderOutput = new Float32Array(encoderOutput); + } + return new Float32Array( + await this.nativeModule.decode(tokens, encoderOutput) + ); } public async transcribe( - waveform: number[], + waveform: Float32Array | number[], options: DecodingOptions = {} ): Promise { this.validateOptions(options); - const segments = await this.asr.transcribe(waveform, options); - - let transcription = ''; - for (const segment of segments) { - for (const word of segment.words) { - transcription += ` ${word.word}`; - } + if (Array.isArray(waveform)) { + Logger.info( + 'Passing waveform as number[] is deprecated, use Float32Array instead' + ); + waveform = new Float32Array(waveform); } - return transcription.trim(); + return this.nativeModule.transcribe(waveform, options.language || ''); } - public async *stream(options: DecodingOptions = {}) { - if (this.isStreaming) { - throw new Error('Streaming is already in progress'); - } + public async *stream( + options: DecodingOptions = {} + ): AsyncGenerator<{ committed: string; nonCommitted: string }> { this.validateOptions(options); - this.resetStreamState(); - - this.isStreaming = true; - while (this.isStreaming) { - if ( - !this.readyToProcess || - this.processor.audioBuffer.length < this.minAudioSamples - ) { - await new Promise((resolve) => setTimeout(resolve, 100)); + + const queue: { committed: string; nonCommitted: string }[] = []; + let waiter: (() => void) | null = null; + let finished = false; + let error: unknown; + + const wake = () => { + waiter?.(); + waiter = null; + }; + + (async () => { + try { + await this.nativeModule.stream( + (committed: string, nonCommitted: string, isDone: boolean) => { + queue.push({ committed, nonCommitted }); + if (isDone) { + finished = true; + } + wake(); + }, + options.language || '' + ); + finished = true; + wake(); + } catch (e) { + error = e; + finished = true; + wake(); + } + })(); + + while (true) { + if (queue.length > 0) { + yield queue.shift()!; + if (finished && queue.length === 0) { + return; + } continue; } - - const { committed, nonCommitted } = - await this.processor.processIter(options); - yield { committed, nonCommitted }; - this.readyToProcess = false; + if (error) throw error; + if (finished) return; + await new Promise((r) => (waiter = r)); } - - const { committed } = await this.processor.finish(); - yield { committed, nonCommitted: '' }; } - public streamStop() { - this.isStreaming = false; + public async streamInsert(waveform: Float32Array | number[]): Promise { + if (Array.isArray(waveform)) { + Logger.info( + 'Passing waveform as number[] is deprecated, use Float32Array instead' + ); + waveform = new Float32Array(waveform); + } + return this.nativeModule.streamInsert(waveform); } - public streamInsert(waveform: number[]) { - this.processor.insertAudioChunk(waveform); - this.readyToProcess = true; + public async streamStop(): Promise { + return this.nativeModule.streamStop(); } private validateOptions(options: DecodingOptions) { @@ -89,10 +159,4 @@ export class SpeechToTextModule { throw new Error('Model is multilingual, provide a language'); } } - - private resetStreamState() { - this.isStreaming = false; - this.readyToProcess = false; - this.processor = new OnlineASRProcessor(this.asr); - } } diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts index fce1179a7..20627ca11 100644 --- a/packages/react-native-executorch/src/types/stt.ts +++ b/packages/react-native-executorch/src/types/stt.ts @@ -1,17 +1,5 @@ import { ResourceSource } from './common'; -export type WordTuple = [number, number, string]; - -export interface WordObject { - start: number; - end: number; - word: string; -} - -export interface Segment { - words: WordObject[]; -} - // Languages supported by whisper (not whisper.en) export type SpeechToTextLanguage = | 'af' diff --git a/packages/react-native-executorch/src/utils/SpeechToTextModule/ASR.ts b/packages/react-native-executorch/src/utils/SpeechToTextModule/ASR.ts deleted file mode 100644 index a599e1f2e..000000000 --- a/packages/react-native-executorch/src/utils/SpeechToTextModule/ASR.ts +++ /dev/null @@ -1,303 +0,0 @@ -// NOTE: This will be implemented in C++ - -import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule'; -import { - DecodingOptions, - Segment, - SpeechToTextModelConfig, - WordObject, - WordTuple, -} from '../../types/stt'; -import { ResourceFetcher } from '../ResourceFetcher'; - -export class ASR { - private nativeModule: any; - private tokenizerModule: TokenizerModule = new TokenizerModule(); - - private timePrecision: number = 0.02; // Whisper timestamp precision - private maxDecodeLength: number = 128; - private chunkSize: number = 30; // 30 seconds - private minChunkSamples: number = 1 * 16000; // 1 second - private samplingRate: number = 16000; - - private startOfTranscriptToken!: number; - private endOfTextToken!: number; - private timestampBeginToken!: number; - - public async load( - model: SpeechToTextModelConfig, - onDownloadProgressCallback: (progress: number) => void - ) { - const tokenizerLoadPromise = this.tokenizerModule.load(model); - const encoderDecoderPromise = ResourceFetcher.fetch( - onDownloadProgressCallback, - model.encoderSource, - model.decoderSource - ); - const [_, encoderDecoderResults] = await Promise.all([ - tokenizerLoadPromise, - encoderDecoderPromise, - ]); - const encoderSource = encoderDecoderResults?.[0]; - const decoderSource = encoderDecoderResults?.[1]; - if (!encoderSource || !decoderSource) { - throw new Error('Download interrupted.'); - } - this.nativeModule = await global.loadSpeechToText( - encoderSource, - decoderSource, - 'whisper' - ); - - this.startOfTranscriptToken = await this.tokenizerModule.tokenToId( - '<|startoftranscript|>' - ); - this.endOfTextToken = await this.tokenizerModule.tokenToId('<|endoftext|>'); - this.timestampBeginToken = await this.tokenizerModule.tokenToId('<|0.00|>'); - } - - private async getInitialSequence( - options: DecodingOptions - ): Promise { - const initialSequence: number[] = [this.startOfTranscriptToken]; - if (options.language) { - const languageToken = await this.tokenizerModule.tokenToId( - `<|${options.language}|>` - ); - const taskToken = await this.tokenizerModule.tokenToId('<|transcribe|>'); - initialSequence.push(languageToken); - initialSequence.push(taskToken); - } - initialSequence.push(this.timestampBeginToken); - return initialSequence; - } - - private async generate( - audio: number[], - temperature: number, - options: DecodingOptions - ): Promise<{ - sequencesIds: number[]; - scores: number[]; - }> { - await this.encode(new Float32Array(audio)); - const initialSequence = await this.getInitialSequence(options); - const sequencesIds = [...initialSequence]; - const scores: number[] = []; - - while (sequencesIds.length <= this.maxDecodeLength) { - const logits = this.softmaxWithTemperature( - Array.from(await this.decode(sequencesIds)), - temperature === 0 ? 1 : temperature - ); - const nextTokenId = - temperature === 0 - ? logits.indexOf(Math.max(...logits)) - : this.sampleFromDistribution(logits); - const nextTokenProb = logits[nextTokenId]!; - sequencesIds.push(nextTokenId); - scores.push(nextTokenProb); - if (nextTokenId === this.endOfTextToken) { - break; - } - } - - return { - sequencesIds: sequencesIds.slice(initialSequence.length), - scores: scores.slice(initialSequence.length), - }; - } - - private softmaxWithTemperature(logits: number[], temperature = 1.0) { - const max = Math.max(...logits); - const exps = logits.map((logit) => Math.exp((logit - max) / temperature)); - const sum = exps.reduce((a, b) => a + b, 0); - return exps.map((exp) => exp / sum); - } - - private sampleFromDistribution(probs: number[]): number { - const r = Math.random(); - let cumulative = 0; - for (let i = 0; i < probs.length; i++) { - cumulative += probs[i]!; - if (r < cumulative) { - return i; - } - } - return probs.length - 1; - } - - private async generateWithFallback( - audio: number[], - options: DecodingOptions - ) { - const temperatures = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; - let generatedTokens: number[] = []; - - for (const temperature of temperatures) { - const result = await this.generate(audio, temperature, options); - const tokens = result.sequencesIds; - const scores = result.scores; - - const seqLen = tokens.length; - const cumLogProb = scores.reduce( - (acc, score) => acc + Math.log(score), - 0 - ); - const avgLogProb = cumLogProb / seqLen; - - if (avgLogProb >= -1.0) { - generatedTokens = tokens; - break; - } - } - - return this.calculateWordLevelTimestamps(generatedTokens, audio); - } - - private async calculateWordLevelTimestamps( - generatedTokens: number[], - audio: number[] - ): Promise { - const segments: Segment[] = []; - - let tokens: number[] = []; - let prevTimestamp = this.timestampBeginToken; - for (let i = 0; i < generatedTokens.length; i++) { - if (generatedTokens[i]! < this.timestampBeginToken) { - tokens.push(generatedTokens[i]!); - } - - if ( - i > 0 && - generatedTokens[i - 1]! >= this.timestampBeginToken && - generatedTokens[i]! >= this.timestampBeginToken - ) { - const start = prevTimestamp; - const end = generatedTokens[i - 1]!; - const wordObjects = await this.estimateWordTimestampsLinear( - tokens, - start, - end - ); - segments.push({ - words: wordObjects, - }); - tokens = []; - prevTimestamp = generatedTokens[i]!; - } - } - - const start = prevTimestamp; - const end = generatedTokens.at(-2)!; - const wordObjects = await this.estimateWordTimestampsLinear( - tokens, - start, - end - ); - segments.push({ - words: wordObjects, - }); - - const scalingFactor = - audio.length / - this.samplingRate / - ((end - this.timestampBeginToken) * this.timePrecision); - if (scalingFactor < 1) { - for (const segment of segments) { - for (const word of segment.words) { - word.start *= scalingFactor; - word.end *= scalingFactor; - } - } - } - - return segments; - } - - private async estimateWordTimestampsLinear( - tokens: number[], - timestampStart: number, - timestampEnd: number - ): Promise { - const duration = (timestampEnd - timestampStart) * this.timePrecision; - const segmentText = ( - (await this.tokenizerModule.decode(tokens)) as string - ).trim(); - - const words = segmentText.split(' ').map((w) => ` ${w}`); - const numOfCharacters = words.reduce( - (acc: number, word: string) => acc + word.length, - 0 - ); - - const timePerCharacter = duration / numOfCharacters; - - const wordObjects: WordObject[] = []; - const startTimeOffset = - (timestampStart - this.timestampBeginToken) * this.timePrecision; - - let prevCharNum = 0; - for (let j = 0; j < words.length; j++) { - const word = words[j]!; - const start = startTimeOffset + prevCharNum * timePerCharacter; - const end = start + timePerCharacter * word.length; - wordObjects.push({ word, start, end }); - prevCharNum += word.length; - } - - return wordObjects; - } - - public async transcribe( - audio: number[], - options: DecodingOptions - ): Promise { - let seek = 0; - const allSegments: Segment[] = []; - - while (seek * this.samplingRate < audio.length) { - const chunk = audio.slice( - seek * this.samplingRate, - (seek + this.chunkSize) * this.samplingRate - ); - if (chunk.length < this.minChunkSamples) { - return allSegments; - } - const segments = await this.generateWithFallback(chunk, options); - for (const segment of segments) { - for (const word of segment.words) { - word.start += seek; - word.end += seek; - } - } - allSegments.push(...segments); - const lastTimeStamp = segments.at(-1)!.words.at(-1)!.end; - seek = lastTimeStamp; - } - - return allSegments; - } - - public tsWords(segments: Segment[]): WordTuple[] { - const o: WordTuple[] = []; - for (const segment of segments) { - for (const word of segment.words) { - o.push([word.start, word.end, word.word]); - } - } - return o; - } - - public segmentsEndTs(res: Segment[]) { - return res.map((segment) => segment.words.at(-1)!.end); - } - - public async encode(waveform: Float32Array): Promise { - await this.nativeModule.encode(waveform); - } - - public async decode(tokens: number[]): Promise { - return new Float32Array(await this.nativeModule.decode(tokens)); - } -} diff --git a/packages/react-native-executorch/src/utils/SpeechToTextModule/OnlineProcessor.ts b/packages/react-native-executorch/src/utils/SpeechToTextModule/OnlineProcessor.ts deleted file mode 100644 index 2185060b9..000000000 --- a/packages/react-native-executorch/src/utils/SpeechToTextModule/OnlineProcessor.ts +++ /dev/null @@ -1,87 +0,0 @@ -// NOTE: This will be implemented in C++ - -import { WordTuple, DecodingOptions, Segment } from '../../types/stt'; -import { ASR } from './ASR'; -import { HypothesisBuffer } from './hypothesisBuffer'; - -export class OnlineASRProcessor { - private asr: ASR; - - private samplingRate: number = 16000; - public audioBuffer: number[] = []; - private transcriptBuffer: HypothesisBuffer = new HypothesisBuffer(); - private bufferTimeOffset: number = 0; - private committed: WordTuple[] = []; - - constructor(asr: ASR) { - this.asr = asr; - } - - public insertAudioChunk(audio: number[]) { - this.audioBuffer.push(...audio); - } - - public async processIter(options: DecodingOptions) { - const res = await this.asr.transcribe(this.audioBuffer, options); - const tsw = this.asr.tsWords(res); - this.transcriptBuffer.insert(tsw, this.bufferTimeOffset); - const o = this.transcriptBuffer.flush(); - this.committed.push(...o); - - const s = 15; - if (this.audioBuffer.length / this.samplingRate > s) { - this.chunkCompletedSegment(res); - } - - const committed = this.toFlush(o)[2]; - const nonCommitted = this.transcriptBuffer - .complete() - .map((x) => x[2]) - .join(''); - return { committed, nonCommitted }; - } - - private chunkCompletedSegment(res: Segment[]) { - if (this.committed.length === 0) { - return; - } - - const ends = this.asr.segmentsEndTs(res); - const t = this.committed.at(-1)![1]; - - if (ends.length > 1) { - let e = ends.at(-2)! + this.bufferTimeOffset; - while (ends.length > 2 && e > t) { - ends.pop(); - e = ends.at(-2)! + this.bufferTimeOffset; - } - - if (e <= t) { - this.chunkAt(e); - } - } - } - - private chunkAt(time: number) { - this.transcriptBuffer.popCommitted(time); - const cutSeconds = time - this.bufferTimeOffset; - this.audioBuffer = this.audioBuffer.slice( - Math.floor(cutSeconds * this.samplingRate) - ); - this.bufferTimeOffset = time; - } - - public async finish() { - const o = this.transcriptBuffer.complete(); - const f = this.toFlush(o); - this.bufferTimeOffset += this.audioBuffer.length / this.samplingRate; - return { committed: f[2] }; - } - - private toFlush(words: WordTuple[]): [number | null, number | null, string] { - const t = words.map((s) => s[2]).join(' '); - const b = words.length === 0 ? null : words[0]![0]; - const e = words.length === 0 ? null : words.at(-1)![1]; - return [b, e, t]; - } -} diff --git a/packages/react-native-executorch/src/utils/SpeechToTextModule/hypothesisBuffer.ts b/packages/react-native-executorch/src/utils/SpeechToTextModule/hypothesisBuffer.ts deleted file mode 100644 index 78a98a5ae..000000000 --- a/packages/react-native-executorch/src/utils/SpeechToTextModule/hypothesisBuffer.ts +++ /dev/null @@ -1,79 +0,0 @@ -// NOTE: This will be implemented in C++ - -import { WordTuple } from '../../types/stt'; - -export class HypothesisBuffer { - private committedInBuffer: WordTuple[] = []; - private buffer: WordTuple[] = []; - private new: WordTuple[] = []; - - private lastCommittedTime: number = 0; - public lastCommittedWord: string | null = null; - - public insert(newWords: WordTuple[], offset: number) { - const newWordsOffset: WordTuple[] = newWords.map(([a, b, t]) => [ - a + offset, - b + offset, - t, - ]); - this.new = newWordsOffset.filter( - ([a, _b, _t]) => a > this.lastCommittedTime - 0.5 - ); - - if (this.new.length > 0) { - const [a, _b, _t] = this.new[0]!; - if ( - Math.abs(a - this.lastCommittedTime) < 1 && - this.committedInBuffer.length > 0 - ) { - const cn = this.committedInBuffer.length; - const nn = this.new.length; - - for (let i = 1; i <= Math.min(cn, nn, 5); i++) { - const c = this.committedInBuffer - .slice(-i) - .map((w) => w[2]) - .join(' '); - const tail = this.new - .slice(0, i) - .map((w) => w[2]) - .join(' '); - if (c === tail) { - for (let j = 0; j < i; j++) { - this.new.shift(); - } - break; - } - } - } - } - } - - public flush(): WordTuple[] { - const commit: WordTuple[] = []; - while (this.new.length > 0 && this.buffer.length > 0) { - if (this.new[0]![2] !== this.buffer[0]![2]) { - break; - } - commit.push(this.new[0]!); - this.lastCommittedWord = this.new[0]![2]; - this.lastCommittedTime = this.new[0]![1]; - this.buffer.shift(); - this.new.shift(); - } - this.buffer = this.new; - this.new = []; - this.committedInBuffer.push(...commit); - return commit; - } - - public popCommitted(time: number) { - this.committedInBuffer = this.committedInBuffer.filter( - ([_a, b, _t]) => b > time - ); - } - - public complete(): WordTuple[] { - return this.buffer; - } -} diff --git a/packages/react-native-executorch/src/utils/stt.ts b/packages/react-native-executorch/src/utils/stt.ts deleted file mode 100644 index a4a912e35..000000000 --- a/packages/react-native-executorch/src/utils/stt.ts +++ /dev/null @@ -1,28 +0,0 @@ -export const longCommonInfPref = ( - seq1: number[], - seq2: number[], - hammingDistThreshold: number -) => { - let maxInd = 0; - let maxLength = 0; - - for (let i = 0; i < seq1.length; i++) { - let j = 0; - let hammingDist = 0; - while ( - j < seq2.length && - i + j < seq1.length && - (seq1[i + j] === seq2[j] || hammingDist < hammingDistThreshold) - ) { - if (seq1[i + j] !== seq2[j]) { - hammingDist++; - } - j++; - } - if (j >= maxLength) { - maxLength = j; - maxInd = i; - } - } - return maxInd; -}; diff --git a/yarn.lock b/yarn.lock index 504faebd8..0e98536c6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -9872,11 +9872,10 @@ __metadata: metro-config: "npm:^0.81.0" react: "npm:19.0.0" react-native: "npm:0.79.2" - react-native-audio-api: "npm:0.5.7" + react-native-audio-api: "npm:^0.8.2" react-native-device-info: "npm:^14.0.4" react-native-executorch: "workspace:*" react-native-gesture-handler: "npm:~2.24.0" - react-native-live-audio-stream: "npm:^1.1.1" react-native-loading-spinner-overlay: "npm:^3.0.1" react-native-markdown-display: "npm:^7.0.2" react-native-reanimated: "npm:~3.17.4" @@ -12167,27 +12166,27 @@ __metadata: languageName: node linkType: hard -"react-native-audio-api@npm:0.5.7": - version: 0.5.7 - resolution: "react-native-audio-api@npm:0.5.7" +"react-native-audio-api@npm:0.6.5": + version: 0.6.5 + resolution: "react-native-audio-api@npm:0.6.5" peerDependencies: react: "*" react-native: "*" bin: setup-rn-audio-api-web: scripts/setup-rn-audio-api-web.js - checksum: 10/5bedfa026bf912932bfaddbcfbfa1e2feba629592f447facadc0e327e72376ae78cea4fbee2645c3d93981961164343fa3896215ed42c930cc5396433da75ce8 + checksum: 10/9bf5b124ff902f359a237bcd3c386a37b354cc6263ce66765c1788c7a8d42c307a133780c8b57ab2f0db530bfed8ac1d3ff8fb55055228854ccebc8da9a595d8 languageName: node linkType: hard -"react-native-audio-api@npm:0.6.5": - version: 0.6.5 - resolution: "react-native-audio-api@npm:0.6.5" +"react-native-audio-api@npm:^0.8.2": + version: 0.8.2 + resolution: "react-native-audio-api@npm:0.8.2" peerDependencies: react: "*" react-native: "*" bin: setup-rn-audio-api-web: scripts/setup-rn-audio-api-web.js - checksum: 10/9bf5b124ff902f359a237bcd3c386a37b354cc6263ce66765c1788c7a8d42c307a133780c8b57ab2f0db530bfed8ac1d3ff8fb55055228854ccebc8da9a595d8 + checksum: 10/064b87f8949786a4cf8c824ba7cda91e31803aec187f9f3fbb02dfe848d7d49c441a97f1697667168cc58bfd07ab7cb3f89e55520fbf897af28bdb7bb42c10cb languageName: node linkType: hard @@ -12356,13 +12355,6 @@ __metadata: languageName: node linkType: hard -"react-native-live-audio-stream@npm:^1.1.1": - version: 1.1.1 - resolution: "react-native-live-audio-stream@npm:1.1.1" - checksum: 10/50fa9d0af62d7199d143d83e779ce04b0a7418729e48ee9135980814adb9c20baa002773d68e203b0745a16ffad9acd662b72f625d42f58c160e7efb8657720a - languageName: node - linkType: hard - "react-native-loading-spinner-overlay@npm:^3.0.1": version: 3.0.1 resolution: "react-native-loading-spinner-overlay@npm:3.0.1"