Skip to content

Commit 4b113f7

Browse files
mkopcinsMateusz Kopciński
authored andcommitted
feat: Implemented GlobalThreadPool for async functions (#603)
Added singleton class GlobalThreadPool for single threadpool management so that we don't have to spawn new threads for each async function and instead we can delegate functions to the threadpool. Also added pthreadpool and cpuinfo binaries for iOS to allow for XNNPack threadpool configuration just like on Android - [ ] Yes - [x] No - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) - [x] iOS - [x] Android <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> <!-- Add screenshots here, if applicable --> <!-- Link related issues here using #issue-number --> - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Kopciński <[email protected]>
1 parent 731fd77 commit 4b113f7

File tree

13 files changed

+560
-77
lines changed

13 files changed

+560
-77
lines changed

.cspell-wordlist.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,17 @@ softmax
6262
logit
6363
logits
6464
probs
65+
unet
66+
Unet
67+
VPRED
68+
timesteps
69+
Timesteps
70+
denoises
71+
denoise
72+
denoising
73+
threadpool
74+
chrono
75+
setpriority
76+
errno
77+
ifdef
78+
elif

apps/llm/app.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
"foregroundImage": "./assets/icons/adaptive-icon.png",
5959
"backgroundColor": "#ffffff"
6060
},
61-
"package": "com.anonymous.llm"
61+
"package": "com.anonymous.llm",
62+
"permissions": [
63+
"android.permission.READ_CALENDAR",
64+
"android.permission.WRITE_CALENDAR"
65+
]
6266
},
6367
"web": {
6468
"favicon": "./assets/icons/favicon.png"

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@
1212
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
1313
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
1414
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>
15-
16-
#if defined(__ANDROID__) && defined(__aarch64__)
17-
#include <executorch/extension/threadpool/cpuinfo_utils.h>
18-
#include <executorch/extension/threadpool/threadpool.h>
19-
#include <rnexecutorch/Log.h>
20-
#endif
15+
#include <rnexecutorch/threads/GlobalThreadPool.h>
16+
#include <rnexecutorch/threads/utils/ThreadUtils.h>
2117

2218
namespace rnexecutorch {
2319

@@ -92,21 +88,8 @@ void RnExecutorchInstaller::injectJSIBindings(
9288
RnExecutorchInstaller::loadModel<models::speech_to_text::SpeechToText>(
9389
jsiRuntime, jsCallInvoker, "loadSpeechToText"));
9490

95-
#if defined(__ANDROID__) && defined(__aarch64__)
96-
auto num_of_perf_cores =
97-
::executorch::extension::cpuinfo::get_num_performant_cores();
98-
log(LOG_LEVEL::Info, "Detected ", num_of_perf_cores, " performant cores");
99-
// setting num_of_cores to floor(num_of_perf_cores / 2) + 1) because depending
100-
// on cpu arch as when possible we want to leave at least 2 performant cores
101-
// for other tasks (setting more actually results in drop of performance). For
102-
// older devices (i.e. samsung s22) resolves to 3 cores, and for newer ones
103-
// (like OnePlus 12) resolves to 4, which when benchamrked gives highest
104-
// throughput.
105-
auto num_of_cores = static_cast<uint32_t>(num_of_perf_cores / 2) + 1;
106-
::executorch::extension::threadpool::get_threadpool()
107-
->_unsafe_reset_threadpool(num_of_cores);
108-
log(LOG_LEVEL::Info, "Configuring xnnpack for ", num_of_cores, " threads");
109-
#endif
91+
threads::utils::unsafeSetupThreadPool();
92+
threads::GlobalThreadPool::initialize();
11093
}
11194

11295
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <rnexecutorch/models/llm/LLM.h>
2020
#include <rnexecutorch/models/ocr/OCR.h>
2121
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>
22+
#include <rnexecutorch/threads/GlobalThreadPool.h>
2223

2324
namespace rnexecutorch {
2425

@@ -195,58 +196,60 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
195196
// We need to dispatch a thread if we want the function to be
196197
// asynchronous. In this thread all accesses to jsi::Runtime need to
197198
// be done via the callInvoker.
198-
std::thread([this, promise,
199-
argsConverted = std::move(argsConverted)]() {
200-
try {
201-
if constexpr (std::is_void_v<decltype(std::apply(
202-
std::bind_front(FnPtr, model),
203-
argsConverted))>) {
204-
// For void functions, just call the function and resolve with
205-
// undefined
206-
std::apply(std::bind_front(FnPtr, model),
207-
std::move(argsConverted));
208-
callInvoker->invokeAsync([promise](jsi::Runtime &runtime) {
209-
promise->resolve(jsi::Value::undefined());
210-
});
211-
} else {
212-
// For non-void functions, capture the result and convert it
213-
auto result = std::apply(std::bind_front(FnPtr, model),
214-
std::move(argsConverted));
215-
// The result is copied. It should either be quickly copiable,
216-
// or passed with a shared_ptr.
217-
callInvoker->invokeAsync(
218-
[promise, result](jsi::Runtime &runtime) {
219-
promise->resolve(jsi_conversion::getJsiValue(
220-
std::move(result), runtime));
221-
});
222-
}
223-
} catch (const std::runtime_error &e) {
224-
// This catch should be merged with the next two
225-
// (std::runtime_error and jsi::JSError inherits from
226-
// std::exception) HOWEVER react native has broken RTTI which
227-
// breaks proper exception type checking. Remove when the
228-
// following change is present in our version:
229-
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
230-
callInvoker->invokeAsync([e = std::move(e), promise]() {
231-
promise->reject(e.what());
199+
threads::GlobalThreadPool::detach(
200+
[this, promise, argsConverted = std::move(argsConverted)]() {
201+
try {
202+
if constexpr (std::is_void_v<decltype(std::apply(
203+
std::bind_front(FnPtr, model),
204+
argsConverted))>) {
205+
// For void functions, just call the function and resolve
206+
// with undefined
207+
std::apply(std::bind_front(FnPtr, model),
208+
std::move(argsConverted));
209+
callInvoker->invokeAsync(
210+
[promise](jsi::Runtime &runtime) {
211+
promise->resolve(jsi::Value::undefined());
212+
});
213+
} else {
214+
// For non-void functions, capture the result and convert
215+
// it
216+
auto result = std::apply(std::bind_front(FnPtr, model),
217+
std::move(argsConverted));
218+
// The result is copied. It should either be quickly
219+
// copiable, or passed with a shared_ptr.
220+
callInvoker->invokeAsync(
221+
[promise, result](jsi::Runtime &runtime) {
222+
promise->resolve(jsi_conversion::getJsiValue(
223+
std::move(result), runtime));
224+
});
225+
}
226+
} catch (const std::runtime_error &e) {
227+
// This catch should be merged with the next two
228+
// (std::runtime_error and jsi::JSError inherits from
229+
// std::exception) HOWEVER react native has broken RTTI
230+
// which breaks proper exception type checking. Remove when
231+
// the following change is present in our version:
232+
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
233+
callInvoker->invokeAsync([e = std::move(e), promise]() {
234+
promise->reject(e.what());
235+
});
236+
return;
237+
} catch (const jsi::JSError &e) {
238+
callInvoker->invokeAsync([e = std::move(e), promise]() {
239+
promise->reject(e.what());
240+
});
241+
return;
242+
} catch (const std::exception &e) {
243+
callInvoker->invokeAsync([e = std::move(e), promise]() {
244+
promise->reject(e.what());
245+
});
246+
return;
247+
} catch (...) {
248+
callInvoker->invokeAsync(
249+
[promise]() { promise->reject("Unknown error"); });
250+
return;
251+
}
232252
});
233-
return;
234-
} catch (const jsi::JSError &e) {
235-
callInvoker->invokeAsync([e = std::move(e), promise]() {
236-
promise->reject(e.what());
237-
});
238-
return;
239-
} catch (const std::exception &e) {
240-
callInvoker->invokeAsync([e = std::move(e), promise]() {
241-
promise->reject(e.what());
242-
});
243-
return;
244-
} catch (...) {
245-
callInvoker->invokeAsync(
246-
[promise]() { promise->reject("Unknown error"); });
247-
return;
248-
}
249-
}).detach();
250253
} catch (...) {
251254
promise->reject("Couldn't parse JS arguments in a native function");
252255
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// GlobalThreadPool.h
2+
#pragma once
3+
4+
#include <executorch/extension/threadpool/cpuinfo_utils.h>
5+
#include <memory>
6+
#include <mutex>
7+
#include <optional>
8+
#include <rnexecutorch/Log.h>
9+
#include <rnexecutorch/threads/HighPerformanceThreadPool.h>
10+
11+
namespace rnexecutorch::threads {
12+
13+
class GlobalThreadPool {
14+
public:
15+
GlobalThreadPool() = delete;
16+
GlobalThreadPool(const GlobalThreadPool &) = delete;
17+
GlobalThreadPool &operator=(const GlobalThreadPool &) = delete;
18+
GlobalThreadPool(GlobalThreadPool &&) = delete;
19+
GlobalThreadPool &operator=(GlobalThreadPool &&) = delete;
20+
21+
static HighPerformanceThreadPool &get() {
22+
if (!instance) {
23+
initialize();
24+
}
25+
return *instance;
26+
}
27+
28+
static void initialize(std::optional<uint32_t> numThreads = std::nullopt,
29+
ThreadConfig config = {}) {
30+
std::call_once(initFlag, [&numThreads, config]() {
31+
if (!numThreads) {
32+
numThreads =
33+
::executorch::extension::cpuinfo::get_num_performant_cores();
34+
}
35+
36+
log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with",
37+
numThreads, "threads");
38+
instance = std::make_unique<HighPerformanceThreadPool>(numThreads.value(),
39+
config);
40+
});
41+
}
42+
43+
// Convenience methods that mirror std::thread interface
44+
template <typename Func, typename... Args>
45+
static auto async(Func &&func, Args &&...args) {
46+
return get().submit(std::forward<Func>(func), std::forward<Args>(args)...);
47+
}
48+
49+
template <typename Func, typename... Args>
50+
static auto async_high_priority(Func &&func, Args &&...args) {
51+
return get().submitWithPriority(Priority::HIGH, std::forward<Func>(func),
52+
std::forward<Args>(args)...);
53+
}
54+
55+
// Fire and forget (like std::thread{}.detach())
56+
template <typename Func, typename... Args>
57+
static void detach(Func &&func, Args &&...args) {
58+
get().submitDetached(std::forward<Func>(func), std::forward<Args>(args)...);
59+
}
60+
61+
// Execute and wait (like std::thread{}.join())
62+
template <typename Func, typename... Args>
63+
static auto execute(Func &&func, Args &&...args) {
64+
return get().execute(std::forward<Func>(func), std::forward<Args>(args)...);
65+
}
66+
67+
static void shutdown() {
68+
if (instance) {
69+
instance->shutdown();
70+
instance.reset();
71+
}
72+
}
73+
74+
private:
75+
inline static std::unique_ptr<HighPerformanceThreadPool> instance;
76+
inline static std::once_flag initFlag;
77+
};
78+
79+
} // namespace rnexecutorch::threads

0 commit comments

Comments
 (0)