Skip to content

Commit 1fa6d8f

Browse files
souptcchentaMS
andauthored
support loading external execution provider from python frontend (#7332)
* initial dynamic load example * support load EP in the provider options * support dynamic load EP in orttrainer * split the provider interface; fix comments in pr * remove experiment code * add test * remove useless file * add test model file;fix linux brewak * fix linux build and missing file * fix python build * fix python build * fix python binding * fix python test * fix runtime path for posix env * exclude the shared library from minimal build * fix comments in pr; * seperate the provider shared lib loading * excluded from minimal / macos / ios build * skip copy the provider shared lib for minimal build and mac os * fix macos build * exclude the test for macos build * exclude from andorid build * exclude from web assembly build * enable the invalid ep test Co-authored-by: Cheng Tang <[email protected]>
1 parent 75e054c commit 1fa6d8f

25 files changed

+427
-28
lines changed

cmake/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,8 +1494,6 @@ foreach(provider_name ${ONNXRUNTIME_PROVIDER_NAMES})
14941494
endif()
14951495
endforeach()
14961496

1497-
1498-
14991497
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
15001498
list(APPEND onnxruntime_EXTERNAL_LIBRARIES log)
15011499
endif()

cmake/onnxruntime_providers.cmake

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,13 @@ if (onnxruntime_USE_CUDA)
354354
endif()
355355
endif()
356356

357-
if (onnxruntime_USE_TENSORRT OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO)
357+
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
358+
AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS"
359+
AND NOT (CMAKE_SYSTEM_NAME STREQUAL "Android")
360+
AND NOT onnxruntime_BUILD_WEBASSEMBLY)
358361
file(GLOB onnxruntime_providers_shared_cc_srcs CONFIGURE_DEPENDS
359-
"${ONNXRUNTIME_ROOT}/core/providers/shared/*.h"
360-
"${ONNXRUNTIME_ROOT}/core/providers/shared/*.cc"
362+
"${ONNXRUNTIME_ROOT}/core/providers/shared/*.h"
363+
"${ONNXRUNTIME_ROOT}/core/providers/shared/*.cc"
361364
)
362365

363366
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_shared_cc_srcs})
@@ -366,13 +369,13 @@ if (onnxruntime_USE_TENSORRT OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO
366369
set_target_properties(onnxruntime_providers_shared PROPERTIES LINKER_LANGUAGE CXX)
367370

368371
if(APPLE)
369-
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst")
372+
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst")
370373
elseif(UNIX)
371-
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections")
374+
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections")
372375
elseif(WIN32)
373-
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def")
376+
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def")
374377
else()
375-
message(FATAL_ERROR "onnxruntime_providers_shared unknown platform, need to specify shared library exports for it")
378+
message(FATAL_ERROR "onnxruntime_providers_shared unknown platform, need to specify shared library exports for it")
376379
endif()
377380

378381
install(TARGETS onnxruntime_providers_shared

cmake/onnxruntime_python.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,18 @@ add_custom_command(
349349
$<TARGET_FILE_DIR:${build_output_target}>
350350
)
351351

352+
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
353+
AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS"
354+
AND NOT (CMAKE_SYSTEM_NAME STREQUAL "Android")
355+
AND NOT onnxruntime_BUILD_WEBASSEMBLY)
356+
add_custom_command(
357+
TARGET onnxruntime_pybind11_state POST_BUILD
358+
COMMAND ${CMAKE_COMMAND} -E copy
359+
$<TARGET_FILE:onnxruntime_providers_shared>
360+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
361+
)
362+
endif()
363+
352364
if (onnxruntime_BUILD_UNIT_TESTS)
353365
add_custom_command(
354366
TARGET onnxruntime_pybind11_state POST_BUILD

cmake/onnxruntime_unittests.cmake

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,4 +1092,32 @@ if (onnxruntime_BUILD_JAVA)
10921092
set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
10931093
endif()
10941094

1095+
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
1096+
AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS"
1097+
AND NOT (CMAKE_SYSTEM_NAME STREQUAL "Android")
1098+
AND NOT onnxruntime_BUILD_WEBASSEMBLY)
1099+
file(GLOB_RECURSE test_execution_provider_srcs
1100+
"${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/*.h"
1101+
"${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/*.cc"
1102+
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
1103+
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
1104+
)
1105+
1106+
add_library(test_execution_provider SHARED ${test_execution_provider_srcs})
1107+
add_dependencies(test_execution_provider onnxruntime_providers_shared)
1108+
target_link_libraries(test_execution_provider PRIVATE onnxruntime_providers_shared)
1109+
target_include_directories(test_execution_provider PRIVATE $<TARGET_PROPERTY:onnx,INTERFACE_INCLUDE_DIRECTORIES>)
1110+
target_include_directories(test_execution_provider PRIVATE $<TARGET_PROPERTY:onnxruntime_common,INTERFACE_INCLUDE_DIRECTORIES>)
1111+
target_include_directories(test_execution_provider PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR})
1112+
if(APPLE)
1113+
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst")
1114+
elseif(UNIX)
1115+
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds -Xlinker --gc-sections")
1116+
elseif(WIN32)
1117+
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def")
1118+
else()
1119+
message(FATAL_ERROR "test_execution_provider unknown platform, need to specify shared library exports for it")
1120+
endif()
1121+
endif()
1122+
10951123
include(onnxruntime_fuzz_test.cmake)

onnxruntime/core/framework/provider_bridge_ort.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/framework/data_transfer_manager.h"
1111
#include "core/framework/execution_provider.h"
1212
#include "core/framework/kernel_registry.h"
13+
#include "core/framework/provider_bridge_ort.h"
1314
#include "core/framework/provider_shutdown.h"
1415
#include "core/graph/model.h"
1516
#include "core/platform/env.h"
@@ -18,6 +19,7 @@
1819
#include "core/session/abi_session_options_impl.h"
1920
#include "core/session/ort_apis.h"
2021

22+
2123
#ifdef USE_TENSORRT
2224
#include "core/providers/cuda/cuda_allocator.h"
2325
#include "core/providers/cuda/gpu_data_transfer.h"
@@ -591,6 +593,10 @@ struct ProviderSharedLibrary {
591593

592594
static ProviderSharedLibrary s_library_shared;
593595

596+
bool InitProvidersSharedLibrary(){
597+
return s_library_shared.Ensure();
598+
}
599+
594600
struct ProviderLibrary {
595601
ProviderLibrary(const char* filename) : filename_{filename} {}
596602
~ProviderLibrary() { /*assert(!handle_);*/
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
namespace onnxruntime {
7+
8+
bool InitProvidersSharedLibrary();
9+
10+
} // namespace onnxruntime

onnxruntime/core/platform/posix/env.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ class PosixEnv : public Env {
359359
return S_ISDIR(sb.st_mode);
360360
}
361361

362+
std::string GetRuntimePath() const override {
363+
return "./";
364+
}
365+
362366
common::Status CreateFolder(const std::string& path) const override {
363367
size_t pos = 0;
364368
do {
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
namespace onnxruntime {
5+
6+
struct Provider {
7+
// Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure
8+
virtual std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* /*provider_options*/) { return nullptr; }
9+
10+
// Old simple device_id API to create provider factories, currently used by DNNL And TensorRT
11+
virtual std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(int /*device_id*/) { return nullptr; }
12+
13+
virtual const void* GetInfo() { return nullptr; } // Returns a provider specific information interface if it exists
14+
virtual void Shutdown() = 0;
15+
};
16+
17+
} // namespace onnxruntime

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
// Public wrappers around internal ort interfaces (currently)
55
// In the future the internal implementations could derive from these to remove the need for the wrapper implementations
6+
#include "core/providers/shared_library/provider_host_api.h"
67

78
#define PROVIDER_DISALLOW_ALL(TypeName) \
89
TypeName() = delete; \
@@ -85,17 +86,6 @@ struct Node__EdgeIterator {
8586
virtual int GetDstArgIndex() const = 0;
8687
};
8788

88-
struct Provider {
89-
// Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure
90-
virtual std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* /*provider_options*/) { return nullptr; }
91-
92-
// Old simple device_id API to create provider factories, currently used by DNNL And TensorRT
93-
virtual std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(int /*device_id*/) { return nullptr; }
94-
95-
virtual const void* GetInfo() { return nullptr; } // Returns a provider specific information interface if it exists
96-
virtual void Shutdown() = 0;
97-
};
98-
9989
// There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function)
10090
// The function pointers are nicer in that they directly call the target function, but they cannot be used in cases where we're calling
10191
// a specific implementation of a virtual class member. Trying to get a pointer to member of a virtual function will return a thunk that

onnxruntime/python/onnxruntime_inference_collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def check_and_normalize_provider_args(providers, provider_options, available_pro
5050

5151
def set_provider_options(name, options):
5252
if name not in available_provider_names:
53-
raise ValueError("Specified provider '{}' is unavailable. Available providers: '{}'".format(
54-
name, ", ".join(available_provider_names)))
53+
warnings.warn("Specified provider '{}' is not in available provider names."
54+
"Available providers: '{}'".format(name, ", ".join(available_provider_names)))
5555

5656
if name in provider_name_to_options:
5757
warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name))

0 commit comments

Comments
 (0)