Skip to content

Commit 31ebea6

Browse files
mkopcinsMateusz Kopciński
authored andcommitted
feat: Implemented GlobalThreadPool for async functions (#603)
## Description Added singleton class GlobalThreadPool for single threadpool management so that we don't have to spawn new threads for each async function and instead we can delegate functions to the threadpool. Also added pthreadpool and cpuinfo binaries for iOS to allow for XNNPack threadpool configuration just like on Android ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Kopciński <[email protected]>
1 parent c89e9d0 commit 31ebea6

File tree

13 files changed

+553
-78
lines changed

13 files changed

+553
-78
lines changed

.cspell-wordlist.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,10 @@ timesteps
7373
Timesteps
7474
denoises
7575
denoise
76-
denoising
76+
denoising
77+
threadpool
78+
chrono
79+
setpriority
80+
errno
81+
ifdef
82+
elif

apps/llm/app.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
"foregroundImage": "./assets/icons/adaptive-icon.png",
5959
"backgroundColor": "#ffffff"
6060
},
61-
"package": "com.anonymous.llm"
61+
"package": "com.anonymous.llm",
62+
"permissions": [
63+
"android.permission.READ_CALENDAR",
64+
"android.permission.WRITE_CALENDAR"
65+
]
6266
},
6367
"web": {
6468
"favicon": "./assets/icons/favicon.png"

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,8 @@
1313
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
1414
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
1515
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>
16-
17-
#if defined(__ANDROID__) && defined(__aarch64__)
18-
#include <executorch/extension/threadpool/cpuinfo_utils.h>
19-
#include <executorch/extension/threadpool/threadpool.h>
20-
#include <rnexecutorch/Log.h>
21-
#endif
16+
#include <rnexecutorch/threads/GlobalThreadPool.h>
17+
#include <rnexecutorch/threads/utils/ThreadUtils.h>
2218

2319
namespace rnexecutorch {
2420

@@ -97,21 +93,8 @@ void RnExecutorchInstaller::injectJSIBindings(
9793
RnExecutorchInstaller::loadModel<models::speech_to_text::SpeechToText>(
9894
jsiRuntime, jsCallInvoker, "loadSpeechToText"));
9995

100-
#if defined(__ANDROID__) && defined(__aarch64__)
101-
auto num_of_perf_cores =
102-
::executorch::extension::cpuinfo::get_num_performant_cores();
103-
log(LOG_LEVEL::Info, "Detected ", num_of_perf_cores, " performant cores");
104-
// setting num_of_cores to floor(num_of_perf_cores / 2) + 1) because depending
105-
// on cpu arch as when possible we want to leave at least 2 performant cores
106-
// for other tasks (setting more actually results in drop of performance). For
107-
// older devices (i.e. samsung s22) resolves to 3 cores, and for newer ones
108-
// (like OnePlus 12) resolves to 4, which when benchamrked gives highest
109-
// throughput.
110-
auto num_of_cores = static_cast<uint32_t>(num_of_perf_cores / 2) + 1;
111-
::executorch::extension::threadpool::get_threadpool()
112-
->_unsafe_reset_threadpool(num_of_cores);
113-
log(LOG_LEVEL::Info, "Configuring xnnpack for ", num_of_cores, " threads");
114-
#endif
96+
threads::utils::unsafeSetupThreadPool();
97+
threads::GlobalThreadPool::initialize();
11598
}
11699

117100
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
2222
#include <rnexecutorch/models/text_to_image/TextToImage.h>
2323
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>
24+
#include <rnexecutorch/threads/GlobalThreadPool.h>
2425

2526
namespace rnexecutorch {
2627

@@ -201,58 +202,60 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
201202
// We need to dispatch a thread if we want the function to be
202203
// asynchronous. In this thread all accesses to jsi::Runtime need to
203204
// be done via the callInvoker.
204-
std::thread([this, promise,
205-
argsConverted = std::move(argsConverted)]() {
206-
try {
207-
if constexpr (std::is_void_v<decltype(std::apply(
208-
std::bind_front(FnPtr, model),
209-
argsConverted))>) {
210-
// For void functions, just call the function and resolve with
211-
// undefined
212-
std::apply(std::bind_front(FnPtr, model),
213-
std::move(argsConverted));
214-
callInvoker->invokeAsync([promise](jsi::Runtime &runtime) {
215-
promise->resolve(jsi::Value::undefined());
216-
});
217-
} else {
218-
// For non-void functions, capture the result and convert it
219-
auto result = std::apply(std::bind_front(FnPtr, model),
220-
std::move(argsConverted));
221-
// The result is copied. It should either be quickly copiable,
222-
// or passed with a shared_ptr.
223-
callInvoker->invokeAsync(
224-
[promise, result](jsi::Runtime &runtime) {
225-
promise->resolve(jsi_conversion::getJsiValue(
226-
std::move(result), runtime));
227-
});
228-
}
229-
} catch (const std::runtime_error &e) {
230-
// This catch should be merged with the next two
231-
// (std::runtime_error and jsi::JSError inherits from
232-
// std::exception) HOWEVER react native has broken RTTI which
233-
// breaks proper exception type checking. Remove when the
234-
// following change is present in our version:
235-
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
236-
callInvoker->invokeAsync([e = std::move(e), promise]() {
237-
promise->reject(e.what());
205+
threads::GlobalThreadPool::detach(
206+
[this, promise, argsConverted = std::move(argsConverted)]() {
207+
try {
208+
if constexpr (std::is_void_v<decltype(std::apply(
209+
std::bind_front(FnPtr, model),
210+
argsConverted))>) {
211+
// For void functions, just call the function and resolve
212+
// with undefined
213+
std::apply(std::bind_front(FnPtr, model),
214+
std::move(argsConverted));
215+
callInvoker->invokeAsync(
216+
[promise](jsi::Runtime &runtime) {
217+
promise->resolve(jsi::Value::undefined());
218+
});
219+
} else {
220+
// For non-void functions, capture the result and convert
221+
// it
222+
auto result = std::apply(std::bind_front(FnPtr, model),
223+
std::move(argsConverted));
224+
// The result is copied. It should either be quickly
225+
// copiable, or passed with a shared_ptr.
226+
callInvoker->invokeAsync(
227+
[promise, result](jsi::Runtime &runtime) {
228+
promise->resolve(jsi_conversion::getJsiValue(
229+
std::move(result), runtime));
230+
});
231+
}
232+
} catch (const std::runtime_error &e) {
233+
// This catch should be merged with the next two
234+
// (std::runtime_error and jsi::JSError inherits from
235+
// std::exception) HOWEVER react native has broken RTTI
236+
// which breaks proper exception type checking. Remove when
237+
// the following change is present in our version:
238+
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
239+
callInvoker->invokeAsync([e = std::move(e), promise]() {
240+
promise->reject(e.what());
241+
});
242+
return;
243+
} catch (const jsi::JSError &e) {
244+
callInvoker->invokeAsync([e = std::move(e), promise]() {
245+
promise->reject(e.what());
246+
});
247+
return;
248+
} catch (const std::exception &e) {
249+
callInvoker->invokeAsync([e = std::move(e), promise]() {
250+
promise->reject(e.what());
251+
});
252+
return;
253+
} catch (...) {
254+
callInvoker->invokeAsync(
255+
[promise]() { promise->reject("Unknown error"); });
256+
return;
257+
}
238258
});
239-
return;
240-
} catch (const jsi::JSError &e) {
241-
callInvoker->invokeAsync([e = std::move(e), promise]() {
242-
promise->reject(e.what());
243-
});
244-
return;
245-
} catch (const std::exception &e) {
246-
callInvoker->invokeAsync([e = std::move(e), promise]() {
247-
promise->reject(e.what());
248-
});
249-
return;
250-
} catch (...) {
251-
callInvoker->invokeAsync(
252-
[promise]() { promise->reject("Unknown error"); });
253-
return;
254-
}
255-
}).detach();
256259
} catch (...) {
257260
promise->reject("Couldn't parse JS arguments in a native function");
258261
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// GlobalThreadPool.h
2+
#pragma once
3+
4+
#include <executorch/extension/threadpool/cpuinfo_utils.h>
5+
#include <memory>
6+
#include <mutex>
7+
#include <optional>
8+
#include <rnexecutorch/Log.h>
9+
#include <rnexecutorch/threads/HighPerformanceThreadPool.h>
10+
11+
namespace rnexecutorch::threads {
12+
13+
class GlobalThreadPool {
14+
public:
15+
GlobalThreadPool() = delete;
16+
GlobalThreadPool(const GlobalThreadPool &) = delete;
17+
GlobalThreadPool &operator=(const GlobalThreadPool &) = delete;
18+
GlobalThreadPool(GlobalThreadPool &&) = delete;
19+
GlobalThreadPool &operator=(GlobalThreadPool &&) = delete;
20+
21+
static HighPerformanceThreadPool &get() {
22+
if (!instance) {
23+
initialize();
24+
}
25+
return *instance;
26+
}
27+
28+
static void initialize(std::optional<uint32_t> numThreads = std::nullopt,
29+
ThreadConfig config = {}) {
30+
std::call_once(initFlag, [&numThreads, config]() {
31+
if (!numThreads) {
32+
numThreads =
33+
::executorch::extension::cpuinfo::get_num_performant_cores();
34+
}
35+
36+
log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with",
37+
numThreads, "threads");
38+
instance = std::make_unique<HighPerformanceThreadPool>(numThreads.value(),
39+
config);
40+
});
41+
}
42+
43+
// Convenience methods that mirror std::thread interface
44+
template <typename Func, typename... Args>
45+
static auto async(Func &&func, Args &&...args) {
46+
return get().submit(std::forward<Func>(func), std::forward<Args>(args)...);
47+
}
48+
49+
template <typename Func, typename... Args>
50+
static auto async_high_priority(Func &&func, Args &&...args) {
51+
return get().submitWithPriority(Priority::HIGH, std::forward<Func>(func),
52+
std::forward<Args>(args)...);
53+
}
54+
55+
// Fire and forget (like std::thread{}.detach())
56+
template <typename Func, typename... Args>
57+
static void detach(Func &&func, Args &&...args) {
58+
get().submitDetached(std::forward<Func>(func), std::forward<Args>(args)...);
59+
}
60+
61+
// Execute and wait (like std::thread{}.join())
62+
template <typename Func, typename... Args>
63+
static auto execute(Func &&func, Args &&...args) {
64+
return get().execute(std::forward<Func>(func), std::forward<Args>(args)...);
65+
}
66+
67+
static void shutdown() {
68+
if (instance) {
69+
instance->shutdown();
70+
instance.reset();
71+
}
72+
}
73+
74+
private:
75+
inline static std::unique_ptr<HighPerformanceThreadPool> instance;
76+
inline static std::once_flag initFlag;
77+
};
78+
79+
} // namespace rnexecutorch::threads

0 commit comments

Comments
 (0)