@@ -13,101 +13,19 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include < stdexcept>
17- #include < utility>
18-
19- #include " absl/container/flat_hash_map.h"
20- #include " absl/status/statusor.h"
21- #include " absl/strings/str_format.h"
2216#include " nanobind/nanobind.h"
23- #include " nanobind/stl/pair.h" // IWYU pragma: keep
24- #include " jaxlib/gpu/gpu_kernel_helpers.h"
25- #include " jaxlib/gpu/solver_handle_pool.h"
26- #include " jaxlib/gpu/solver_kernels.h"
2717#include " jaxlib/gpu/solver_kernels_ffi.h"
2818#include " jaxlib/gpu/vendor.h"
2919#include " jaxlib/kernel_nanobind_helpers.h"
30- #include " xla/tsl/python/lib/core/numpy.h"
3120
3221namespace jax {
3322namespace JAX_GPU_NAMESPACE {
3423namespace {
3524
3625namespace nb = nanobind;
3726
38- // Converts a NumPy dtype to a Type.
39- SolverType DtypeToSolverType (const dtype& np_type) {
40- static auto * types =
41- new absl::flat_hash_map<std::pair<char , int >, SolverType>({
42- {{' f' , 4 }, SolverType::F32},
43- {{' f' , 8 }, SolverType::F64},
44- {{' c' , 8 }, SolverType::C64},
45- {{' c' , 16 }, SolverType::C128},
46- });
47- auto it = types->find ({np_type.kind (), np_type.itemsize ()});
48- if (it == types->end ()) {
49- nb::str repr = nb::repr (np_type);
50- throw std::invalid_argument (
51- absl::StrFormat (" Unsupported dtype %s" , repr.c_str ()));
52- }
53- return it->second ;
54- }
55-
56- #ifdef JAX_GPU_CUDA
57-
58- // csrlsvqr: Linear system solve via Sparse QR
59-
60- // Returns a descriptor for a csrlsvqr operation.
61- nb::bytes BuildCsrlsvqrDescriptor (const dtype& dtype, int n, int nnzA,
62- int reorder, double tol) {
63- SolverType type = DtypeToSolverType (dtype);
64- return PackDescriptor (CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
65- }
66-
67- #endif // JAX_GPU_CUDA
68-
69- // Returns the workspace size and a descriptor for a geqrf operation.
70- std::pair<int , nb::bytes> BuildSytrdDescriptor (const dtype& dtype, bool lower,
71- int b, int n) {
72- SolverType type = DtypeToSolverType (dtype);
73- auto h = SolverHandlePool::Borrow (/* stream=*/ nullptr );
74- JAX_THROW_IF_ERROR (h.status ());
75- auto & handle = *h;
76- int lwork;
77- gpusolverFillMode_t uplo =
78- lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
79- switch (type) {
80- case SolverType::F32:
81- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnSsytrd_bufferSize (
82- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
83- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
84- break ;
85- case SolverType::F64:
86- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnDsytrd_bufferSize (
87- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
88- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
89- break ;
90- case SolverType::C64:
91- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnChetrd_bufferSize (
92- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
93- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
94- break ;
95- case SolverType::C128:
96- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnZhetrd_bufferSize (
97- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
98- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
99- break ;
100- }
101- return {lwork, PackDescriptor (SytrdDescriptor{type, uplo, b, n, n, lwork})};
102- }
103-
10427nb::dict Registrations () {
10528 nb::dict dict;
106- dict[JAX_GPU_PREFIX " solver_sytrd" ] = EncapsulateFunction (Sytrd);
107-
108- #ifdef JAX_GPU_CUDA
109- dict[" cusolver_csrlsvqr" ] = EncapsulateFunction (Csrlsvqr);
110- #endif // JAX_GPU_CUDA
11129
11230 dict[JAX_GPU_PREFIX " solver_getrf_ffi" ] = EncapsulateFfiHandler (GetrfFfi);
11331 dict[JAX_GPU_PREFIX " solver_geqrf_ffi" ] = EncapsulateFfiHandler (GeqrfFfi);
@@ -127,12 +45,7 @@ nb::dict Registrations() {
12745}
12846
12947NB_MODULE (_solver, m) {
130- tsl::ImportNumpy ();
13148 m.def (" registrations" , &Registrations);
132- m.def (" build_sytrd_descriptor" , &BuildSytrdDescriptor);
133- #ifdef JAX_GPU_CUDA
134- m.def (" build_csrlsvqr_descriptor" , &BuildCsrlsvqrDescriptor);
135- #endif // JAX_GPU_CUDA
13649}
13750
13851} // namespace
0 commit comments