diff --git a/apps/llm/android/app/src/main/AndroidManifest.xml b/apps/llm/android/app/src/main/AndroidManifest.xml
index c1d260939..87227eb01 100644
--- a/apps/llm/android/app/src/main/AndroidManifest.xml
+++ b/apps/llm/android/app/src/main/AndroidManifest.xml
@@ -7,6 +7,7 @@
+
diff --git a/apps/llm/app.json b/apps/llm/app.json
index c3031d0ed..5eef49366 100644
--- a/apps/llm/app.json
+++ b/apps/llm/app.json
@@ -23,7 +23,22 @@
"calendarPermission": "The app needs to access your calendar."
}
],
- "expo-router"
+ "expo-router",
+ [
+ "react-native-audio-api",
+ {
+ "iosBackgroundMode": true,
+ "iosMicrophonePermission": "This app requires access to the microphone to record audio.",
+ "androidPermissions": [
+ "android.permission.MODIFY_AUDIO_SETTINGS",
+ "android.permission.FOREGROUND_SERVICE",
+ "android.permission.FOREGROUND_SERVICE_MEDIA_PLAYBACK",
+ "android.permission.RECORD_AUDIO"
+ ],
+ "androidForegroundService": true,
+ "androidFSTypes": ["mediaPlayback"]
+ }
+ ]
],
"newArchEnabled": true,
"splash": {
diff --git a/apps/llm/app/voice_chat/index.tsx b/apps/llm/app/voice_chat/index.tsx
index 53c3f8508..c709a6c4a 100644
--- a/apps/llm/app/voice_chat/index.tsx
+++ b/apps/llm/app/voice_chat/index.tsx
@@ -22,38 +22,10 @@ import MicIcon from '../../assets/icons/mic_icon.svg';
import StopIcon from '../../assets/icons/stop_icon.svg';
import ColorPalette from '../../colors';
import Messages from '../../components/Messages';
-import LiveAudioStream from 'react-native-live-audio-stream';
+import { AudioManager, AudioRecorder } from 'react-native-audio-api';
import DeviceInfo from 'react-native-device-info';
-import { Buffer } from 'buffer';
import { useIsFocused } from '@react-navigation/native';
import { GeneratingContext } from '../../context';
-const audioStreamOptions = {
- sampleRate: 16000,
- channels: 1,
- bitsPerSample: 16,
- audioSource: 1,
- bufferSize: 16000,
-};
-
-const startStreamingAudio = (options: any, onChunk: (data: string) => void) => {
- LiveAudioStream.init(options);
- LiveAudioStream.on('data', onChunk);
- LiveAudioStream.start();
-};
-
-const float32ArrayFromPCMBinaryBuffer = (b64EncodedBuffer: string) => {
- const b64DecodedChunk = Buffer.from(b64EncodedBuffer, 'base64');
- const int16Array = new Int16Array(b64DecodedChunk.buffer);
-
- const float32Array = new Float32Array(int16Array.length);
- for (let i = 0; i < int16Array.length; i++) {
- float32Array[i] = Math.max(
- -1,
- Math.min(1, (int16Array[i] / audioStreamOptions.bufferSize) * 8)
- );
- }
- return float32Array;
-};
export default function VoiceChatScreenWrapper() {
const isFocused = useIsFocused();
@@ -63,6 +35,13 @@ export default function VoiceChatScreenWrapper() {
function VoiceChatScreen() {
const [isRecording, setIsRecording] = useState(false);
+ const [recorder] = useState(
+ () =>
+ new AudioRecorder({
+ sampleRate: 16000,
+ bufferLengthInSamples: 1600,
+ })
+ );
const messageRecorded = useRef(false);
const { setGlobalGenerating } = useContext(GeneratingContext);
@@ -75,20 +54,27 @@ function VoiceChatScreen() {
setGlobalGenerating(llm.isGenerating || speechToText.isGenerating);
}, [llm.isGenerating, speechToText.isGenerating, setGlobalGenerating]);
- const onChunk = (data: string) => {
- const float32Chunk = float32ArrayFromPCMBinaryBuffer(data);
- speechToText.streamInsert(Array.from(float32Chunk));
- };
+ useEffect(() => {
+ AudioManager.setAudioSessionOptions({
+ iosCategory: 'playAndRecord',
+ iosMode: 'spokenAudio',
+ iosOptions: ['allowBluetooth', 'defaultToSpeaker'],
+ });
+ AudioManager.requestRecordingPermissions();
+ }, []);
const handleRecordPress = async () => {
if (isRecording) {
setIsRecording(false);
- LiveAudioStream.stop();
+ recorder.stop();
messageRecorded.current = true;
- speechToText.streamStop();
+ await speechToText.streamStop();
} else {
setIsRecording(true);
- startStreamingAudio(audioStreamOptions, onChunk);
+ recorder.onAudioReady(async ({ buffer }) => {
+ await speechToText.streamInsert(buffer.getChannelData(0));
+ });
+ recorder.start();
const transcription = await speechToText.stream();
await llm.sendMessage(transcription);
}
diff --git a/apps/llm/ios/Podfile.lock b/apps/llm/ios/Podfile.lock
index 48632bb03..46d2399e8 100644
--- a/apps/llm/ios/Podfile.lock
+++ b/apps/llm/ios/Podfile.lock
@@ -1403,7 +1403,7 @@ PODS:
- React-jsiexecutor
- React-RCTFBReactNativeSpec
- ReactCommon/turbomodule/core
- - react-native-executorch (0.5.2):
+ - react-native-executorch (0.5.0):
- DoubleConversion
- glog
- hermes-engine
@@ -1826,7 +1826,7 @@ PODS:
- React-logger (= 0.79.2)
- React-perflogger (= 0.79.2)
- React-utils (= 0.79.2)
- - RNAudioAPI (0.5.7):
+ - RNAudioAPI (0.8.2):
- DoubleConversion
- glog
- hermes-engine
@@ -1849,9 +1849,9 @@ PODS:
- ReactCodegen
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- - RNAudioAPI/audioapi (= 0.5.7)
+ - RNAudioAPI/audioapi (= 0.8.2)
- Yoga
- - RNAudioAPI/audioapi (0.5.7):
+ - RNAudioAPI/audioapi (0.8.2):
- DoubleConversion
- glog
- hermes-engine
@@ -1874,9 +1874,9 @@ PODS:
- ReactCodegen
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- - RNAudioAPI/audioapi/ios (= 0.5.7)
+ - RNAudioAPI/audioapi/ios (= 0.8.2)
- Yoga
- - RNAudioAPI/audioapi/ios (0.5.7):
+ - RNAudioAPI/audioapi/ios (0.8.2):
- DoubleConversion
- glog
- hermes-engine
@@ -1926,8 +1926,6 @@ PODS:
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- Yoga
- - RNLiveAudioStream (1.1.1):
- - React
- RNReanimated (3.17.5):
- DoubleConversion
- glog
@@ -2246,7 +2244,6 @@ DEPENDENCIES:
- RNAudioAPI (from `../../../node_modules/react-native-audio-api`)
- RNDeviceInfo (from `../../../node_modules/react-native-device-info`)
- RNGestureHandler (from `../../../node_modules/react-native-gesture-handler`)
- - RNLiveAudioStream (from `../../../node_modules/react-native-live-audio-stream`)
- RNReanimated (from `../../../node_modules/react-native-reanimated`)
- RNScreens (from `../../../node_modules/react-native-screens`)
- RNSVG (from `../../../node_modules/react-native-svg`)
@@ -2430,8 +2427,6 @@ EXTERNAL SOURCES:
:path: "../../../node_modules/react-native-device-info"
RNGestureHandler:
:path: "../../../node_modules/react-native-gesture-handler"
- RNLiveAudioStream:
- :path: "../../../node_modules/react-native-live-audio-stream"
RNReanimated:
:path: "../../../node_modules/react-native-reanimated"
RNScreens:
@@ -2525,15 +2520,14 @@ SPEC CHECKSUMS:
ReactAppDependencyProvider: 04d5eb15eb46be6720e17a4a7fa92940a776e584
ReactCodegen: 7ea266ccd94436294f516247db7402b57b1214af
ReactCommon: 76d2dc87136d0a667678668b86f0fca0c16fdeb0
- RNAudioAPI: 2e3fd4bf75aa5717791babb30126707504996f09
+ RNAudioAPI: 3e398c4e9d44bb6b0c0b00e902057613224fc024
RNDeviceInfo: d863506092aef7e7af3a1c350c913d867d795047
RNGestureHandler: 7d0931a61d7ba0259f32db0ba7d0963c3ed15d2b
- RNLiveAudioStream: 93ac2bb6065be9018d0b00157b220f11cebc1513
RNReanimated: afd6a269a47d6f13ba295c46c6c0e14e3cbd0d8a
RNScreens: 482e9707f9826230810c92e765751af53826d509
RNSVG: 794f269526df9ddc1f79b3d1a202b619df0368e3
SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748
- sqlite3: 83105acd294c9137c026e2da1931c30b4588ab81
+ sqlite3: 1d85290c3321153511f6e900ede7a1608718bbd5
Yoga: c758bfb934100bb4bf9cbaccb52557cee35e8bdf
PODFILE CHECKSUM: bba19a069e673f2259009e9d2caab44374fdebcf
diff --git a/apps/llm/ios/llm.xcodeproj/project.pbxproj b/apps/llm/ios/llm.xcodeproj/project.pbxproj
index 582d3baf5..215dad425 100644
--- a/apps/llm/ios/llm.xcodeproj/project.pbxproj
+++ b/apps/llm/ios/llm.xcodeproj/project.pbxproj
@@ -28,12 +28,11 @@
B79E360E00239D910BF9B38D /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xml; name = PrivacyInfo.xcprivacy; path = llm/PrivacyInfo.xcprivacy; sourceTree = ""; };
BB2F792C24A3F905000567C9 /* Expo.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Expo.plist; sourceTree = ""; };
E8C01EF33FCE4105BBBC9DF6 /* Aeonik-Medium.otf */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 9; includeInIndex = 0; lastKnownFileType = unknown; name = "Aeonik-Medium.otf"; path = "../assets/fonts/Aeonik-Medium.otf"; sourceTree = ""; };
- EA4529BE680FEB0AB7539557 /* Pods-llm.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-llm.release.xcconfig"; path = "Target Support Files/Pods-llm/Pods-llm.release.xcconfig"; sourceTree = ""; };
ED297162215061F000B7C4FE /* JavaScriptCore.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = JavaScriptCore.framework; path = System/Library/Frameworks/JavaScriptCore.framework; sourceTree = SDKROOT; };
F11748412D0307B40044C1D9 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = AppDelegate.swift; path = llm/AppDelegate.swift; sourceTree = ""; };
F11748442D0722820044C1D9 /* llm-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "llm-Bridging-Header.h"; path = "llm/llm-Bridging-Header.h"; sourceTree = ""; };
+ F5CE0775ADE5923FA417B603 /* libPods-llm.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-llm.a"; sourceTree = BUILT_PRODUCTS_DIR; };
F866B7979FB94C8797EE2E3D /* Aeonik-Regular.otf */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 9; includeInIndex = 0; lastKnownFileType = unknown; name = "Aeonik-Regular.otf"; path = "../assets/fonts/Aeonik-Regular.otf"; sourceTree = ""; };
- FCA4A9AE0011869427989B32 /* libPods-llm.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-llm.a"; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@@ -96,6 +95,15 @@
name = Frameworks;
sourceTree = "";
};
+ 3014A6CAF64EC97E4003A2A3 /* Pods */ = {
+ isa = PBXGroup;
+ children = (
+ 4F489A14802F01369BFDDEFD /* Pods-llm.debug.xcconfig */,
+ 63C842393C3838DA2ECEFC7C /* Pods-llm.release.xcconfig */,
+ );
+ path = Pods;
+ sourceTree = "";
+ };
832341AE1AAA6A7D00B99B32 /* Libraries */ = {
isa = PBXGroup;
children = (
@@ -292,7 +300,52 @@
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
- E0CDBD4D0993974173A0E9FD /* [CP] Copy Pods Resources */ = {
+ 281D8603161F8B331E2BA335 /* [Expo] Configure project */ = {
+ isa = PBXShellScriptBuildPhase;
+ alwaysOutOfDate = 1;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputFileListPaths = (
+ );
+ inputPaths = (
+ );
+ name = "[Expo] Configure project";
+ outputFileListPaths = (
+ );
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "# This script configures Expo modules and generates the modules provider file.\nbash -l -c \"./Pods/Target\\ Support\\ Files/Pods-llm/expo-configure-project.sh\"\n";
+ };
+ 62055444ECB4CA2743E68CDC /* [CP] Embed Pods Frameworks */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ "${PODS_ROOT}/Target Support Files/Pods-llm/Pods-llm-frameworks.sh",
+ "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavcodec.framework/libavcodec",
+ "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavformat.framework/libavformat",
+ "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libavutil.framework/libavutil",
+ "${PODS_XCFRAMEWORKS_BUILD_DIR}/RNAudioAPI/libswresample.framework/libswresample",
+ "${PODS_XCFRAMEWORKS_BUILD_DIR}/hermes-engine/Pre-built/hermes.framework/hermes",
+ );
+ name = "[CP] Embed Pods Frameworks";
+ outputPaths = (
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavcodec.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavformat.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libavutil.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/libswresample.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/hermes.framework",
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-llm/Pods-llm-frameworks.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+ 800E24972A6A228C8D4807E9 /* [CP] Copy Pods Resources */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
diff --git a/apps/llm/metro.config.js b/apps/llm/metro.config.js
index 1ce72df78..fa6ef6d8b 100644
--- a/apps/llm/metro.config.js
+++ b/apps/llm/metro.config.js
@@ -1,21 +1,22 @@
const { getDefaultConfig } = require('expo/metro-config');
+const {
+ wrapWithAudioAPIMetroConfig,
+} = require('react-native-audio-api/metro-config');
-module.exports = (() => {
- const config = getDefaultConfig(__dirname);
+const config = getDefaultConfig(__dirname);
- const { transformer, resolver } = config;
+const { transformer, resolver } = config;
- config.transformer = {
- ...transformer,
- babelTransformerPath: require.resolve('react-native-svg-transformer/expo'),
- };
- config.resolver = {
- ...resolver,
- assetExts: resolver.assetExts.filter((ext) => ext !== 'svg'),
- sourceExts: [...resolver.sourceExts, 'svg'],
- };
+config.transformer = {
+ ...transformer,
+ babelTransformerPath: require.resolve('react-native-svg-transformer/expo'),
+};
+config.resolver = {
+ ...resolver,
+ assetExts: resolver.assetExts.filter((ext) => ext !== 'svg'),
+ sourceExts: [...resolver.sourceExts, 'svg'],
+};
- config.resolver.assetExts.push('pte');
+config.resolver.assetExts.push('pte');
- return config;
-})();
+module.exports = wrapWithAudioAPIMetroConfig(config);
diff --git a/apps/llm/package.json b/apps/llm/package.json
index 03c76dbff..32d8c3392 100644
--- a/apps/llm/package.json
+++ b/apps/llm/package.json
@@ -28,11 +28,10 @@
"metro-config": "^0.81.0",
"react": "19.0.0",
"react-native": "0.79.2",
- "react-native-audio-api": "0.5.7",
+ "react-native-audio-api": "^0.8.2",
"react-native-device-info": "^14.0.4",
"react-native-executorch": "workspace:*",
"react-native-gesture-handler": "~2.24.0",
- "react-native-live-audio-stream": "^1.1.1",
"react-native-loading-spinner-overlay": "^3.0.1",
"react-native-markdown-display": "^7.0.2",
"react-native-reanimated": "~3.17.4",
diff --git a/apps/speech-to-text/ios/Podfile.lock b/apps/speech-to-text/ios/Podfile.lock
index 3c162c082..fe20c8804 100644
--- a/apps/speech-to-text/ios/Podfile.lock
+++ b/apps/speech-to-text/ios/Podfile.lock
@@ -1395,7 +1395,7 @@ PODS:
- React-jsiexecutor
- React-RCTFBReactNativeSpec
- ReactCommon/turbomodule/core
- - react-native-executorch (0.4.2):
+ - react-native-executorch (0.5.0):
- DoubleConversion
- glog
- hermes-engine
@@ -2382,7 +2382,7 @@ SPEC CHECKSUMS:
React-logger: 8edfcedc100544791cd82692ca5a574240a16219
React-Mapbuffer: c3f4b608e4a59dd2f6a416ef4d47a14400194468
React-microtasksnativemodule: 054f34e9b82f02bd40f09cebd4083828b5b2beb6
- react-native-executorch: c18d209e226f0530a9ee88f1d60ce5837d4800ee
+ react-native-executorch: 3c871f7ed2e2b0ff92519ce38f06f0904784dbdb
react-native-safe-area-context: 562163222d999b79a51577eda2ea8ad2c32b4d06
React-NativeModulesApple: 2c4377e139522c3d73f5df582e4f051a838ff25e
React-oscompat: ef5df1c734f19b8003e149317d041b8ce1f7d29c
diff --git a/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj b/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj
index 6118e5088..0787f2ae6 100644
--- a/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj
+++ b/apps/speech-to-text/ios/speechtotext.xcodeproj/project.pbxproj
@@ -276,7 +276,6 @@
"${PODS_CONFIGURATION_BUILD_DIR}/React-cxxreact/React-cxxreact_privacy.bundle",
"${PODS_CONFIGURATION_BUILD_DIR}/boost/boost_privacy.bundle",
"${PODS_CONFIGURATION_BUILD_DIR}/glog/glog_privacy.bundle",
- "${PODS_CONFIGURATION_BUILD_DIR}/react-native-image-picker/RNImagePickerPrivacyInfo.bundle",
);
name = "[CP] Copy Pods Resources";
outputPaths = (
@@ -290,7 +289,6 @@
"${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/React-cxxreact_privacy.bundle",
"${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/boost_privacy.bundle",
"${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/glog_privacy.bundle",
- "${TARGET_BUILD_DIR}/${UNLOCALIZED_RESOURCES_FOLDER_PATH}/RNImagePickerPrivacyInfo.bundle",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
@@ -305,12 +303,10 @@
inputPaths = (
"${PODS_ROOT}/Target Support Files/Pods-speechtotext/Pods-speechtotext-frameworks.sh",
"${PODS_XCFRAMEWORKS_BUILD_DIR}/hermes-engine/Pre-built/hermes.framework/hermes",
- "${PODS_XCFRAMEWORKS_BUILD_DIR}/react-native-executorch/ExecutorchLib.framework/ExecutorchLib",
);
name = "[CP] Embed Pods Frameworks";
outputPaths = (
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/hermes.framework",
- "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ExecutorchLib.framework",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
diff --git a/apps/speech-to-text/metro.config.js b/apps/speech-to-text/metro.config.js
index f8ab2ab96..fa6ef6d8b 100644
--- a/apps/speech-to-text/metro.config.js
+++ b/apps/speech-to-text/metro.config.js
@@ -1,7 +1,8 @@
-// Learn more https://docs.expo.io/guides/customizing-metro
const { getDefaultConfig } = require('expo/metro-config');
+const {
+ wrapWithAudioAPIMetroConfig,
+} = require('react-native-audio-api/metro-config');
-/** @type {import('expo/metro-config').MetroConfig} */
const config = getDefaultConfig(__dirname);
const { transformer, resolver } = config;
@@ -18,4 +19,4 @@ config.resolver = {
config.resolver.assetExts.push('pte');
-module.exports = config;
+module.exports = wrapWithAudioAPIMetroConfig(config);
diff --git a/apps/speech-to-text/screens/SpeechToTextScreen.tsx b/apps/speech-to-text/screens/SpeechToTextScreen.tsx
index a7789a7f3..3be1750d3 100644
--- a/apps/speech-to-text/screens/SpeechToTextScreen.tsx
+++ b/apps/speech-to-text/screens/SpeechToTextScreen.tsx
@@ -66,8 +66,7 @@ export const SpeechToTextScreen = () => {
try {
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
- const audioArray = Array.from(audioBuffer);
- setTranscription(await model.transcribe(audioArray));
+ setTranscription(await model.transcribe(audioBuffer));
} catch (error) {
console.error('Error decoding audio data', error);
console.warn('Note: Supported file formats: mp3, wav, flac');
@@ -79,8 +78,7 @@ export const SpeechToTextScreen = () => {
setLiveTranscribing(true);
setTranscription('');
recorder.onAudioReady(async ({ buffer }) => {
- const bufferArray = Array.from(buffer.getChannelData(0));
- model.streamInsert(bufferArray);
+ await model.streamInsert(buffer.getChannelData(0));
});
recorder.start();
@@ -93,7 +91,7 @@ export const SpeechToTextScreen = () => {
const handleStopTranscribeFromMicrophone = async () => {
recorder.stop();
- model.streamStop();
+ await model.streamStop();
console.log('Live transcription stopped');
setLiveTranscribing(false);
};
diff --git a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md
index 3b5af0e23..0b7d95398 100644
--- a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md
+++ b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md
@@ -44,10 +44,9 @@ const { uri } = await FileSystem.downloadAsync(
const audioContext = new AudioContext({ sampleRate: 16000 });
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
-const audioArray = Array.from(audioBuffer);
try {
- const transcription = await model.transcribe(audioArray);
+ const transcription = await model.transcribe(audioBuffer);
console.log(transcription);
} catch (error) {
console.error('Error during audio transcription', error);
@@ -76,20 +75,20 @@ For more information on loading resources, take a look at [loading models](../..
### Returns
-| Field | Type | Description |
-| --------------------------- | --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `transcribe` | `(waveform: number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. |
-| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. |
-| `streamInsert` | `(waveform: number[]) => void` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. |
-| `streamStop` | `() => void` | Stops the ongoing streaming transcription process. |
-| `encode` | `(waveform: Float32Array) => Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. |
-| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. |
-| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. |
-| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. |
-| `error` | `string \| null` | Contains the error message if the model failed to load. |
-| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
-| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
-| `downloadProgress` | `number` | Tracks the progress of the model download process. |
+| Field | Type | Description |
+| --------------------------- | ---------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. Passing `number[]` is deprecated. |
+| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. |
+| `streamInsert` | `(waveform: Float32Array \| number[]) => Promise` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. Passing `number[]` is deprecated. |
+| `streamStop` | `() => Promise` | Stops the ongoing streaming transcription process. |
+| `encode` | `(waveform: Float32Array \| number[]) => Promise` | Runs the encoding part of the model on the provided waveform. Passing `number[]` is deprecated. |
+| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]) => Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. |
+| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. |
+| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. |
+| `error` | `string \| null` | Contains the error message if the model failed to load. |
+| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
+| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
+| `downloadProgress` | `number` | Tracks the progress of the model download process. |
Type definitions
@@ -231,7 +230,7 @@ function App() {
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
- return Array.from(audioBuffer);
+ return audioBuffer;
};
const handleTranscribe = async () => {
@@ -281,8 +280,7 @@ function App() {
const handleStartStreamingTranscribe = async () => {
recorder.onAudioReady(async ({ buffer }) => {
- const bufferArray = Array.from(buffer.getChannelData(0));
- model.streamInsert(bufferArray);
+ await model.streamInsert(buffer.getChannelData(0));
});
recorder.start();
@@ -295,7 +293,7 @@ function App() {
const handleStopStreamingTranscribe = async () => {
recorder.stop();
- model.streamStop();
+ await model.streamStop();
};
return (
diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
index df57a8f3d..839e58c78 100644
--- a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
+++ b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
@@ -19,15 +19,15 @@ await model.transcribe(waveform);
### Methods
-| Method | Type | Description |
-| -------------- | ---------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. |
-| `encode` | `(waveform: Float32Array): Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. |
-| `decode` | `(tokens: number[]): Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. |
-| `transcribe` | `(waveform: number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. |
-| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. |
-| `streamStop` | `(): void` | Stops the current streaming transcription session. |
-| `streamInsert` | `(waveform: number[]): void` | Inserts a new audio chunk into the streaming transcription session. |
+| Method | Type | Description |
+| -------------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. |
+| `encode` | `(waveform: Float32Array \| number[]): Promise` | Runs the encoding part of the model on the provided waveform. Returns the encoded waveform as a Float32Array. Passing `number[]` is deprecated. |
+| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]): Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. |
+| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. Passing `number[]` is deprecated. |
+| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. |
+| `streamStop` | `(): Promise` | Stops the current streaming transcription session. |
+| `streamInsert` | `(waveform: Float32Array \| number[]): Promise` | Inserts a new audio chunk into the streaming transcription session. Passing `number[]` is deprecated. |
:::info
@@ -192,11 +192,10 @@ const { uri } = await FileSystem.downloadAsync(
const audioContext = new AudioContext({ sampleRate: 16000 });
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
-const audioArray = Array.from(audioBuffer);
// Transcribe the audio
try {
- const transcription = await model.transcribe(audioArray);
+ const transcription = await model.transcribe(audioBuffer);
console.log(transcription);
} catch (error) {
console.error('Error during audio transcription', error);
@@ -229,9 +228,8 @@ const recorder = new AudioRecorder({
bufferLengthInSamples: 1600,
});
recorder.onAudioReady(async ({ buffer }) => {
- const bufferArray = Array.from(buffer.getChannelData(0));
// Insert the audio into the streaming transcription
- model.streamInsert(bufferArray);
+ await model.streamInsert(buffer.getChannelData(0));
});
recorder.start();
@@ -248,6 +246,6 @@ try {
}
// Stop streaming transcription
-model.streamStop();
+await model.streamStop();
recorder.stop();
```
diff --git a/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md b/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md
index 3b5af0e23..4c3237743 100644
--- a/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md
+++ b/docs/versioned_docs/version-0.5.x/02-hooks/01-natural-language-processing/useSpeechToText.md
@@ -44,10 +44,9 @@ const { uri } = await FileSystem.downloadAsync(
const audioContext = new AudioContext({ sampleRate: 16000 });
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
-const audioArray = Array.from(audioBuffer);
try {
- const transcription = await model.transcribe(audioArray);
+ const transcription = await model.transcribe(audioBuffer);
console.log(transcription);
} catch (error) {
console.error('Error during audio transcription', error);
@@ -76,20 +75,20 @@ For more information on loading resources, take a look at [loading models](../..
### Returns
-| Field | Type | Description |
-| --------------------------- | --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `transcribe` | `(waveform: number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. |
-| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. |
-| `streamInsert` | `(waveform: number[]) => void` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. |
-| `streamStop` | `() => void` | Stops the ongoing streaming transcription process. |
-| `encode` | `(waveform: Float32Array) => Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. |
-| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. |
-| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. |
-| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. |
-| `error` | `string \| null` | Contains the error message if the model failed to load. |
-| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
-| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
-| `downloadProgress` | `number` | Tracks the progress of the model download process. |
+| Field | Type | Description |
+| --------------------------- | ---------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions \| undefined) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. The second argument is an options object, e.g. `{ language: 'es' }` for multilingual models. Resolves a promise with the output transcription when the model is finished. Passing `number[]` is deprecated. |
+| `stream` | `() => Promise` | Starts a streaming transcription process. Use in combination with `streamInsert` to feed audio chunks and `streamStop` to end the stream. Updates `committedTranscription` and `nonCommittedTranscription` as transcription progresses. |
+| `streamInsert` | `(waveform: Float32Array \| number[]) => Promise` | Inserts a chunk of audio data (sampled at 16kHz) into the ongoing streaming transcription. Call this repeatedly as new audio data becomes available. Passing `number[]` is deprecated. |
+| `streamStop` | `() => Promise` | Stops the ongoing streaming transcription process. |
+| `encode` | `(waveform: Float32Array \| number[]) => Promise` | Runs the encoding part of the model on the provided waveform. Passing `number[]` is deprecated. |
+| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]) => Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. |
+| `committedTranscription` | `string` | Contains the part of the transcription that is finalized and will not change. Useful for displaying stable results during streaming. |
+| `nonCommittedTranscription` | `string` | Contains the part of the transcription that is still being processed and may change. Useful for displaying live, partial results during streaming. |
+| `error` | `string \| null` | Contains the error message if the model failed to load. |
+| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
+| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
+| `downloadProgress` | `number` | Tracks the progress of the model download process. |
Type definitions
@@ -231,7 +230,7 @@ function App() {
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
- return Array.from(audioBuffer);
+ return audioBuffer;
};
const handleTranscribe = async () => {
@@ -281,8 +280,7 @@ function App() {
const handleStartStreamingTranscribe = async () => {
recorder.onAudioReady(async ({ buffer }) => {
- const bufferArray = Array.from(buffer.getChannelData(0));
- model.streamInsert(bufferArray);
+ model.streamInsert(buffer.getChannelData(0));
});
recorder.start();
diff --git a/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
index df57a8f3d..039189e73 100644
--- a/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
+++ b/docs/versioned_docs/version-0.5.x/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md
@@ -19,15 +19,15 @@ await model.transcribe(waveform);
### Methods
-| Method | Type | Description |
-| -------------- | ---------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. |
-| `encode` | `(waveform: Float32Array): Promise` | Runs the encoding part of the model on the provided waveform. Stores the result internally. |
-| `decode` | `(tokens: number[]): Promise` | Runs the decoder of the model. Returns the decoded waveform as a Float32Array. |
-| `transcribe` | `(waveform: number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. |
-| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. |
-| `streamStop` | `(): void` | Stops the current streaming transcription session. |
-| `streamInsert` | `(waveform: number[]): void` | Inserts a new audio chunk into the streaming transcription session. |
+| Method | Type | Description |
+| -------------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `load` | `(model: SpeechToTextModelConfig, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified by the config object. `onDownloadProgressCallback` allows you to monitor the current progress of the model download. |
+| `encode` | `(waveform: Float32Array \| number[]): Promise` | Runs the encoding part of the model on the provided waveform. Returns the encoded waveform as a Float32Array. Passing `number[]` is deprecated. |
+| `decode` | `(tokens: number[] \| Int32Array, encoderOutput: Float32Array \| number[]): Promise` | Runs the decoder of the model. Passing `number[]` is deprecated. |
+| `transcribe` | `(waveform: Float32Array \| number[], options?: DecodingOptions): Promise` | Starts a transcription process for a given input array (16kHz waveform). For multilingual models, specify the language in `options`. Returns the transcription as a string. Passing `number[]` is deprecated. |
+| `stream` | `(options?: DecodingOptions): AsyncGenerator<{ committed: string; nonCommitted: string }>` | Starts a streaming transcription session. Yields objects with `committed` and `nonCommitted` transcriptions. Use with `streamInsert` and `streamStop` to control the stream. |
+| `streamStop` | `(): void` | Stops the current streaming transcription session. |
+| `streamInsert` | `(waveform: Float32Array \| number[]): void` | Inserts a new audio chunk into the streaming transcription session. Passing `number[]` is deprecated. |
:::info
@@ -192,11 +192,10 @@ const { uri } = await FileSystem.downloadAsync(
const audioContext = new AudioContext({ sampleRate: 16000 });
const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
const audioBuffer = decodedAudioData.getChannelData(0);
-const audioArray = Array.from(audioBuffer);
// Transcribe the audio
try {
- const transcription = await model.transcribe(audioArray);
+ const transcription = await model.transcribe(audioBuffer);
console.log(transcription);
} catch (error) {
console.error('Error during audio transcription', error);
@@ -229,9 +228,8 @@ const recorder = new AudioRecorder({
bufferLengthInSamples: 1600,
});
recorder.onAudioReady(async ({ buffer }) => {
- const bufferArray = Array.from(buffer.getChannelData(0));
// Insert the audio into the streaming transcription
- model.streamInsert(bufferArray);
+ model.streamInsert(buffer.getChannelData(0));
});
recorder.start();
diff --git a/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt b/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt
index bf1544aeb..d77ec7563 100644
--- a/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt
+++ b/packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt
@@ -95,4 +95,5 @@ target_link_libraries(
${OPENCV_THIRD_PARTY_LIBS}
executorch
${EXECUTORCH_LIBS}
-)
\ No newline at end of file
+ z
+)
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp
index 1b155f72d..0ccac9d07 100644
--- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp
@@ -9,7 +9,7 @@
#include
namespace rnexecutorch::numerical {
-void softmax(std::vector &v) {
+void softmax(std::span v) {
float max = *std::max_element(v.begin(), v.end());
float sum = 0.0f;
@@ -22,32 +22,40 @@ void softmax(std::vector &v) {
}
}
-void normalize(std::span span) {
- auto sum = 0.0f;
- for (const auto &val : span) {
- sum += val * val;
+void softmaxWithTemperature(std::span input, float temperature) {
+ if (input.empty()) {
+ return;
}
- if (isClose(sum, 0.0f)) {
- return;
+ if (temperature <= 0.0F) {
+ throw std::invalid_argument(
+ "Temperature must be greater than 0 for softmax with temperature.");
}
- float norm = std::sqrt(sum);
- for (auto &val : span) {
- val /= norm;
+ const auto maxElement = *std::ranges::max_element(input);
+
+ for (auto &value : input) {
+ value = std::exp((value - maxElement) / temperature);
}
-}
-void normalize(std::vector &v) {
- float sum = 0.0f;
- for (float &x : v) {
- sum += x * x;
+ const auto sum = std::reduce(input.begin(), input.end());
+
+ // sum is at least 1 since exp(max - max) == exp(0) == 1
+ for (auto &value : input) {
+ value /= sum;
}
+}
- float norm =
- std::max(std::sqrt(sum), 1e-9f); // Solely for preventing division by 0
- for (float &x : v) {
- x /= norm;
+void normalize(std::span input) {
+ const auto sumOfSquares =
+ std::inner_product(input.begin(), input.end(), input.begin(), 0.0F);
+
+ constexpr auto kEpsilon = 1.0e-15F;
+
+ const auto norm = std::sqrt(sumOfSquares) + kEpsilon;
+
+ for (auto &value : input) {
+ value /= norm;
}
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h
index 77a13f44f..fa808f21f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h
@@ -4,10 +4,59 @@
#include
namespace rnexecutorch::numerical {
-void softmax(std::vector &v);
-void normalize(std::span span);
-void normalize(std::vector &v);
-void normalize(std::span span);
+
+/**
+ * @brief Applies the softmax function in-place to a sequence of numbers.
+ *
+ * @param input A mutable span of floating-point numbers. After the function
+ * returns, `input` contains the softmax probabilities.
+ */
+void softmax(std::span input);
+
+/**
+ * @brief Applies the softmax function with temperature scaling in-place to a
+ * sequence of numbers.
+ *
+ * The temperature parameter controls the "sharpness" of the resulting
+ * probability distribution. A temperature of 1.0 means no scaling, while lower
+ * values make the distribution sharper (more peaked), and higher values make it
+ * softer (more uniform).
+ *
+ * @param input A mutable span of floating-point numbers. After the function
+ * returns, `input` contains the softmax probabilities.
+ * @param temperature A positive float value used to scale the logits before
+ * applying softmax. Must be greater than 0.
+ */
+void softmaxWithTemperature(std::span input, float temperature);
+
+/**
+ * @brief Normalizes the elements of the given float span in-place using the
+ * L2 norm method.
+ *
+ * This function scales the input vector such that its L2 norm (Euclidean norm)
+ * becomes 1. If the norm is zero, the result is a zero vector with the same
+ * size as the input.
+ *
+ * @param input A mutable span of floating-point values representing the data to
+ * be normalized.
+ */
+void normalize(std::span input);
+
+/**
+ * @brief Computes mean pooling across the modelOutput adjusted by an attention
+ * mask.
+ *
+ * This function aggregates the `modelOutput` span by sections defined by
+ * `attnMask`, computing the mean of sections influenced by the mask. The result
+ * is a vector where each element is the mean of a segment from the original
+ * data.
+ *
+ * @param modelOutput A span of floating-point numbers representing the model
+ * output.
+ * @param attnMask A span of integers where each integer is a weight
+ * corresponding to the elements in `modelOutput`.
+ * @return A std::vector containing the computed mean values of segments.
+ */
std::vector meanPooling(std::span modelOutput,
std::span attnMask);
/**
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp
index f088659b8..d3761dced 100644
--- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp
@@ -18,7 +18,7 @@ std::vector hannWindow(size_t size) {
return window;
}
-std::vector stftFromWaveform(std::span waveform,
+std::vector stftFromWaveform(std::span waveform,
size_t fftWindowSize, size_t hopSize) {
// Initialize FFT
FFT fft(fftWindowSize);
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h
index d04939f12..7eaa26d83 100644
--- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h
@@ -6,7 +6,7 @@
namespace rnexecutorch::dsp {
std::vector hannWindow(size_t size);
-std::vector stftFromWaveform(std::span waveform,
+std::vector stftFromWaveform(std::span waveform,
size_t fftWindowSize, size_t hopSize);
} // namespace rnexecutorch::dsp
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp
new file mode 100644
index 000000000..877b85995
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp
@@ -0,0 +1,47 @@
+#include
+#include
+
+#include "gzip.h"
+
+namespace rnexecutorch::gzip {
+
+namespace {
+constexpr int32_t kGzipWrapper = 16; // gzip header/trailer
+constexpr int32_t kMemLevel = 8; // memory level
+constexpr size_t kChunkSize = 16 * 1024; // 16 KiB stream buffer
+} // namespace
+
+size_t deflateSize(const std::string &input) {
+ z_stream strm{};
+ if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
+ MAX_WBITS + kGzipWrapper, kMemLevel,
+ Z_DEFAULT_STRATEGY) != Z_OK) {
+ throw std::runtime_error("deflateInit2 failed");
+ }
+
+ size_t outSize = 0;
+
+ strm.next_in = reinterpret_cast(
+ const_cast(input.data()));
+ strm.avail_in = static_cast(input.size());
+
+ std::vector buf(kChunkSize);
+ int ret;
+ do {
+ strm.next_out = buf.data();
+ strm.avail_out = static_cast(buf.size());
+
+ ret = ::deflate(&strm, strm.avail_in ? Z_NO_FLUSH : Z_FINISH);
+ if (ret == Z_STREAM_ERROR) {
+ ::deflateEnd(&strm);
+ throw std::runtime_error("deflate stream error");
+ }
+
+ outSize += buf.size() - strm.avail_out;
+ } while (ret != Z_STREAM_END);
+
+ ::deflateEnd(&strm);
+ return outSize;
+}
+
+} // namespace rnexecutorch::gzip
diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h
new file mode 100644
index 000000000..73890e1f9
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.h
@@ -0,0 +1,7 @@
+#pragma once
+
+namespace rnexecutorch::gzip {
+
+size_t deflateSize(const std::string &input);
+
+} // namespace rnexecutorch::gzip
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
index 17225cce0..949d75ad1 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
@@ -62,6 +62,30 @@ template class ModelHostObject : public JsiHostObject {
"decode"));
}
+ if constexpr (meta::HasTranscribe) {
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject,
+ promiseHostFunction<&Model::transcribe>,
+ "transcribe"));
+ }
+
+ if constexpr (meta::HasStream) {
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject,
+ promiseHostFunction<&Model::stream>,
+ "stream"));
+ }
+
+ if constexpr (meta::HasStreamInsert) {
+ addFunctions(JSI_EXPORT_FUNCTION(
+ ModelHostObject, promiseHostFunction<&Model::streamInsert>,
+ "streamInsert"));
+ }
+
+ if constexpr (meta::HasStreamStop) {
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject,
+ promiseHostFunction<&Model::streamStop>,
+ "streamStop"));
+ }
+
if constexpr (meta::SameAs) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject,
promiseHostFunction<&Model::encode>,
diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h
index 253a53238..10bc23af3 100644
--- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h
+++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h
@@ -26,6 +26,26 @@ concept HasDecode = requires(T t) {
{ &T::decode };
};
+template
+concept HasTranscribe = requires(T t) {
+ { &T::transcribe };
+};
+
+template
+concept HasStream = requires(T t) {
+ { &T::stream };
+};
+
+template
+concept HasStreamInsert = requires(T t) {
+ { &T::streamInsert };
+};
+
+template
+concept HasStreamStop = requires(T t) {
+ { &T::streamStop };
+};
+
template
concept IsNumeric = std::is_arithmetic_v;
@@ -34,4 +54,4 @@ concept ProvidesMemoryLowerBound = requires(T t) {
{ &T::getMemoryLowerBound };
};
-} // namespace rnexecutorch::meta
\ No newline at end of file
+} // namespace rnexecutorch::meta
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp
index 77b4fd63a..3eab6dd83 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp
@@ -142,7 +142,8 @@ BaseModel::getMethodMeta(const std::string &methodName) {
return module_->method_meta(methodName);
}
-Result> BaseModel::forward(const EValue &input_evalue) {
+Result>
+BaseModel::forward(const EValue &input_evalue) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
}
@@ -150,7 +151,7 @@ Result> BaseModel::forward(const EValue &input_evalue) {
}
Result>
-BaseModel::forward(const std::vector &input_evalues) {
+BaseModel::forward(const std::vector &input_evalues) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h
index dbf3c433a..983dc9b74 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h
@@ -26,8 +26,9 @@ class BaseModel {
getAllInputShapes(std::string methodName = "forward");
std::vector
forwardJS(std::vector tensorViewVec);
- Result> forward(const EValue &input_value);
- Result> forward(const std::vector &input_value);
+ Result> forward(const EValue &input_value) const;
+ Result>
+ forward(const std::vector &input_value) const;
Result> execute(const std::string &methodName,
const std::vector &input_value);
Result
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp b/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp
deleted file mode 100644
index a0da38708..000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-#include
-
-namespace rnexecutorch::models {
-
-EncoderDecoderBase::EncoderDecoderBase(
- const std::string &encoderPath, const std::string &decoderPath,
- std::shared_ptr callInvoker)
- : callInvoker(callInvoker),
- encoder_(std::make_unique(encoderPath, callInvoker)),
- decoder_(std::make_unique(decoderPath, callInvoker)) {};
-
-size_t EncoderDecoderBase::getMemoryLowerBound() const noexcept {
- return encoder_->getMemoryLowerBound() + decoder_->getMemoryLowerBound();
-}
-
-void EncoderDecoderBase::unload() noexcept {
- encoder_.reset(nullptr);
- decoder_.reset(nullptr);
-}
-
-} // namespace rnexecutorch::models
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h b/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h
deleted file mode 100644
index f2f2265aa..000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/EncoderDecoderBase.h
+++ /dev/null
@@ -1,31 +0,0 @@
-#pragma once
-
-#include
-#include
-#include
-#include
-
-namespace rnexecutorch::models {
-
-using namespace facebook;
-using executorch::aten::Tensor;
-using executorch::runtime::EValue;
-
-class EncoderDecoderBase {
-public:
- explicit EncoderDecoderBase(const std::string &encoderPath,
- const std::string &decoderPath,
- std::shared_ptr callInvoker);
- size_t getMemoryLowerBound() const noexcept;
- void unload() noexcept;
-
-protected:
- std::shared_ptr callInvoker;
- std::unique_ptr encoder_;
- std::unique_ptr decoder_;
-
-private:
- size_t memorySizeLowerBound;
-};
-
-} // namespace rnexecutorch::models
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
index 3f627095e..d444b9c91 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
@@ -1,64 +1,125 @@
-#include
-#include
-#include
+#include
+
+#include "SpeechToText.h"
namespace rnexecutorch::models::speech_to_text {
using namespace ::executorch::extension;
-SpeechToText::SpeechToText(const std::string &encoderPath,
- const std::string &decoderPath,
- const std::string &modelName,
+SpeechToText::SpeechToText(const std::string &encoderSource,
+ const std::string &decoderSource,
+ const std::string &tokenizerSource,
std::shared_ptr callInvoker)
- : EncoderDecoderBase(encoderPath, decoderPath, callInvoker),
- modelName(modelName) {
- initializeStrategy();
+ : callInvoker(std::move(callInvoker)),
+ encoder(std::make_unique(encoderSource, this->callInvoker)),
+ decoder(std::make_unique(decoderSource, this->callInvoker)),
+ tokenizer(std::make_unique(tokenizerSource,
+ this->callInvoker)),
+ asr(std::make_unique(this->encoder.get(), this->decoder.get(),
+ this->tokenizer.get())),
+ processor(std::make_unique(this->asr.get())),
+ isStreaming(false), readyToProcess(false) {}
+
+std::shared_ptr
+SpeechToText::encode(std::span waveform) const {
+ std::vector encoderOutput = this->asr->encode(waveform);
+ return this->makeOwningBuffer(encoderOutput);
}
-void SpeechToText::initializeStrategy() {
- if (modelName == "whisper") {
- strategy = std::make_unique();
- } else {
- throw std::runtime_error("Unsupported STT model: " + modelName +
- ". Only 'whisper' is supported.");
- }
+std::shared_ptr
+SpeechToText::decode(std::span tokens,
+ std::span encoderOutput) const {
+ std::vector decoderOutput = this->asr->decode(tokens, encoderOutput);
+ return this->makeOwningBuffer(decoderOutput);
}
-void SpeechToText::encode(std::span waveform) {
- const auto modelInputTensor = strategy->prepareAudioInput(waveform);
+std::string SpeechToText::transcribe(std::span waveform,
+ std::string languageOption) const {
+ std::vector segments =
+ this->asr->transcribe(waveform, DecodingOptions(languageOption));
+ std::string transcription;
+
+ size_t transcriptionLength = 0;
+ for (auto &segment : segments) {
+ for (auto &word : segment.words) {
+ transcriptionLength += word.content.size();
+ }
+ }
+ transcription.reserve(transcriptionLength);
- const auto result = encoder_->forward(modelInputTensor);
- if (!result.ok()) {
- throw std::runtime_error(
- "Forward pass failed during encoding, error code: " +
- std::to_string(static_cast(result.error())));
+ for (auto &segment : segments) {
+ for (auto &word : segment.words) {
+ transcription += word.content;
+ }
}
+ return transcription;
+}
- encoderOutput = result.get().at(0);
+size_t SpeechToText::getMemoryLowerBound() const noexcept {
+ return this->encoder->getMemoryLowerBound() +
+ this->decoder->getMemoryLowerBound() +
+ this->tokenizer->getMemoryLowerBound();
}
std::shared_ptr
-SpeechToText::decode(std::vector prevTokens) {
- if (encoderOutput.isNone()) {
- throw std::runtime_error("Empty encodings on decode call, make sure to "
- "call encode() prior to decode()!");
+SpeechToText::makeOwningBuffer(std::span vectorView) const {
+ auto owningArrayBuffer =
+ std::make_shared(vectorView.size_bytes());
+ std::memcpy(owningArrayBuffer->data(), vectorView.data(),
+ vectorView.size_bytes());
+ return owningArrayBuffer;
+}
+
+void SpeechToText::stream(std::shared_ptr callback,
+ std::string languageOption) {
+ if (this->isStreaming) {
+ throw std::runtime_error("Streaming is already in progress");
+ }
+
+ auto nativeCallback = [this, callback](const std::string &committed,
+ const std::string &nonCommitted,
+ bool isDone) {
+ this->callInvoker->invokeAsync(
+ [callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
+ callback->call(rt, jsi::String::createFromUtf8(rt, committed),
+ jsi::String::createFromUtf8(rt, nonCommitted),
+ jsi::Value(isDone));
+ });
+ };
+
+ this->resetStreamState();
+
+ this->isStreaming = true;
+ while (this->isStreaming) {
+ if (!this->readyToProcess ||
+ this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ continue;
+ }
+ ProcessResult res =
+ this->processor->processIter(DecodingOptions(languageOption));
+ nativeCallback(res.committed, res.nonCommitted, false);
+ this->readyToProcess = false;
}
- const auto prevTokensTensor = strategy->prepareTokenInput(prevTokens);
+ std::string committed = this->processor->finish();
+ nativeCallback(committed, "", true);
+}
- const auto decoderMethod = strategy->getDecoderMethod();
- const auto decoderResult =
- decoder_->execute(decoderMethod, {prevTokensTensor, encoderOutput});
+void SpeechToText::streamStop() { this->isStreaming = false; }
- if (!decoderResult.ok()) {
- throw std::runtime_error(
- "Forward pass failed during decoding, error code: " +
- std::to_string(static_cast(decoderResult.error())));
+void SpeechToText::streamInsert(std::span waveform) {
+ if (!this->isStreaming) {
+ throw std::runtime_error("Streaming is not started");
}
+ this->processor->insertAudioChunk(waveform);
+ this->readyToProcess = true;
+}
- const auto decoderOutputTensor = decoderResult.get().at(0).toTensor();
- const auto innerDim = decoderOutputTensor.size(1);
- return strategy->extractOutputToken(decoderOutputTensor);
+void SpeechToText::resetStreamState() {
+ this->isStreaming = false;
+ this->readyToProcess = false;
+ this->processor = std::make_unique(this->asr.get());
}
} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
index fa90b53dc..a6f3779e4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
@@ -1,38 +1,60 @@
#pragma once
-#include "ReactCommon/CallInvoker.h"
-#include "executorch/runtime/core/evalue.h"
-#include
-#include
-#include
-#include
-#include
-
-#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
-#include
-#include
+#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
namespace rnexecutorch {
+
namespace models::speech_to_text {
-class SpeechToText : public EncoderDecoderBase {
+
+using namespace asr;
+using namespace types;
+using namespace stream;
+
+class SpeechToText {
public:
- explicit SpeechToText(const std::string &encoderPath,
- const std::string &decoderPath,
- const std::string &modelName,
+ explicit SpeechToText(const std::string &encoderSource,
+ const std::string &decoderSource,
+ const std::string &tokenizerSource,
std::shared_ptr callInvoker);
- void encode(std::span waveform);
- std::shared_ptr decode(std::vector prevTokens);
+
+ std::shared_ptr encode(std::span waveform) const;
+ std::shared_ptr
+ decode(std::span tokens, std::span encoderOutput) const;
+ std::string transcribe(std::span waveform,
+ std::string languageOption) const;
+
+ size_t getMemoryLowerBound() const noexcept;
+
+ // Stream
+ void stream(std::shared_ptr callback,
+ std::string languageOption);
+ void streamStop();
+ void streamInsert(std::span waveform);
private:
- const std::string modelName;
- executorch::runtime::EValue encoderOutput;
- std::unique_ptr strategy;
+ std::unique_ptr encoder;
+ std::unique_ptr decoder;
+ std::unique_ptr tokenizer;
+ std::unique_ptr asr;
- void initializeStrategy();
+ std::shared_ptr
+ makeOwningBuffer(std::span vectorView) const;
+
+ // Stream
+ std::shared_ptr callInvoker;
+ std::unique_ptr processor;
+ bool isStreaming;
+ bool readyToProcess;
+
+ constexpr static int32_t kMinAudioSamples = 16000; // 1 second
+
+ void resetStreamState();
};
+
} // namespace models::speech_to_text
REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string,
std::string, std::string,
std::shared_ptr);
+
} // namespace rnexecutorch
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h
deleted file mode 100644
index c68f2e395..000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h
+++ /dev/null
@@ -1,27 +0,0 @@
-#pragma once
-
-#include "executorch/extension/tensor/tensor_ptr.h"
-#include
-#include
-#include
-
-namespace rnexecutorch::models::speech_to_text {
-
-using TensorPtr = ::executorch::extension::TensorPtr;
-
-class SpeechToTextStrategy {
-public:
- virtual ~SpeechToTextStrategy() = default;
-
- virtual TensorPtr prepareAudioInput(std::span waveform) = 0;
-
- virtual TensorPtr
- prepareTokenInput(const std::vector &prevTokens) = 0;
-
- virtual std::string getDecoderMethod() const = 0;
-
- virtual std::shared_ptr extractOutputToken(
- const executorch::aten::Tensor &decoderOutputTensor) const = 0;
-};
-
-} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp
deleted file mode 100644
index 3b33c4450..000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp
+++ /dev/null
@@ -1,50 +0,0 @@
-#include "executorch/extension/tensor/tensor_ptr.h"
-#include "rnexecutorch/data_processing/dsp.h"
-#include
-
-namespace rnexecutorch::models::speech_to_text {
-
-using namespace ::executorch::extension;
-using namespace ::executorch::aten;
-
-TensorPtr WhisperStrategy::prepareAudioInput(std::span waveform) {
- constexpr auto fftWindowSize = 512;
- constexpr auto stftHopLength = 160;
- constexpr auto innerDim = 256;
- preprocessedData =
- dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
- const auto numFrames = preprocessedData.size() / innerDim;
- std::vector inputShape = {static_cast(numFrames), innerDim};
- return make_tensor_ptr(std::move(inputShape), std::move(preprocessedData));
-}
-
-TensorPtr
-WhisperStrategy::prepareTokenInput(const std::vector &prevTokens) {
- tokens32.clear();
- tokens32.reserve(prevTokens.size());
- for (auto token : prevTokens) {
- tokens32.push_back(static_cast(token));
- }
- auto tensorSizes = {1, static_cast(tokens32.size())};
- return make_tensor_ptr(std::move(tensorSizes), std::move(tokens32));
-}
-
-std::shared_ptr WhisperStrategy::extractOutputToken(
- const executorch::aten::Tensor &decoderOutputTensor) const {
- const auto innerDim = decoderOutputTensor.size(1);
- const auto dictSize = decoderOutputTensor.size(2);
- auto outputNumel = decoderOutputTensor.numel();
- auto dataPtr =
- static_cast(decoderOutputTensor.const_data_ptr()) +
- (innerDim - 1) * dictSize;
-
- std::span modelOutput(dataPtr, outputNumel / innerDim);
- auto createBuffer = [](const auto &data, size_t size) {
- auto buffer = std::make_shared(size);
- std::memcpy(buffer->data(), data, size);
- return buffer;
- };
- return createBuffer(modelOutput.data(), modelOutput.size_bytes());
-}
-
-} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h
deleted file mode 100644
index 936be3c00..000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h
+++ /dev/null
@@ -1,25 +0,0 @@
-#pragma once
-
-#include "SpeechToTextStrategy.h"
-#include
-#include
-
-namespace rnexecutorch::models::speech_to_text {
-
-class WhisperStrategy final : public SpeechToTextStrategy {
-public:
- TensorPtr prepareAudioInput(std::span waveform) override;
-
- TensorPtr prepareTokenInput(const std::vector &prevTokens) override;
-
- std::string getDecoderMethod() const override { return "forward"; }
-
- std::shared_ptr extractOutputToken(
- const executorch::aten::Tensor &decoderOutputTensor) const override;
-
-private:
- std::vector preprocessedData;
- std::vector tokens32;
-};
-
-} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp
new file mode 100644
index 000000000..5a56f2d7e
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp
@@ -0,0 +1,307 @@
+#include
+
+#include "ASR.h"
+#include "executorch/extension/tensor/tensor_ptr.h"
+#include "rnexecutorch/data_processing/Numerical.h"
+#include "rnexecutorch/data_processing/dsp.h"
+#include "rnexecutorch/data_processing/gzip.h"
+
+namespace rnexecutorch::models::speech_to_text::asr {
+
+ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder,
+ const TokenizerModule *tokenizer)
+ : encoder(encoder), decoder(decoder), tokenizer(tokenizer),
+ startOfTranscriptionToken(
+ this->tokenizer->tokenToId("<|startoftranscript|>")),
+ endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")),
+ timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {}
+
+std::vector
+ASR::getInitialSequence(const DecodingOptions &options) const {
+ std::vector seq;
+ seq.push_back(this->startOfTranscriptionToken);
+
+ if (options.language.has_value()) {
+ int32_t langToken =
+ this->tokenizer->tokenToId("<|" + options.language.value() + "|>");
+ int32_t taskToken = this->tokenizer->tokenToId("<|transcribe|>");
+ seq.push_back(langToken);
+ seq.push_back(taskToken);
+ }
+
+ seq.push_back(this->timestampBeginToken);
+
+ return seq;
+}
+
+GenerationResult ASR::generate(std::span waveform,
+ float temperature,
+ const DecodingOptions &options) const {
+ std::vector encoderOutput = this->encode(waveform);
+
+ std::vector sequenceIds = this->getInitialSequence(options);
+ const size_t initialSequenceLenght = sequenceIds.size();
+ std::vector scores;
+
+ while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) {
+ std::vector logits = this->decode(sequenceIds, encoderOutput);
+
+ // intentionally comparing float to float
+ // temperatures are predefined, so this is safe
+ if (temperature == 0.0f) {
+ numerical::softmax(logits);
+ } else {
+ numerical::softmaxWithTemperature(logits, temperature);
+ }
+
+ const std::vector &probs = logits;
+
+ int32_t nextId;
+ float nextProb;
+
+ // intentionally comparing float to float
+ // temperatures are predefined, so this is safe
+ if (temperature == 0.0f) {
+ auto maxIt = std::ranges::max_element(probs);
+ nextId = static_cast(std::distance(probs.begin(), maxIt));
+ nextProb = *maxIt;
+ } else {
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
+ std::mt19937 gen((std::random_device{}()));
+ nextId = dist(gen);
+ nextProb = probs[nextId];
+ }
+
+ sequenceIds.push_back(nextId);
+ scores.push_back(nextProb);
+
+ if (nextId == this->endOfTranscriptionToken) {
+ break;
+ }
+ }
+
+ return {.tokens = std::vector(
+ sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()),
+ .scores = scores};
+}
+
+float ASR::getCompressionRatio(const std::string &text) const {
+ size_t compressedSize = gzip::deflateSize(text);
+ return static_cast(text.size()) / static_cast(compressedSize);
+}
+
+std::vector
+ASR::generateWithFallback(std::span waveform,
+ const DecodingOptions &options) const {
+ std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
+ std::vector bestTokens;
+
+ for (auto t : temperatures) {
+ auto [tokens, scores] = this->generate(waveform, t, options);
+
+ const float cumLogProb = std::transform_reduce(
+ scores.begin(), scores.end(), 0.0f, std::plus<>(),
+ [](float s) { return std::log(std::max(s, 1e-9f)); });
+
+ const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1);
+ const std::string text = this->tokenizer->decode(tokens, true);
+ const float compressionRatio = this->getCompressionRatio(text);
+
+ if (avgLogProb >= -1.0f && compressionRatio < 2.4f) {
+ bestTokens = std::move(tokens);
+ break;
+ }
+ }
+
+ return this->calculateWordLevelTimestamps(bestTokens, waveform);
+}
+
+std::vector
+ASR::calculateWordLevelTimestamps(std::span generatedTokens,
+ const std::span waveform) const {
+ const size_t generatedTokensSize = generatedTokens.size();
+ if (generatedTokensSize < 2 ||
+ generatedTokens[generatedTokensSize - 1] !=
+ this->endOfTranscriptionToken ||
+ generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) {
+ return {};
+ }
+ std::vector segments;
+ std::vector tokens;
+ int32_t prevTimestamp = this->timestampBeginToken;
+
+ for (size_t i = 0; i < generatedTokensSize; i++) {
+ if (generatedTokens[i] < this->timestampBeginToken) {
+ tokens.push_back(generatedTokens[i]);
+ }
+ if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken &&
+ generatedTokens[i] >= this->timestampBeginToken) {
+ const int32_t start = prevTimestamp;
+ const int32_t end = generatedTokens[i - 1];
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
+ if (words.size()) {
+ segments.emplace_back(std::move(words), 0.0);
+ }
+ tokens.clear();
+ prevTimestamp = generatedTokens[i];
+ }
+ }
+
+ const int32_t start = prevTimestamp;
+ const int32_t end = generatedTokens[generatedTokensSize - 2];
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
+
+ if (words.size()) {
+ segments.emplace_back(std::move(words), 0.0);
+ }
+
+ float scalingFactor =
+ static_cast(waveform.size()) /
+ (ASR::kSamplingRate * (end - this->timestampBeginToken) *
+ ASR::kTimePrecision);
+ if (scalingFactor < 1.0f) {
+ for (auto &seg : segments) {
+ for (auto &w : seg.words) {
+ w.start *= scalingFactor;
+ w.end *= scalingFactor;
+ }
+ }
+ }
+
+ return segments;
+}
+
+std::vector
+ASR::estimateWordLevelTimestampsLinear(std::span tokens,
+ int32_t start, int32_t end) const {
+ const std::vector tokensVec(tokens.begin(), tokens.end());
+ const std::string segmentText = this->tokenizer->decode(tokensVec, true);
+ std::istringstream iss(segmentText);
+ std::vector wordsStr;
+ std::string word;
+ while (iss >> word) {
+ wordsStr.emplace_back(" ");
+ wordsStr.back().append(word);
+ }
+
+ size_t numChars = 0;
+ for (const auto &w : wordsStr) {
+ numChars += w.size();
+ }
+ const float duration = (end - start) * ASR::kTimePrecision;
+ const float timePerChar = duration / std::max(1, numChars);
+ const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision;
+
+ std::vector wordObjs;
+ wordObjs.reserve(wordsStr.size());
+ int32_t prevCharCount = 0;
+ for (auto &w : wordsStr) {
+ const auto wSize = static_cast(w.size());
+ const float wStart = startOffset + prevCharCount * timePerChar;
+ const float wEnd = wStart + timePerChar * wSize;
+ prevCharCount += wSize;
+ wordObjs.emplace_back(std::move(w), wStart, wEnd);
+ }
+
+ return wordObjs;
+}
+
+std::vector ASR::transcribe(std::span waveform,
+ const DecodingOptions &options) const {
+ int32_t seek = 0;
+ std::vector results;
+
+ while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) {
+ int32_t start = seek * ASR::kSamplingRate;
+ const auto end = std::min(
+ (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
+ std::span chunk = waveform.subspan(start, end - start);
+
+ if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
+ break;
+ }
+
+ std::vector