Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .cspell-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,10 @@ timesteps
Timesteps
denoises
denoise
denoising
denoising
threadpool
chrono
setpriority
errno
ifdef
elif
6 changes: 5 additions & 1 deletion apps/llm/app.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@
"foregroundImage": "./assets/icons/adaptive-icon.png",
"backgroundColor": "#ffffff"
},
"package": "com.anonymous.llm"
"package": "com.anonymous.llm",
"permissions": [
"android.permission.READ_CALENDAR",
"android.permission.WRITE_CALENDAR"
]
},
"web": {
"favicon": "./assets/icons/favicon.png"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>

#if defined(__ANDROID__) && defined(__aarch64__)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#include <rnexecutorch/Log.h>
#endif
#include <rnexecutorch/threads/GlobalThreadPool.h>
#include <rnexecutorch/threads/utils/ThreadUtils.h>

namespace rnexecutorch {

Expand Down Expand Up @@ -97,21 +93,8 @@ void RnExecutorchInstaller::injectJSIBindings(
RnExecutorchInstaller::loadModel<models::speech_to_text::SpeechToText>(
jsiRuntime, jsCallInvoker, "loadSpeechToText"));

#if defined(__ANDROID__) && defined(__aarch64__)
auto num_of_perf_cores =
::executorch::extension::cpuinfo::get_num_performant_cores();
log(LOG_LEVEL::Info, "Detected ", num_of_perf_cores, " performant cores");
// setting num_of_cores to floor(num_of_perf_cores / 2) + 1) because depending
// on cpu arch as when possible we want to leave at least 2 performant cores
// for other tasks (setting more actually results in drop of performance). For
// older devices (i.e. samsung s22) resolves to 3 cores, and for newer ones
// (like OnePlus 12) resolves to 4, which when benchamrked gives highest
// throughput.
auto num_of_cores = static_cast<uint32_t>(num_of_perf_cores / 2) + 1;
::executorch::extension::threadpool::get_threadpool()
->_unsafe_reset_threadpool(num_of_cores);
log(LOG_LEVEL::Info, "Configuring xnnpack for ", num_of_cores, " threads");
#endif
threads::utils::unsafeSetupThreadPool();
threads::GlobalThreadPool::initialize();
}

} // namespace rnexecutorch
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
#include <rnexecutorch/models/text_to_image/TextToImage.h>
#include <rnexecutorch/models/vertical_ocr/VerticalOCR.h>
#include <rnexecutorch/threads/GlobalThreadPool.h>

namespace rnexecutorch {

Expand Down Expand Up @@ -201,58 +202,60 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
// We need to dispatch a thread if we want the function to be
// asynchronous. In this thread all accesses to jsi::Runtime need to
// be done via the callInvoker.
std::thread([this, promise,
argsConverted = std::move(argsConverted)]() {
try {
if constexpr (std::is_void_v<decltype(std::apply(
std::bind_front(FnPtr, model),
argsConverted))>) {
// For void functions, just call the function and resolve with
// undefined
std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
callInvoker->invokeAsync([promise](jsi::Runtime &runtime) {
promise->resolve(jsi::Value::undefined());
});
} else {
// For non-void functions, capture the result and convert it
auto result = std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
// The result is copied. It should either be quickly copiable,
// or passed with a shared_ptr.
callInvoker->invokeAsync(
[promise, result](jsi::Runtime &runtime) {
promise->resolve(jsi_conversion::getJsiValue(
std::move(result), runtime));
});
}
} catch (const std::runtime_error &e) {
// This catch should be merged with the next two
// (std::runtime_error and jsi::JSError inherits from
// std::exception) HOWEVER react native has broken RTTI which
// breaks proper exception type checking. Remove when the
// following change is present in our version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
threads::GlobalThreadPool::detach(
[this, promise, argsConverted = std::move(argsConverted)]() {
try {
if constexpr (std::is_void_v<decltype(std::apply(
std::bind_front(FnPtr, model),
argsConverted))>) {
// For void functions, just call the function and resolve
// with undefined
std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
callInvoker->invokeAsync(
[promise](jsi::Runtime &runtime) {
promise->resolve(jsi::Value::undefined());
});
} else {
// For non-void functions, capture the result and convert
// it
auto result = std::apply(std::bind_front(FnPtr, model),
std::move(argsConverted));
// The result is copied. It should either be quickly
// copiable, or passed with a shared_ptr.
callInvoker->invokeAsync(
[promise, result](jsi::Runtime &runtime) {
promise->resolve(jsi_conversion::getJsiValue(
std::move(result), runtime));
});
}
} catch (const std::runtime_error &e) {
// This catch should be merged with the next two
// (std::runtime_error and jsi::JSError inherits from
// std::exception) HOWEVER react native has broken RTTI
// which breaks proper exception type checking. Remove when
// the following change is present in our version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
});
return;
} catch (const jsi::JSError &e) {
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
});
return;
} catch (const std::exception &e) {
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
});
return;
} catch (...) {
callInvoker->invokeAsync(
[promise]() { promise->reject("Unknown error"); });
return;
}
});
return;
} catch (const jsi::JSError &e) {
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
});
return;
} catch (const std::exception &e) {
callInvoker->invokeAsync([e = std::move(e), promise]() {
promise->reject(e.what());
});
return;
} catch (...) {
callInvoker->invokeAsync(
[promise]() { promise->reject("Unknown error"); });
return;
}
}).detach();
} catch (...) {
promise->reject("Couldn't parse JS arguments in a native function");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// GlobalThreadPool.h
#pragma once

#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <memory>
#include <mutex>
#include <optional>
#include <rnexecutorch/Log.h>
#include <rnexecutorch/threads/HighPerformanceThreadPool.h>

namespace rnexecutorch::threads {

class GlobalThreadPool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing this class and leaving all functions inside the threads::global_thread_pool namespace

Copy link
Member

@msluszniak msluszniak Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HighPerformanceThreadPool should be probably implemented similarly to this one:

class Singleton {
public:
    Singleton(const Singleton&) = delete;
    Singleton& operator=(const Singleton&) = delete;
    Singleton(Singleton&&) = delete;
    Singleton& operator=(Singleton&&) = delete;

    static Singleton& getInstance() {
        static std::once_flag initFlag;
        if (!instance) { 
            throw std::logic_error("Instance not initialized.");
        }
        return *instance;
    }

    static void initialize(int arg) {
        static std::once_flag initFlag;
        std::call_once(initFlag, [arg]() {
            instance = std::make_unique<Singleton>(arg);
        });
    }

private:
    int value;
    Singleton(int val) : value(val) {}

    static std::unique_ptr<Singleton> instance;
};

std::unique_ptr<Singleton> Singleton::instance = nullptr;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant GlobalThreadPool, right? HighPerformanceThreadPool was never supposed to be a singleton

Copy link
Member

@msluszniak msluszniak Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, HighPerformanceThreadPool is used only in the context of GlobalThreadPool. I guess Jakub meant that GlobalThreadPool is only a wrapper on a class that might be a singleton itself. If in the future there might be any other way to use HighPerformanceThreadPool the current design will make more sense to me then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my idea when creating this, HighPerformanceThreadPool is not meant to be a singleton, but GlobalThreadPool is. Also semantically it makes more sense to me to keep GlobalThreadPool as a class instead of a namespace even though it doesn't have any nonstatic members.

public:
GlobalThreadPool() = delete;
GlobalThreadPool(const GlobalThreadPool &) = delete;
GlobalThreadPool &operator=(const GlobalThreadPool &) = delete;
GlobalThreadPool(GlobalThreadPool &&) = delete;
GlobalThreadPool &operator=(GlobalThreadPool &&) = delete;

static HighPerformanceThreadPool &get() {
if (!instance) {
initialize();
}
return *instance;
}

static void initialize(std::optional<uint32_t> numThreads = std::nullopt,
ThreadConfig config = {}) {
std::call_once(initFlag, [&numThreads, config]() {
if (!numThreads) {
numThreads =
::executorch::extension::cpuinfo::get_num_performant_cores();
}

log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with",
numThreads, "threads");
instance = std::make_unique<HighPerformanceThreadPool>(numThreads.value(),
config);
});
}

// Convenience methods that mirror std::thread interface
template <typename Func, typename... Args>
static auto async(Func &&func, Args &&...args) {
return get().submit(std::forward<Func>(func), std::forward<Args>(args)...);
}

template <typename Func, typename... Args>
static auto async_high_priority(Func &&func, Args &&...args) {
return get().submitWithPriority(Priority::HIGH, std::forward<Func>(func),
std::forward<Args>(args)...);
}

// Fire and forget (like std::thread{}.detach())
template <typename Func, typename... Args>
static void detach(Func &&func, Args &&...args) {
get().submitDetached(std::forward<Func>(func), std::forward<Args>(args)...);
}

// Execute and wait (like std::thread{}.join())
template <typename Func, typename... Args>
static auto execute(Func &&func, Args &&...args) {
return get().execute(std::forward<Func>(func), std::forward<Args>(args)...);
}

static void shutdown() {
if (instance) {
instance->shutdown();
instance.reset();
}
}

private:
inline static std::unique_ptr<HighPerformanceThreadPool> instance;
inline static std::once_flag initFlag;
};

} // namespace rnexecutorch::threads
Loading