diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b0e93a39b3a4..b7359495f184 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -235,7 +235,6 @@ cc_library( "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", - "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_finalize'", "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", ], deps = [ @@ -352,9 +351,6 @@ cc_test( deps = [ ":mosaic_gpu_support", "//testing/base/public:gunit_main", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/log:globals", - "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index f1d7e1199ee6..29931d56e139 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -50,10 +50,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/driver_types.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Debug.h" @@ -132,8 +130,6 @@ limitations under the License. namespace { -using ::mosaic::gpu::NvshmemApi; - namespace ffi = xla::ffi; namespace se = stream_executor; @@ -511,25 +507,19 @@ absl::StatusOr, bool>> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - CUmodule module, MosaicHostFunc* host_launch, - bool is_comm_used) + MosaicHostFunc* host_launch, bool is_comm_used) : engine_(std::move(engine)), ctx_(ctx), - module_(module), host_launch_(host_launch), is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() const { + std::tuple GetHostLaunch() { return std::make_tuple(ctx_, host_launch_, is_comm_used_); } - CUmodule module() const { return module_; } - bool is_comm_used() const { return is_comm_used_; } - private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly - CUmodule module_; MosaicHostFunc* host_launch_; bool is_comm_used_; }; @@ -569,7 +559,7 @@ absl::StatusOr> GetHostAndInitFuncNames( return std::make_pair(host_func_name, init_func_name); } -absl::StatusOr CompileAndInit(absl::string_view module) { +absl::StatusOr CompileAndInit(llvm::StringRef module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); context.allowUnregisteredDialects(true); InitContext(&context); @@ -610,30 +600,13 @@ absl::StatusOr CompileAndInit(absl::string_view module) { void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, - reinterpret_cast(module_ptr), reinterpret_cast(*host), is_comm_used); } -absl::Status Unload(const CompiledKernel& kernel, CUcontext ctx) { - CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(ctx)); - if (kernel.is_comm_used()) { - if (NvshmemApi::Default().cumodule_finalize(kernel.module()) != - NVSHMEM_SUCCESS) { - return absl::InternalError("nvshmemx_cumodule_finalize failed"); - } - } - CUDA_RETURN_IF_ERROR(cuModuleUnload(kernel.module())); - CUcontext unused; - CUDA_RETURN_IF_ERROR(cuCtxPopCurrent(&unused)); - return absl::OkStatus(); -} - using KernelHash = std::array; +using CacheKey = std::pair; -// A reference counted cache of compiled and loaded kernels. -class KernelCache { - public: - // A global cache of compiled and loaded kernels. +struct KernelCache { static KernelCache& Global() { static absl::NoDestructor cache; return *cache; @@ -644,89 +617,80 @@ class KernelCache { KernelCache(const KernelCache&) = delete; KernelCache(KernelCache&&) = delete; - // Holds a reference to a compiled and loaded kernel. - // Unload the kernel when the handle is destroyed. - class KernelHandle { - public: - KernelHandle(CompiledKernel kernel, CUcontext ctx) - : kernel_(std::move(kernel)), ctx_(ctx) {} - ~KernelHandle() { - CHECK_OK(Unload(kernel_, ctx_)); - VLOG(5) << "Successfully unloaded GPU module"; - } - const CompiledKernel* kernel() const { return &kernel_; } - - private: - CompiledKernel kernel_; - CUcontext ctx_; // The CUDA context in which the kernel was loaded. - }; - - // Compile and load the given module in the current CUDA context. - absl::StatusOr> CompileAndInit( - const KernelHash& kernel_hash, absl::string_view module) { - CUcontext ctx; - CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); - CacheKey key(kernel_hash, reinterpret_cast(ctx)); - absl::MutexLock lock(mutex_); - if (auto it = kernels_.find(key); it != kernels_.end()) { - std::shared_ptr handle = it->second.lock(); - if (handle) { - return handle; - } - } - // Kernel not found or has expired, create a new value. - tsl::profiler::TraceMe trace("Compilation cache miss"); - TF_ASSIGN_OR_RETURN(CompiledKernel compiled, ::CompileAndInit(module)); - VLOG(5) << "Successfully compiled and initialized Mosaic GPU kernel"; - auto handle = std::make_shared(std::move(compiled), ctx); - kernels_[key] = handle; - return handle; - } - - private: - using CacheKey = std::pair; - absl::Mutex mutex_; - absl::flat_hash_map> kernels_ - ABSL_GUARDED_BY(mutex_); + absl::Mutex mutex; + absl::flat_hash_map kernels ABSL_GUARDED_BY(mutex); }; -// Tracks the compiled and loaded kernels for a given custom call. -// There is a single global cache in the process and a process can have -// multiple devices, each of which must load/unload the module. We expect each -// device/module pair to have a unique cache key. -class CustomCallResources { - public: - CustomCallResources() = default; +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +absl::StatusOr CachedCompileAndInit(CacheKey key, + llvm::StringRef module) { + KernelCache& cache = KernelCache::Global(); - const CompiledKernel* KernelForDevice(int32_t device_ordinal) const { - absl::MutexLock lock(mutex_); - return kernels_.at(device_ordinal)->kernel(); + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(cache.mutex); + auto it = cache.kernels.find(key); + if (ABSL_PREDICT_TRUE(it != cache.kernels.end())) return &it->second; } - void AddKernel(int32_t device_ordinal, - std::shared_ptr kernel) { - absl::MutexLock lock(mutex_); - kernels_[device_ordinal] = std::move(kernel); + absl::MutexLock lock(cache.mutex); + // We released the reader lock, another thread might have initialized it. + if (cache.kernels.find(key) == cache.kernels.end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); + } + cache.kernels.insert_or_assign(key, std::move(*compiled)); } + return &cache.kernels.at(key); +} - private: - mutable absl::Mutex mutex_; - absl::flat_hash_map> - kernels_ ABSL_GUARDED_BY(mutex_); -}; - -absl::StatusOr> InstantiateResources() { - // TODO(b/466097203): Ideally we would compile the module here. - // Sadly we need to acquire a lock on LLVM command line options which is - // already held by XLA causing a deadlock. - // See `GpuCompiler::CompileToBackendResult`. - return std::make_unique(); +// TODO(b/464203195): Backward-compatible version using the legacy FFI +// API. Remove once backward compatibility window has passed. +void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + if (reinterpret_cast(opaque) % alignof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(opaque); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { + XlaCustomCallStatusSetFailure(status, + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); + return; + } + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } -absl::Status InitializeResources(int32_t device_ordinal, - CustomCallResources* resources, - std::string_view kernel_hash, - std::string_view module, bool) { +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, + "CUDA"); + +absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + std::string_view kernel_hash, + std::string_view module, + bool use_custom_barrier) { + if (use_custom_barrier) { + return absl::UnimplementedError("Custom barrier is not supported on GPUs."); + } if (kernel_hash.size() != sizeof(KernelHash)) { return absl::InvalidArgumentError( absl::StrFormat("Kernel hash size is %d bytes, expected %d bytes", @@ -734,23 +698,11 @@ absl::Status InitializeResources(int32_t device_ordinal, } KernelHash hash; std::memcpy(hash.data(), kernel_hash.data(), sizeof(KernelHash)); - TF_ASSIGN_OR_RETURN( - std::shared_ptr handle, - KernelCache::Global().CompileAndInit(hash, module)); - resources->AddKernel(device_ordinal, std::move(handle)); - return absl::OkStatus(); -} - -absl::Status MosaicGpuExecute(cudaStream_t stream, int32_t device_ordinal, - ffi::RemainingArgs inputs, - ffi::RemainingRets results, - CustomCallResources* resources, std::string_view, - std::string_view, bool use_custom_barrier) { - if (use_custom_barrier) { - return absl::UnimplementedError("Custom barrier is not supported on GPUs."); - } - const CompiledKernel* compiled_kernel = - resources->KernelForDevice(device_ordinal); + CUcontext ctx; + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, + CachedCompileAndInit(key, module)); auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); bool is_comm_used = std::get<2>(ctx_kernel_comm); @@ -778,30 +730,17 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, int32_t device_ordinal, void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; if (is_comm_used) { - NvshmemApi::Default().barrier_all_on_stream(stream); + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(stream); } std::get<1>(ctx_kernel_comm)(args); return absl::OkStatus(); } -XLA_FFI_DEFINE_HANDLER(kInstantiateResources, InstantiateResources, - ffi::Ffi::BindInstantiate()); - -XLA_FFI_DEFINE_HANDLER(kInitializeResources, InitializeResources, - ffi::Ffi::BindInitialize() - .Ctx() - .Ctx>() - .Attr("kernel_hash") - .Attr("module") - .Attr("use_custom_barrier")); - XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, ffi::Ffi::Bind() .Ctx>() - .Ctx() .RemainingArgs() .RemainingRets() - .Ctx>() .Attr("kernel_hash") .Attr("module") .Attr("use_custom_barrier"), @@ -809,78 +748,12 @@ XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", { - /*instantiate=*/kInstantiateResources, + /*instantiate=*/nullptr, /*prepare=*/nullptr, - /*initialize=*/kInitializeResources, + /*initialize=*/nullptr, /*execute=*/kMosaicGpuExecute, }); -// Cache compiled and loaded kernels in the current CUDA context. -// Loaded kernels are never unloaded. -absl::StatusOr LegacyCachedCompileAndInit( - const KernelHash& kernel_hash, absl::string_view module) { - using CacheKey = std::pair; - struct LegacyCache { - absl::Mutex mutex; - absl::flat_hash_map kernels - ABSL_GUARDED_BY(mutex); - }; - static absl::NoDestructor cache; - - CUcontext ctx; - CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); - - CacheKey key(kernel_hash, reinterpret_cast(ctx)); - { - // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(cache->mutex); - auto it = cache->kernels.find(key); - if (ABSL_PREDICT_TRUE(it != cache->kernels.end())) return &it->second; - } - - absl::MutexLock lock(cache->mutex); - // We released the reader lock, another thread might have initialized it. - if (cache->kernels.find(key) == cache->kernels.end()) { - tsl::profiler::TraceMe trace("Compilation cache miss"); - auto compiled = CompileAndInit(module); - if (!compiled.ok()) { - return compiled.status(); - } - cache->kernels.insert_or_assign(key, std::move(*compiled)); - } - return &cache->kernels.at(key); -} - -// TODO(b/464203195): Backward-compatible version using the legacy FFI -// API. Remove once backward compatibility window has passed. -void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - if (reinterpret_cast(opaque) % alignof(KernelHash)) { - fprintf(stderr, "Misaligned opaque pointer\n"); - abort(); - } - auto hash = *reinterpret_cast(opaque); - auto compiled_kernel = - LegacyCachedCompileAndInit(hash, opaque + sizeof(KernelHash)); - if (!compiled_kernel.ok()) { - XlaCustomCallStatusSetFailure(status, - compiled_kernel.status().message().data(), - compiled_kernel.status().message().size()); - return; - } - auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); - bool is_comm_used = std::get<2>(ctx_kernel_comm); - void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; - if (is_comm_used) { - mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( - reinterpret_cast(stream)); - } - std::get<1>(ctx_kernel_comm)(args); -} - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, - "CUDA"); - } // namespace extern "C" { diff --git a/jaxlib/mosaic/gpu/custom_call_test.cc b/jaxlib/mosaic/gpu/custom_call_test.cc index d3426c0fd71a..e4756a394325 100644 --- a/jaxlib/mosaic/gpu/custom_call_test.cc +++ b/jaxlib/mosaic/gpu/custom_call_test.cc @@ -19,9 +19,6 @@ limitations under the License. #include #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" -#include "absl/log/scoped_mock_log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" @@ -39,7 +36,6 @@ limitations under the License. namespace { using ::absl_testing::IsOk; -using ::testing::_; absl::Status ExecuteSync(xla::PjRtLoadedExecutable* executable) { std::vector no_buffers; @@ -69,16 +65,16 @@ ENTRY main { custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI })"; - ASSERT_OK_AND_ASSIGN(auto module, - xla::ParseAndReturnUnverifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(kHloModule)); std::string tmp_path = testing::TempDir(); tsl::setenv("XLA_FLAGS", absl::StrCat("--xla_dump_to=", tmp_path).c_str(), /*overwrite=*/true); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), /*options=*/{})); @@ -138,145 +134,4 @@ TEST(CustomCallTest, LegacyCustomCall) { EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); } -absl::string_view TestMGPUHloModule() { - // Dumped from the following JAX program: - // - // ``` - // @functools.partial( - // plgpu.pallas_call, - // out_shape=jax.ShapeDtypeStruct((), jnp.int32), - // out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - // ) - // def kernel(o_ref): - // o_ref[...] = jnp.array(42) - // ``` - return R"hlo( - HloModule test - - ENTRY main { - ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI, backend_config={kernel_hash = "\90\C7\1F$\92=c\9D\E4\A8\15\B1Y\9B.\02\B4\B0\0B\16\C5Ol\D4\ED\CDdA-\C9\D77", module = "ML\EFR\01MLIR\00\01O\0D\01\03\05\07\09\0B\01\03\0D\037\0F\11\13\15\17\19\1B\1D\1F!#%')+-/13579;=?AC\03\12\02\C9\1D\01\BB\0F\13\0B\0B\0F\13\13\13\13\0B\07\0B\0B\13\13\0B\0F\13\13\13e\1B\0B\0F\0B\0B#\0B\0B\0B\0B;\0B\0B\0B\0B\0B\0B\0B#\0B\0B\07\0B\13\0F\0F\13\13\13\0F\13\13\0B\133\133\133U\1B\0B\C3\0B\13\13\13\13\13\13\13\13\13\17\17\17\0B\0F\1F\0F\0B\0B\13\13\0B\0B\0F\0B\0F\0B\17\0B\05\03a\07\09y111\09\03Y\0B\03U\01\15\0F\07\0F\0B\0B\1B/\17\13;\05\07)yQ\07\03E\02\AE\0A\1D3\15\03\03\9B\C5\05E\05G\11\05\01\03\03\07]\03\03\19\BF\03\03\19\C1\03\03\19\C3\05I\1F\05K\05M\03\03\07\9D\03\03\A5\09\05O\11\01\11\03\03\07\9F\03\03\07\A1\03\03\A3\C7affine_map<(d0) -> (d0)>\00\03\05-/\131\05Q\11\05\19\05S\05U\03\07\1F7\139;=\0D\0D\05W\05Y\05[\03\0DA!CEG\BB\13IK\09M\09\05]\05_\0D\19\05a\05c\05e\05g\03\07\1FQSU\13W\0D\0F\05i\0F\05k\03\03\07[\11\01\A9\11\01\01\03\03\07a\11\03\02\04\03\03\07e\11\03\05\03\03\07\09\03\03k\09\05m\03\03\17o#\05\03\11\00\00\00\00\00\00\00\00\03\03\17s#\05\03\11\01\00\00\00\00\00\00\00\03\03\17w#\05\03\11\02\00\00\00\00\00\00\00affine_map<() -> ()>\00\03\05}\7F\81\09\05o#\01\17Y\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\05q\17\05%O\17\05%]\17\05%k\17\05%\E1\17\05%\EF\17\05%\FD\17\05%\81\17\05%\9B\17\05%\B5\17\05%&\02\17\05%f\02\17\05%\9E\02\05s\11\01\15\11\01\D0\FF\FF\FF?\11\01}\05u\05w\03\03\07!\03\03\AB\AD\05y\01\01\1D\B1\B3\05{\1D\B5\B7\05}\17\B9\06\03\0D\05\7F#llvm.linkage\00#gpu.address_space\00#gpu\00#gpu\00#gpu\00#arith.overflow\00#nvvm\00\01\02\02\03\01\02\04\01\09\01A\17\BD\03\01\09)\05\11\15\15\05\05\15\15\05\15\01\05\05\15\15\01\15\01\01y\17\BD\03\00\FF\FF\FF\FF\FF\FF\FF\FF\09)!llvm.ptr\00!llvm.struct<(ptr, ptr, i64)>\00!llvm.array<0 x i8>\00!gpu.async.token\00\04Z\0C\05\01\11\01+\07\03\01\0D\17\11\015\07\01\1F\11\01?\07\01\17\11\01O\07\03\1F;\05\15\01\15\01\05\03\15Y\03\01\05\03\15\0B\03\01\05\03\01_\03\03\05\03\15c\03\03!\03\01g\03\05#\02\01\03\17\0F\06\01\03\1B\03\01%\07\01i\03\15\03\03\11\07\01m\03\17\05\0F\13\11\07\01q\03\17\05\15\13\11\07\01u\03\17\05\17\0D\0F\06\01\03\11\03\19'\17\01{\03\1B\11\11\0B\0B\0B\09\0B\0B\07\05\03\C1\C6\02\19\03\83\03\85\03\87\03\89\03\8B\03\8D\03\8F\03\91\03\93\03\95\03\97\03\99\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\039\0B\03\01\0D\03\03\03\06\01\03\01\03=\09\03\01\0F\03\03\03\06\01\03\01\03A\07\07\01\03\03\01\05C?\0D\07\01\03\03\01\05;E\0B\03\01\0F\03\03\03\06\01\03\01\03I\07\07\01\03\03\01\05?K\09\03\01\11\03\03\03\06\01\03\01\03O\07\07\01\03\03\01\05QM\0D\07\01\03\03\01\05GS\0B\03\01\11\03\03\03\06\01\03\01\03W\07\07\01\03\03\01\05MY\05\03\01\1B\03\01\13\06\01\03\01\05U]\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09a_ce\05\03\01\0B\03\01\15\07\01\1D\03\07\05gi\1D\06\01\03\07\05k7\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\03q\0B\03\01\0D\03\03\03\06\01\03\01\03u\09\03\01\0F\03\03\03\06\01\03\01\03y\07\07\01\03\03\01\05{w\0D\07\01\03\03\01\05s}\0B\03\01\0F\03\03\03\06\01\03\01\03\81\07\07\01\03\03\01\05w\83\09\03\01\11\03\03\03\06\01\03\01\03\87\07\07\01\03\03\01\05\89\85\0D\07\01\03\03\01\05\7F\8B\0B\03\01\11\03\03\03\06\01\03\01\03\8F\07\07\01\03\03\01\05\85\91\05\03\01\1B\03\01\13\06\01\03\01\05\8D\95\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09\99\97\9B\9D\05\03\01\A7\03\01+\06\01\03\01\05\9F\A1\05\03\01\0B\03\01\15\07\01\1D\03\07\05\A3\A5\1D\06\01\03\07\05\A7o\09\03\01\0D\03\03\03\06\01\03\01\03\AB\0B\03\01\0D\03\03\03\06\01\03\01\03\AF\09\03\01\0F\03\03\03\06\01\03\01\03\B3\07\07\01\03\03\01\05\B5\B1\0D\07\01\03\03\01\05\AD\B7\0B\03\01\0F\03\03\03\06\01\03\01\03\BB\07\07\01\03\03\01\05\B1\BD\09\03\01\11\03\03\03\06\01\03\01\03\C1\07\07\01\03\03\01\05\C3\BF\0D\07\01\03\03\01\05\B9\C5\0B\03\01\11\03\03\03\06\01\03\01\03\C9\07\07\01\03\03\01\05\BF\CB\05\03\01\1B\03\01\13\06\01\03\01\05\C7\CF\05\03\01\0B\03\01\15\07\01\1D\03\07\05\D1\D3-\02\01\03\13\03\06\01\03\03\03\07/\06\01\03\0B\05\D7\D9\0F\07\01\A9\03\0B\03\DB1\00\013\00\015\04\AF\05\05\1B7\00\01)\00\01\06\03\01\05\01\00\9E\0E\81g\0B\0D\17\15\0B\1D/)\13%-\19\1B\1F\11\19\17\11\1F3\19\0F5\1D\15\13\13\0D\05\1F\1B\193\195\19\19\17\15!'#\17\1F!\15\17\19#G\17\1D\1D\17\1F#\0F\0B\0D\09\0B%\11builtin\00stable_mosaic_gpu\00llvm\00gpu\00arith\00nvvm\00module\00arith.index_cast\00arith.constant\00arith.muli\00gpu.thread_id\00gpu.block_dim\00arith.addi\00builtin.unrealized_conversion_cast\00llvm.insertvalue\00arith.shrui\00arith.cmpi\00func.func\00nvvm.elect.sync\00nvvm.shfl.sync\00arith.andi\00llvm.mlir.global\00llvm.mlir.constant\00llvm.mlir.undef\00llvm.load\00gpu.launch\00func.return\00arith.remui\00gpu.dynamic_shared_memory\00memref.view\00nvvm.fence.mbarrier.init\00gpu.barrier\00memref.store\00gpu.terminator\00-\00value\00sym_name\00position\00dimension\00function_type\00stable_mosaic_gpu.version\00kernel\00pallas_call\00mosaic_gpu_init_tma_desc\00sym_visibility\00private\00addr_space\00global_type\00linkage\00global_scratch\00unnamed_addr\00visibility_\00llvm.emit_c_interface\00kernel_mosaic_gpu\00ordering\00operandSegmentSizes\00workgroup_attributions\00overflowFlags\00kind\00predicate\00transforms\00swap:\00swap\00third_party/py/jax/tests/pallas/mosaic_gpu_test.py\00", use_custom_barrier = false} - } - )hlo"; -} - -TEST(CustomCallTest, UnloadGPUModule) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - absl::SetVLogLevel("custom_call", 5); - { - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(1); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } - - { - // The second execution the compilation should be cached. - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(0); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } - - { - // GPU module should be unloaded when the executable is destroyed. - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable.reset(); - } -} - -TEST(CustomCallTest, GPUModuleIsOnlyUnloadedWhenAllExecutablesAreDestroyed) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable1, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable2, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - EXPECT_THAT(ExecuteSync(executable1.get()), IsOk()); - EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); - - absl::SetVLogLevel("custom_call", 5); - { - // executable2 still holds a reference to the GPU module. - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(0); - log.StartCapturingLogs(); - executable1.reset(); - } - EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); - { - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable2.reset(); - } -} - -TEST(CustomCallTest, GPUModuleIsRecompiledAfterExpiration) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - - { - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable.reset(); - } - - ASSERT_OK_AND_ASSIGN( - executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - { - // executable was destroyed and the module was unloaded. We re-compile the - // kernel. - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(1); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } -} - } // namespace diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h index 267f17de8324..dbd11aa1d373 100644 --- a/jaxlib/mosaic/gpu/nvshmem.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -54,11 +54,6 @@ class NvshmemApi { return nvshmemx_cumodule_init(module); } - int cumodule_finalize(CUmodule module) { - std::lock_guard lock(mutex_); - return nvshmemx_cumodule_finalize(module); - } - void barrier_all_on_stream(cudaStream_t stream) { nvshmemx_barrier_all_on_stream(stream); } @@ -83,13 +78,11 @@ class NvshmemApi { NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) NVSHMEM_SET_FN(nvshmemx_cumodule_init) - NVSHMEM_SET_FN(nvshmemx_cumodule_finalize) NVSHMEM_SET_FN(nvshmemx_init_status) } int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); int (*nvshmemx_cumodule_init)(CUmodule); - int (*nvshmemx_cumodule_finalize)(CUmodule); int (*nvshmemx_init_status)(); std::mutex mutex_;