diff --git a/CHANGELOG.md b/CHANGELOG.md index f8aaae542ec5..409f5bcd5524 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum * Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595) * Added implementation of `dpnp.scipy.linalg.lu` (SciPy-compatible) [#2787](https://github.com/IntelPython/dpnp/pull/2787) * Added support for ndarray subclassing via `dpnp.ndarray.view` method with `type` parameter [#2815](https://github.com/IntelPython/dpnp/issues/2815) +* Added implementation of `dpnp.scipy.sparse.linalg import LinearOperator, cg, gmres, minres` [#2841](https://github.com/IntelPython/dpnp/pull/2841) ### Changed diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 956ff6db0133..c4a7287447ba 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -50,6 +50,7 @@ requirements: - {{ pin_compatible('onemkl-sycl-lapack', min_pin='x.x', max_pin='x') }} - {{ pin_compatible('onemkl-sycl-rng', min_pin='x.x', max_pin='x') }} - {{ pin_compatible('onemkl-sycl-vm', min_pin='x.x', max_pin='x') }} + - {{ pin_compatible('onemkl-sycl-sparse', min_pin='x.x', max_pin='x') }} - numpy - intel-gpu-ocl-icd-system diff --git a/dpnp/CMakeLists.txt b/dpnp/CMakeLists.txt index 6850b799735c..cfced6b4ae44 100644 --- a/dpnp/CMakeLists.txt +++ b/dpnp/CMakeLists.txt @@ -100,6 +100,7 @@ add_subdirectory(backend/extensions/statistics) add_subdirectory(backend/extensions/ufunc) add_subdirectory(backend/extensions/vm) add_subdirectory(backend/extensions/window) +add_subdirectory(backend/extensions/sparse) add_subdirectory(dpnp_algo) add_subdirectory(dpnp_utils) diff --git a/dpnp/backend/extensions/sparse/CMakeLists.txt b/dpnp/backend/extensions/sparse/CMakeLists.txt new file mode 100644 index 000000000000..5ec461e316df --- /dev/null +++ b/dpnp/backend/extensions/sparse/CMakeLists.txt @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +set(python_module_name _sparse_impl) +set(_module_src + ${CMAKE_CURRENT_SOURCE_DIR}/sparse_py.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp +) + +pybind11_add_module(${python_module_name} MODULE ${_module_src}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) + +if(_dpnp_sycl_targets) + # make fat binary + target_compile_options( + ${python_module_name} + PRIVATE ${_dpnp_sycl_target_compile_options} + ) + target_link_options(${python_module_name} PRIVATE ${_dpnp_sycl_target_link_options}) +endif() + +if(WIN32) + if(${CMAKE_VERSION} VERSION_LESS "3.27") + # this is a work-around for target_link_options inserting option after -link option, cause + # linker to ignore it. + set(CMAKE_CXX_LINK_FLAGS + "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel" + ) + endif() +endif() + +set_target_properties( + ${python_module_name} + PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON +) + +target_include_directories( + ${python_module_name} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common +) + +# treat below headers as system to suppress the warnings there during the build +target_include_directories( + ${python_module_name} + SYSTEM + PRIVATE ${SYCL_INCLUDE_DIR} ${Dpctl_INCLUDE_DIRS} ${Dpctl_TENSOR_INCLUDE_DIR} +) + +if(WIN32) + target_compile_options( + ${python_module_name} + PRIVATE /clang:-fno-approx-func /clang:-fno-finite-math-only + ) +else() + target_compile_options( + ${python_module_name} + PRIVATE -fno-approx-func -fno-finite-math-only + ) +endif() + +target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel) + +if(DPNP_GENERATE_COVERAGE) + target_link_options( + ${python_module_name} + PRIVATE -fprofile-instr-generate -fcoverage-mapping + ) +endif() + +if(_ues_onemath) + target_link_libraries(${python_module_name} PRIVATE ${ONEMATH_LIB}) + target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMATH) + if(_ues_onemath_cuda) + target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMATH_CUSPARSE) + endif() +else() + target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::SPARSE) +endif() + +if(DPNP_WITH_REDIST) + set_target_properties( + ${python_module_name} + PROPERTIES INSTALL_RPATH "$ORIGIN/../../../../../../" + ) +endif() + +install(TARGETS ${python_module_name} DESTINATION "dpnp/backend/extensions/sparse") diff --git a/dpnp/backend/extensions/sparse/gemv.cpp b/dpnp/backend/extensions/sparse/gemv.cpp new file mode 100644 index 000000000000..cd94c0143ce2 --- /dev/null +++ b/dpnp/backend/extensions/sparse/gemv.cpp @@ -0,0 +1,400 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include +#include +#include +#include +#include + +#include + +// utils extension header +#include "ext/common.hpp" + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_utils.hpp" + +#include "gemv.hpp" +#include "types_matrix.hpp" + +namespace dpnp::extensions::sparse +{ + +namespace mkl_sparse = oneapi::mkl::sparse; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +using ext::common::init_dispatch_table; + +// --------------------------------------------------------------------------- +// Dispatch table types +// --------------------------------------------------------------------------- + +/** + * init_impl: builds the matrix_handle, calls set_csr_data + optimize_gemv. + * Returns (handle_ptr, optimize_event). + * All CSR arrays are *not* copied -- they must stay alive until release. + */ +typedef std::pair (*gemv_init_fn_ptr_t)( + sycl::queue &, + oneapi::mkl::transpose, + const char *, // row_ptr (typeless) + const char *, // col_ind (typeless) + const char *, // values (typeless) + std::int64_t, // num_rows + std::int64_t, // num_cols + std::int64_t, // nnz + const std::vector &); + +/** + * compute_impl: fires sparse::gemv using a pre-built handle. + * Returns the gemv event directly -- no host_task wrapping. + */ +typedef sycl::event (*gemv_compute_fn_ptr_t)( + sycl::queue &, + oneapi::mkl::sparse::matrix_handle_t, + oneapi::mkl::transpose, + double, // alpha (cast to Tv inside) + const char *, // x (typeless) + double, // beta (cast to Tv inside) + char *, // y (typeless, writable) + const std::vector &); + +// Init dispatch: 2-D on (Tv, Ti). +static gemv_init_fn_ptr_t gemv_init_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + +// Compute dispatch: 1-D on Tv. The index type is baked into the handle, +// so compute doesn't need it. +static gemv_compute_fn_ptr_t + gemv_compute_dispatch_table[dpctl_td_ns::num_types]; + +// --------------------------------------------------------------------------- +// Per-type init implementation +// --------------------------------------------------------------------------- + +template +static std::pair + gemv_init_impl(sycl::queue &exec_q, + oneapi::mkl::transpose mkl_trans, + const char *row_ptr_data, + const char *col_ind_data, + const char *values_data, + std::int64_t num_rows, + std::int64_t num_cols, + std::int64_t nnz, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const Ti *row_ptr = reinterpret_cast(row_ptr_data); + const Ti *col_ind = reinterpret_cast(col_ind_data); + const Tv *values = reinterpret_cast(values_data); + + mkl_sparse::matrix_handle_t spmat = nullptr; + mkl_sparse::init_matrix_handle(&spmat); + + auto ev_set = mkl_sparse::set_csr_data( + exec_q, spmat, num_rows, num_cols, nnz, oneapi::mkl::index_base::zero, + const_cast(row_ptr), const_cast(col_ind), + const_cast(values), depends); + + sycl::event ev_opt; + try { + ev_opt = mkl_sparse::optimize_gemv(exec_q, mkl_trans, spmat, {ev_set}); + } catch (oneapi::mkl::exception const &e) { + mkl_sparse::release_matrix_handle(exec_q, &spmat, {}); + throw std::runtime_error( + std::string("sparse_gemv_init: MKL exception in optimize_gemv: ") + + e.what()); + } catch (sycl::exception const &e) { + mkl_sparse::release_matrix_handle(exec_q, &spmat, {}); + throw std::runtime_error( + std::string("sparse_gemv_init: SYCL exception in optimize_gemv: ") + + e.what()); + } + + auto handle_ptr = reinterpret_cast(spmat); + return {handle_ptr, ev_opt}; +} + +// --------------------------------------------------------------------------- +// Per-type compute implementation +// --------------------------------------------------------------------------- + +template +static sycl::event gemv_compute_impl(sycl::queue &exec_q, + mkl_sparse::matrix_handle_t spmat, + oneapi::mkl::transpose mkl_trans, + double alpha_d, + const char *x_data, + double beta_d, + char *y_data, + const std::vector &depends) +{ + // For complex Tv the single-arg constructor sets imag to zero. + // Solvers use alpha=1, beta=0 so this is exact; other callers + // passing complex scalars via this path will lose the imag + // component silently. + const Tv alpha = static_cast(alpha_d); + const Tv beta = static_cast(beta_d); + + const Tv *x = reinterpret_cast(x_data); + Tv *y = reinterpret_cast(y_data); + + try { + return mkl_sparse::gemv(exec_q, mkl_trans, alpha, spmat, x, beta, y, + depends); + } catch (oneapi::mkl::exception const &e) { + throw std::runtime_error( + std::string("sparse_gemv_compute: MKL exception: ") + e.what()); + } catch (sycl::exception const &e) { + throw std::runtime_error( + std::string("sparse_gemv_compute: SYCL exception: ") + e.what()); + } +} + +// --------------------------------------------------------------------------- +// Public entry points +// --------------------------------------------------------------------------- + +static oneapi::mkl::transpose decode_trans(const int trans) +{ + switch (trans) { + case 0: + return oneapi::mkl::transpose::nontrans; + case 1: + return oneapi::mkl::transpose::trans; + case 2: + return oneapi::mkl::transpose::conjtrans; + default: + throw std::invalid_argument( + "sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)"); + } +} + +std::tuple + sparse_gemv_init(sycl::queue &exec_q, + const int trans, + const dpctl::tensor::usm_ndarray &row_ptr, + const dpctl::tensor::usm_ndarray &col_ind, + const dpctl::tensor::usm_ndarray &values, + const std::int64_t num_rows, + const std::int64_t num_cols, + const std::int64_t nnz, + const std::vector &depends) +{ + if (!dpctl::utils::queues_are_compatible( + exec_q, + {row_ptr.get_queue(), col_ind.get_queue(), values.get_queue()})) + throw py::value_error( + "sparse_gemv_init: USM allocations are not compatible with the " + "execution queue."); + + // Basic CSR shape sanity. + if (row_ptr.get_ndim() != 1 || col_ind.get_ndim() != 1 || + values.get_ndim() != 1) + throw py::value_error( + "sparse_gemv_init: row_ptr, col_ind, values must all be 1-D."); + + if (row_ptr.get_shape(0) != num_rows + 1) + throw py::value_error( + "sparse_gemv_init: row_ptr length must equal num_rows + 1."); + + if (col_ind.get_shape(0) != nnz || values.get_shape(0) != nnz) + throw py::value_error( + "sparse_gemv_init: col_ind and values length must equal nnz."); + + // Index types of row_ptr and col_ind must match. + if (row_ptr.get_typenum() != col_ind.get_typenum()) + throw py::value_error( + "sparse_gemv_init: row_ptr and col_ind must have the same dtype."); + + auto mkl_trans = decode_trans(trans); + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int val_id = array_types.typenum_to_lookup_id(values.get_typenum()); + const int idx_id = array_types.typenum_to_lookup_id(row_ptr.get_typenum()); + + gemv_init_fn_ptr_t init_fn = gemv_init_dispatch_table[val_id][idx_id]; + if (init_fn == nullptr) + throw py::value_error( + "sparse_gemv_init: no implementation for the given value/index " + "dtype combination. Supported: {float32,float64,complex64," + "complex128} x {int32,int64}."); + + auto [handle_ptr, ev_opt] = + init_fn(exec_q, mkl_trans, row_ptr.get_data(), col_ind.get_data(), + values.get_data(), num_rows, num_cols, nnz, depends); + + return {handle_ptr, val_id, ev_opt}; +} + +sycl::event sparse_gemv_compute(sycl::queue &exec_q, + const std::uintptr_t handle_ptr, + const int val_type_id, + const int trans, + const double alpha, + const dpctl::tensor::usm_ndarray &x, + const double beta, + const dpctl::tensor::usm_ndarray &y, + const std::int64_t num_rows, + const std::int64_t num_cols, + const std::vector &depends) +{ + if (x.get_ndim() != 1) + throw py::value_error("sparse_gemv_compute: x must be a 1-D array."); + if (y.get_ndim() != 1) + throw py::value_error("sparse_gemv_compute: y must be a 1-D array."); + + if (!dpctl::utils::queues_are_compatible(exec_q, + {x.get_queue(), y.get_queue()})) + throw py::value_error( + "sparse_gemv_compute: USM allocations are not compatible with the " + "execution queue."); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(x, y)) + throw py::value_error( + "sparse_gemv_compute: x and y are overlapping memory segments."); + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(y); + + // Shape validation: op(A) is (num_rows, num_cols) for trans=N, + // (num_cols, num_rows) for trans={T,C}. + auto mkl_trans = decode_trans(trans); + const bool is_non_trans = (mkl_trans == oneapi::mkl::transpose::nontrans); + const std::int64_t op_rows = is_non_trans ? num_rows : num_cols; + const std::int64_t op_cols = is_non_trans ? num_cols : num_rows; + + if (x.get_shape(0) != op_cols) + throw py::value_error( + "sparse_gemv_compute: x length does not match operator columns."); + if (y.get_shape(0) != op_rows) + throw py::value_error( + "sparse_gemv_compute: y length does not match operator rows."); + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + y, static_cast(op_rows)); + + // Dtype verification: x, y, and the handle's value type must all match. + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int x_val_id = array_types.typenum_to_lookup_id(x.get_typenum()); + const int y_val_id = array_types.typenum_to_lookup_id(y.get_typenum()); + + if (x_val_id != val_type_id || y_val_id != val_type_id) + throw py::value_error( + "sparse_gemv_compute: x and y dtype must match the value dtype " + "of the sparse matrix used to build the handle."); + + if (val_type_id < 0 || val_type_id >= dpctl_td_ns::num_types) + throw py::value_error("sparse_gemv_compute: val_type_id out of range."); + + gemv_compute_fn_ptr_t compute_fn = gemv_compute_dispatch_table[val_type_id]; + + if (compute_fn == nullptr) + throw py::value_error("sparse_gemv_compute: unsupported value dtype."); + + auto spmat = reinterpret_cast(handle_ptr); + + return compute_fn(exec_q, spmat, mkl_trans, alpha, x.get_data(), beta, + const_cast(y.get_data()), depends); +} + +sycl::event sparse_gemv_release(sycl::queue &exec_q, + const std::uintptr_t handle_ptr, + const std::vector &depends) +{ + auto spmat = reinterpret_cast(handle_ptr); + + // release_matrix_handle takes `depends` so it will not free the handle + // until all pending compute work on it has completed. In recent oneMKL + // versions release_matrix_handle returns a sycl::event; older versions + // returned void. If your pinned oneMKL returns void, replace the body + // with: + // mkl_sparse::release_matrix_handle(exec_q, &spmat, depends); + // return exec_q.submit([&](sycl::handler &cgh) { + // cgh.depends_on(depends); + // cgh.host_task([]() {}); + // }); + sycl::event release_ev = + mkl_sparse::release_matrix_handle(exec_q, &spmat, depends); + + return release_ev; +} + +// --------------------------------------------------------------------------- +// Dispatch table factories and registration +// --------------------------------------------------------------------------- + +template +struct GemvInitContigFactory +{ + fnT get() + { + if constexpr (types::SparseGemvInitTypePairSupportFactory< + Tv, Ti>::is_defined) + return gemv_init_impl; + else + return nullptr; + } +}; + +template +struct GemvComputeContigFactory +{ + fnT get() + { + if constexpr (types::SparseGemvComputeTypeSupportFactory< + Tv>::is_defined) + return gemv_compute_impl; + else + return nullptr; + } +}; + +void init_sparse_gemv_dispatch_tables(void) +{ + // 2-D table on (Tv, Ti) for init. + init_dispatch_table( + gemv_init_dispatch_table); + + // 1-D table on Tv for compute. dpctl's type dispatch headers expose + // DispatchVectorBuilder as the 1-D analogue of DispatchTableBuilder. + dpctl_td_ns::DispatchVectorBuilder< + gemv_compute_fn_ptr_t, GemvComputeContigFactory, dpctl_td_ns::num_types> + builder; + builder.populate_dispatch_vector(gemv_compute_dispatch_table); +} + +} // namespace dpnp::extensions::sparse diff --git a/dpnp/backend/extensions/sparse/gemv.hpp b/dpnp/backend/extensions/sparse/gemv.hpp new file mode 100644 index 000000000000..0820fe9cc540 --- /dev/null +++ b/dpnp/backend/extensions/sparse/gemv.hpp @@ -0,0 +1,130 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +namespace dpnp::extensions::sparse +{ + +/** + * sparse_gemv_init -- ONE-TIME setup per sparse matrix operator. + * + * Calls init_matrix_handle + set_csr_data + optimize_gemv. + * + * Returns a tuple of: + * - handle_ptr: opaque matrix_handle_t cast to uintptr_t for safe + * Python round-tripping. + * - val_type_id: the dpctl typenum lookup id of the value dtype Tv. + * Python MUST pass this back to sparse_gemv_compute so + * the C++ layer can verify that x and y dtype match the + * handle's value type. + * - event: dependency event from optimize_gemv; the caller must + * wait on it (or chain via depends) before the first + * sparse_gemv_compute call. + * + * LIFETIME CONTRACT -- IMPORTANT: + * The handle owns NO copies of the CSR arrays. The caller MUST keep + * row_ptr, col_ind, and values USM allocations alive until + * sparse_gemv_release has been called AND its returned event has + * completed. Dropping any of them earlier is undefined behavior and + * will produce silent memory corruption -- there is no runtime check. + * + * The Python wrapper (_CachedSpMV) enforces this contract by holding + * a reference to the CSR matrix for the lifetime of the handle. + */ +extern std::tuple + sparse_gemv_init(sycl::queue &exec_q, + const int trans, + const dpctl::tensor::usm_ndarray &row_ptr, + const dpctl::tensor::usm_ndarray &col_ind, + const dpctl::tensor::usm_ndarray &values, + const std::int64_t num_rows, + const std::int64_t num_cols, + const std::int64_t nnz, + const std::vector &depends); + +/** + * sparse_gemv_compute -- PER-ITERATION SpMV. + * + * Calls only oneapi::mkl::sparse::gemv using the pre-built handle. + * Verifies that: + * - x and y are 1-D usm_ndarrays on a queue compatible with exec_q + * - x and y dtype match val_type_id (the handle's value type) + * - x and y shapes match op(A) dimensions, taking trans into account + * (op(A) is num_rows x num_cols for trans=N, num_cols x num_rows + * for trans={T,C}) + * - y is writable and does not overlap x + * + * alpha and beta are passed as double and cast inside gemv_compute_impl + * to the matrix value type. For complex Tv the cast drops the imaginary + * part; callers needing complex scalars should keep alpha=1, beta=0 + * (the solver use case). + * + * Returns the gemv event. The caller is responsible for sequencing + * subsequent work on the same queue; no host-side wait or host_task + * keep-alive is performed. + */ +extern sycl::event sparse_gemv_compute(sycl::queue &exec_q, + const std::uintptr_t handle_ptr, + const int val_type_id, + const int trans, + const double alpha, + const dpctl::tensor::usm_ndarray &x, + const double beta, + const dpctl::tensor::usm_ndarray &y, + const std::int64_t num_rows, + const std::int64_t num_cols, + const std::vector &depends); + +/** + * sparse_gemv_release -- free the matrix_handle created by sparse_gemv_init. + * + * Must be called exactly once per handle, after all compute calls that + * depend on it have completed. The returned event depends on the release, + * so the caller can chain CSR buffer deallocation on it safely. + */ +extern sycl::event sparse_gemv_release(sycl::queue &exec_q, + const std::uintptr_t handle_ptr, + const std::vector &depends); + +/** + * Register the init (2-D on Tv x Ti) and compute (1-D on Tv) dispatch + * tables. Called exactly once from PYBIND11_MODULE. + */ +extern void init_sparse_gemv_dispatch_tables(void); + +} // namespace dpnp::extensions::sparse diff --git a/dpnp/backend/extensions/sparse/sparse_py.cpp b/dpnp/backend/extensions/sparse/sparse_py.cpp new file mode 100644 index 000000000000..3f018595ea81 --- /dev/null +++ b/dpnp/backend/extensions/sparse/sparse_py.cpp @@ -0,0 +1,153 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include + +#include +#include + +#include + +#include + +#include "gemv.hpp" + +namespace py = pybind11; + +using dpnp::extensions::sparse::init_sparse_gemv_dispatch_tables; +using dpnp::extensions::sparse::sparse_gemv_compute; +using dpnp::extensions::sparse::sparse_gemv_init; +using dpnp::extensions::sparse::sparse_gemv_release; + +PYBIND11_MODULE(_sparse_impl, m) +{ + init_sparse_gemv_dispatch_tables(); + + // ------------------------------------------------------------------ + // _using_onemath() + // + // Reports whether the module was compiled against the portable + // OneMath interface (USE_ONEMATH) rather than direct oneMKL. + // ------------------------------------------------------------------ + m.def("_using_onemath", []() -> bool { +#ifdef USE_ONEMATH + return true; +#else + return false; +#endif + }); + + // ------------------------------------------------------------------ + // _sparse_gemv_init(exec_q, trans, row_ptr, col_ind, values, + // num_rows, num_cols, nnz, depends) + // -> (handle: int, val_type_id: int, event) + // + // Calls init_matrix_handle + set_csr_data + optimize_gemv ONCE. + // + // The returned handle is an opaque uintptr_t; val_type_id is the + // dpctl typenum lookup id of the matrix value dtype and MUST be + // passed back to _sparse_gemv_compute so the C++ layer can verify + // that x and y dtype match the handle. + // + // LIFETIME CONTRACT: the caller must keep row_ptr / col_ind / values + // USM allocations alive until _sparse_gemv_release has been called + // AND its returned event has completed. The handle does not copy + // the CSR arrays. + // ------------------------------------------------------------------ + m.def( + "_sparse_gemv_init", + [](sycl::queue &exec_q, const int trans, + const dpctl::tensor::usm_ndarray &row_ptr, + const dpctl::tensor::usm_ndarray &col_ind, + const dpctl::tensor::usm_ndarray &values, + const std::int64_t num_rows, const std::int64_t num_cols, + const std::int64_t nnz, const std::vector &depends) + -> std::tuple { + return sparse_gemv_init(exec_q, trans, row_ptr, col_ind, values, + num_rows, num_cols, nnz, depends); + }, + py::arg("exec_q"), py::arg("trans"), py::arg("row_ptr"), + py::arg("col_ind"), py::arg("values"), py::arg("num_rows"), + py::arg("num_cols"), py::arg("nnz"), py::arg("depends"), + "Initialise oneMKL sparse matrix handle " + "(set_csr_data + optimize_gemv). " + "Returns (handle_ptr: int, val_type_id: int, event). " + "Call once per operator."); + + // ------------------------------------------------------------------ + // _sparse_gemv_compute(exec_q, handle, val_type_id, trans, alpha, + // x, beta, y, num_rows, num_cols, depends) + // -> gemv_event + // + // Fires sparse::gemv using a pre-built handle. Verifies x and y + // dtype match val_type_id from init, and that shapes agree with + // op(A) dimensions (swapped for trans != N). + // + // Only the cheap MKL kernel is dispatched; no analysis overhead. + // No host_task keep-alive is submitted -- pybind11 refcounts the + // usm_ndarrays across the call, and sequencing of subsequent work + // on the same queue happens automatically. + // ------------------------------------------------------------------ + m.def( + "_sparse_gemv_compute", + [](sycl::queue &exec_q, const std::uintptr_t handle_ptr, + const int val_type_id, const int trans, const double alpha, + const dpctl::tensor::usm_ndarray &x, const double beta, + const dpctl::tensor::usm_ndarray &y, const std::int64_t num_rows, + const std::int64_t num_cols, + const std::vector &depends) -> sycl::event { + return sparse_gemv_compute(exec_q, handle_ptr, val_type_id, trans, + alpha, x, beta, y, num_rows, num_cols, + depends); + }, + py::arg("exec_q"), py::arg("handle"), py::arg("val_type_id"), + py::arg("trans"), py::arg("alpha"), py::arg("x"), py::arg("beta"), + py::arg("y"), py::arg("num_rows"), py::arg("num_cols"), + py::arg("depends"), + "Execute sparse::gemv using a pre-built handle. " + "Returns the gemv event."); + + // ------------------------------------------------------------------ + // _sparse_gemv_release(exec_q, handle, depends) -> event + // + // Releases the matrix_handle allocated by _sparse_gemv_init. + // Must be called exactly once per handle after all compute calls + // referencing it have completed. The returned event depends on the + // release, so callers can chain CSR buffer deallocation on it. + // ------------------------------------------------------------------ + m.def( + "_sparse_gemv_release", + [](sycl::queue &exec_q, const std::uintptr_t handle_ptr, + const std::vector &depends) -> sycl::event { + return sparse_gemv_release(exec_q, handle_ptr, depends); + }, + py::arg("exec_q"), py::arg("handle"), py::arg("depends"), + "Release the oneMKL matrix_handle created by _sparse_gemv_init."); +} diff --git a/dpnp/backend/extensions/sparse/types_matrix.hpp b/dpnp/backend/extensions/sparse/types_matrix.hpp new file mode 100644 index 000000000000..42145a4ab4d2 --- /dev/null +++ b/dpnp/backend/extensions/sparse/types_matrix.hpp @@ -0,0 +1,122 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +// dpctl tensor headers +#include "utils/type_dispatch.hpp" + +// dpctl namespace alias for type dispatch utilities +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; + +namespace dpnp::extensions::sparse::types +{ + +/** + * @brief Factory encoding the supported (value type, index type) combinations + * for oneapi::mkl::sparse::gemv initialization. + * + * oneMKL sparse BLAS supports: + * - float32 with int32 indices + * - float32 with int64 indices + * - float64 with int32 indices + * - float64 with int64 indices + * - complex with int32 indices + * - complex with int64 indices + * - complex with int32 indices + * - complex with int64 indices + * + * Complex support requires oneMKL >= 2023.x (sparse BLAS complex USM API). + * The init dispatch table entry is non-null only when the pair is registered + * here; the Python layer falls back to A.dot(x) when the entry is nullptr. + * + * @tparam Tv Value type of the sparse matrix and dense vectors. + * @tparam Ti Index type of the sparse matrix (row_ptr / col_ind arrays). + */ +template +struct SparseGemvInitTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + // real single precision + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + // real double precision + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + // complex single precision + dpctl_td_ns:: + TypePairDefinedEntry, Ti, std::int32_t>, + dpctl_td_ns:: + TypePairDefinedEntry, Ti, std::int64_t>, + // complex double precision + dpctl_td_ns:: + TypePairDefinedEntry, Ti, std::int32_t>, + dpctl_td_ns:: + TypePairDefinedEntry, Ti, std::int64_t>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +/** + * @brief Factory encoding supported value types for sparse::gemv compute. + * + * The compute path only requires Tv because the index type is baked into + * the matrix_handle at init time. Using a 1-D dispatch vector on Tv avoids + * the wasted num_types * num_types slots of a 2-D table where only the + * diagonal (keyed on Ti) would ever be populated. + * + * If your pinned dpctl version does not expose TypeDefinedEntry as a 1-arg + * entry, fall back to the std::is_same_v expansion shown in the comment + * below -- both are equivalent. + * + * @tparam Tv Value type of the sparse matrix and dense vectors. + */ +template +struct SparseGemvComputeTypeSupportFactory +{ +#if defined(DPCTL_HAS_TYPE_DEFINED_ENTRY) + static constexpr bool + is_defined = std::disjunction dpctl_td_ns::TypeDefinedEntry, + dpctl_td_ns::TypeDefinedEntry, + dpctl_td_ns::TypeDefinedEntry>, + dpctl_td_ns::TypeDefinedEntry>, + dpctl_td_ns::NotDefinedEntry > ::is_defined; +#else + // Portable fallback: works with any dpctl version. + static constexpr bool is_defined = + std::is_same_v || std::is_same_v || + std::is_same_v> || + std::is_same_v>; +#endif +}; + +} // namespace dpnp::extensions::sparse::types diff --git a/dpnp/scipy/__init__.py b/dpnp/scipy/__init__.py index 56cf27f56342..ceb1f9df932e 100644 --- a/dpnp/scipy/__init__.py +++ b/dpnp/scipy/__init__.py @@ -36,6 +36,6 @@ DPNP functionality, reusing DPNP and oneMKL implementations underneath. """ -from . import linalg, special +from . import linalg, sparse, special -__all__ = ["linalg", "special"] +__all__ = ["linalg", "special", "sparse"] diff --git a/dpnp/scipy/sparse/__init__.py b/dpnp/scipy/sparse/__init__.py new file mode 100644 index 000000000000..83b6e365a6cc --- /dev/null +++ b/dpnp/scipy/sparse/__init__.py @@ -0,0 +1,37 @@ +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +"""Sparse linear algebra namespace for DPNP. + +Currently this module exposes the :mod:`dpnp.scipy.sparse.linalg` submodule +and provides a location for future sparse matrix container types. +""" + +from . import linalg + +__all__ = ["linalg"] diff --git a/dpnp/scipy/sparse/linalg/__init__.py b/dpnp/scipy/sparse/linalg/__init__.py new file mode 100644 index 000000000000..30124562447e --- /dev/null +++ b/dpnp/scipy/sparse/linalg/__init__.py @@ -0,0 +1,44 @@ +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +""" +Sparse linear algebra interface for DPNP. + +This module provides a subset of :mod:`scipy.sparse.linalg` + functionality on top of DPNP arrays. + +The initial implementation focuses on the :class:`LinearOperator` interface +and a small set of Krylov solvers (``cg``, ``gmres``, ``minres``). +""" + +from __future__ import annotations + +from ._interface import LinearOperator, aslinearoperator +from ._iterative import cg, gmres, minres + +__all__ = ["LinearOperator", "aslinearoperator", "cg", "gmres", "minres"] diff --git a/dpnp/scipy/sparse/linalg/_interface.py b/dpnp/scipy/sparse/linalg/_interface.py new file mode 100644 index 000000000000..e071242f6ba3 --- /dev/null +++ b/dpnp/scipy/sparse/linalg/_interface.py @@ -0,0 +1,576 @@ +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +"""LinearOperator and helpers for dpnp.scipy.sparse.linalg. + +Aligned with SciPy main scipy/sparse/linalg/_interface.py and +CuPy v14.0.1 cupyx/scipy/sparse/linalg/_interface.py so that code +written for either library is portable to dpnp. + +Additional items versus the previous version +-------------------------------------------- +* T / H properties now exposed as SciPy does (A.T and A.H work) +* _adjoint / _transpose virtual hooks on LinearOperator base +* _ScaledLinearOperator.adjoint uses conj(alpha) correctly +* aslinearoperator accepts ndim-1 vectors (promotes to column/row) +* _isshape accepts numpy integer types, not just Python int +""" + +from __future__ import annotations + +import warnings + +import dpnp + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _isshape(shape): + """Return True if shape is a length-2 tuple of non-negative integers.""" + if not isinstance(shape, tuple) or len(shape) != 2: + return False + try: + return all(int(s) >= 0 and int(s) == s for s in shape) + except (TypeError, ValueError): + return False + + +def _isintlike(x): + try: + return int(x) == x + except (TypeError, ValueError): + return False + + +def _get_dtype(operators, dtypes=None): + if dtypes is None: + dtypes = [] + for obj in operators: + if obj is not None and hasattr(obj, "dtype") and obj.dtype is not None: + dtypes.append(obj.dtype) + return dpnp.result_type(*dtypes) if dtypes else None + + +class LinearOperator: + """Drop-in replacement for cupyx/scipy LinearOperator backed by dpnp arrays. + + Supports the full operator algebra (addition, multiplication, scaling, + power, adjoint A.H, transpose A.T) matching CuPy v14.0.1 and SciPy main. + """ + + ndim = 2 + + def __new__(cls, *args, **kwargs): + if cls is LinearOperator: + return super().__new__(_CustomLinearOperator) + else: + obj = super().__new__(cls) + if ( + type(obj)._matvec is LinearOperator._matvec + and type(obj)._matmat is LinearOperator._matmat + ): + warnings.warn( + "LinearOperator subclass should implement at least one of " + "_matvec and _matmat.", + RuntimeWarning, + stacklevel=2, + ) + return obj + + def __init__(self, dtype, shape): + if dtype is not None: + dtype = dpnp.dtype(dtype) + shape = tuple(int(s) for s in shape) + if not _isshape(shape): + raise ValueError( + f"invalid shape {shape!r} (must be a length-2 tuple of " + "non-negative ints)" + ) + self.dtype = dtype + self.shape = shape + + def _init_dtype(self): + """Infer dtype via a trial matvec on a zero vector.""" + if self.dtype is not None: + return + v = dpnp.zeros(self.shape[-1], dtype=dpnp.float64) + self.dtype = self.matvec(v).dtype + + def _matvec(self, x): + return self.matmat(x.reshape(-1, 1)) + + def _matmat(self, X): + return dpnp.hstack([self.matvec(col.reshape(-1, 1)) for col in X.T]) + + def _rmatvec(self, x): + if type(self)._adjoint is LinearOperator._adjoint: + raise NotImplementedError( + "rmatvec is not defined for this LinearOperator" + ) + return self.H.matvec(x) + + def _rmatmat(self, X): + if type(self)._adjoint is LinearOperator._adjoint: + return dpnp.hstack( + [self.rmatvec(col.reshape(-1, 1)) for col in X.T] + ) + return self.H.matmat(X) + + def matvec(self, x): + """Apply the matrix-vector product.""" + M, N = self.shape + if x.shape not in ((N,), (N, 1)): + raise ValueError( + f"dimension mismatch: operator shape {self.shape}, " + "vector shape {x.shape}" + ) + y = self._matvec(x) + return y.reshape(M) if x.ndim == 1 else y.reshape(M, 1) + + def rmatvec(self, x): + """Apply the adjoint matrix-vector product.""" + M, N = self.shape + if x.shape not in ((M,), (M, 1)): + raise ValueError( + f"dimension mismatch: operator shape {self.shape}, " + "vector shape {x.shape}" + ) + y = self._rmatvec(x) + return y.reshape(N) if x.ndim == 1 else y.reshape(N, 1) + + def matmat(self, X): + """Apply the matrix-matrix product.""" + if X.ndim != 2: + raise ValueError(f"expected 2-D array, got {X.ndim}-D") + if X.shape[0] != self.shape[1]: + raise ValueError( + f"dimension mismatch: {self.shape!r} vs {X.shape!r}" + ) + return self._matmat(X) + + def rmatmat(self, X): + """Apply the adjoint matrix-matrix product.""" + if X.ndim != 2: + raise ValueError(f"expected 2-D array, got {X.ndim}-D") + if X.shape[0] != self.shape[0]: + raise ValueError( + f"dimension mismatch: {self.shape!r} vs {X.shape!r}" + ) + return self._rmatmat(X) + + def dot(self, x): + if isinstance(x, LinearOperator): + return _ProductLinearOperator(self, x) + elif dpnp.isscalar(x): + return _ScaledLinearOperator(self, x) + else: + x = dpnp.asarray(x) + if x.ndim == 1 or (x.ndim == 2 and x.shape[1] == 1): + return self.matvec(x) + elif x.ndim == 2: + return self.matmat(x) + raise ValueError( + f"expected 1-D or 2-D array or LinearOperator, got {x!r}" + ) + + def __call__(self, x): + return self * x + + def __mul__(self, x): + """Multiply operator by array x.""" + return self.dot(x) + + def __matmul__(self, x): + if dpnp.isscalar(x): + raise ValueError( + "Scalar operands not allowed with '@'; use '*' instead" + ) + return self.__mul__(x) + + def __rmatmul__(self, x): + if dpnp.isscalar(x): + raise ValueError( + "Scalar operands not allowed with '@'; use '*' instead" + ) + return self.__rmul__(x) + + def __rmul__(self, x): + if dpnp.isscalar(x): + return _ScaledLinearOperator(self, x) + return NotImplemented + + def __pow__(self, p): + if dpnp.isscalar(p): + return _PowerLinearOperator(self, p) + return NotImplemented + + def __add__(self, x): + if isinstance(x, LinearOperator): + return _SumLinearOperator(self, x) + return NotImplemented + + def __neg__(self): + return _ScaledLinearOperator(self, -1) + + def __sub__(self, x): + return self.__add__(-x) + + def _adjoint(self): + """Return conjugate-transpose operator (override in subclasses).""" + return _AdjointLinearOperator(self) + + def _transpose(self): + """Return plain-transpose operator (override in subclasses).""" + return _TransposedLinearOperator(self) + + def adjoint(self): + """Hermitian adjoint A^H.""" + return self._adjoint() + + def transpose(self): + """Plain (non-conjugated) transpose A^T.""" + return self._transpose() + + #: A.H — conjugate transpose + H = property(adjoint) + #: A.T — plain transpose + T = property(transpose) + + def __repr__(self): + dt = ( + "unspecified dtype" if self.dtype is None else f"dtype={self.dtype}" + ) + return ( + f"<{self.shape[0]}x{self.shape[1]}" + f" {self.__class__.__name__} with {dt}>" + ) + + +class _CustomLinearOperator(LinearOperator): + """Created when the user calls LinearOperator(shape, matvec=...)""" + + def __init__( + self, shape, matvec, rmatvec=None, matmat=None, dtype=None, rmatmat=None + ): + super().__init__(dtype, shape) + self.args = () + self.__matvec_impl = matvec + self.__rmatvec_impl = rmatvec + self.__rmatmat_impl = rmatmat + self.__matmat_impl = matmat + self._init_dtype() + + def _matvec(self, x): + return self.__matvec_impl(x) + + def _matmat(self, X): + if self.__matmat_impl is not None: + return self.__matmat_impl(X) + return super()._matmat(X) + + def _rmatvec(self, x): + if self.__rmatvec_impl is None: + raise NotImplementedError( + "rmatvec is not defined for this operator" + ) + return self.__rmatvec_impl(x) + + def _rmatmat(self, X): + if self.__rmatmat_impl is not None: + return self.__rmatmat_impl(X) + return super()._rmatmat(X) + + def _adjoint(self): + return _CustomLinearOperator( + shape=(self.shape[1], self.shape[0]), + matvec=self.__rmatvec_impl, + rmatvec=self.__matvec_impl, + matmat=self.__rmatmat_impl, + rmatmat=self.__matmat_impl, + dtype=self.dtype, + ) + + +class _AdjointLinearOperator(LinearOperator): + def __init__(self, A): + super().__init__(A.dtype, (A.shape[1], A.shape[0])) + self.A = A + self.args = (A,) + + def _matvec(self, x): + return self.A._rmatvec(x) # pylint: disable=protected-access + + def _rmatvec(self, x): + return self.A._matvec(x) # pylint: disable=protected-access + + def _matmat(self, X): + return self.A._rmatmat(X) # pylint: disable=protected-access + + def _rmatmat(self, X): + return self.A._matmat(X) # pylint: disable=protected-access + + def _adjoint(self): + return self.A + + +class _TransposedLinearOperator(LinearOperator): + def __init__(self, A): + super().__init__(A.dtype, (A.shape[1], A.shape[0])) + self.A = A + self.args = (A,) + + def _matvec(self, x): + return dpnp.conj(self.A._rmatvec(dpnp.conj(x))) + + def _rmatvec(self, x): + return dpnp.conj(self.A._matvec(dpnp.conj(x))) + + def _matmat(self, X): + return dpnp.conj(self.A._rmatmat(dpnp.conj(X))) + + def _rmatmat(self, X): + return dpnp.conj(self.A._matmat(dpnp.conj(X))) + + def _transpose(self): + return self.A + + +class _SumLinearOperator(LinearOperator): + def __init__(self, A, B): + if A.shape != B.shape: + raise ValueError(f"shape mismatch for addition: {A!r} + {B!r}") + super().__init__(_get_dtype([A, B]), A.shape) + self.args = (A, B) + + def _matvec(self, x): + return self.args[0].matvec(x) + self.args[1].matvec(x) + + def _rmatvec(self, x): + return self.args[0].rmatvec(x) + self.args[1].rmatvec(x) + + def _matmat(self, X): + return self.args[0].matmat(X) + self.args[1].matmat(X) + + def _rmatmat(self, X): + return self.args[0].rmatmat(X) + self.args[1].rmatmat(X) + + def _adjoint(self): + return self.args[0].H + self.args[1].H + + +class _ProductLinearOperator(LinearOperator): + def __init__(self, A, B): + if A.shape[1] != B.shape[0]: + raise ValueError(f"shape mismatch for multiply: {A!r} * {B!r}") + super().__init__(_get_dtype([A, B]), (A.shape[0], B.shape[1])) + self.args = (A, B) + + def _matvec(self, x): + return self.args[0].matvec(self.args[1].matvec(x)) + + def _rmatvec(self, x): + return self.args[1].rmatvec(self.args[0].rmatvec(x)) + + def _matmat(self, X): + return self.args[0].matmat(self.args[1].matmat(X)) + + def _rmatmat(self, X): + return self.args[1].rmatmat(self.args[0].rmatmat(X)) + + def _adjoint(self): + A, B = self.args + return B.H * A.H + + +class _ScaledLinearOperator(LinearOperator): + def __init__(self, A, alpha): + super().__init__(_get_dtype([A], [type(alpha)]), A.shape) + self.args = (A, alpha) + + def _matvec(self, x): + return self.args[1] * self.args[0].matvec(x) + + def _rmatvec(self, x): + return dpnp.conj(self.args[1]) * self.args[0].rmatvec(x) + + def _matmat(self, X): + return self.args[1] * self.args[0].matmat(X) + + def _rmatmat(self, X): + return dpnp.conj(self.args[1]) * self.args[0].rmatmat(X) + + def _adjoint(self): + A, alpha = self.args + return A.H * dpnp.conj(alpha) + + +class _PowerLinearOperator(LinearOperator): + def __init__(self, A, p): + if A.shape[0] != A.shape[1]: + raise ValueError("matrix power requires a square operator") + if not _isintlike(p) or p < 0: + raise ValueError( + "matrix power requires a non-negative integer exponent" + ) + super().__init__(_get_dtype([A]), A.shape) + self.args = (A, int(p)) + + def _power(self, f, x): + res = x.copy() + for _ in range(self.args[1]): + res = f(res) + return res + + def _matvec(self, x): + return self._power(self.args[0].matvec, x) + + def _rmatvec(self, x): + return self._power(self.args[0].rmatvec, x) + + def _matmat(self, X): + return self._power(self.args[0].matmat, X) + + def _rmatmat(self, X): + return self._power(self.args[0].rmatmat, X) + + def _adjoint(self): + A, p = self.args + return A.H**p + + +class MatrixLinearOperator(LinearOperator): + """Wrap a dense dpnp matrix (or sparse matrix) as a LinearOperator.""" + + def __init__(self, A): + super().__init__(A.dtype, A.shape) + self.A = A + self.__adj = None + self.args = (A,) + + def _matmat(self, X): + return self.A.dot(X) + + def _rmatmat(self, X): + return dpnp.conj(self.A.T).dot(X) + + def _adjoint(self): + if self.__adj is None: + self.__adj = _AdjointMatrixOperator(self) + return self.__adj + + +class _AdjointMatrixOperator(MatrixLinearOperator): + def __init__(self, adjoint): + self.A = dpnp.conj(adjoint.A.T) + self.__adjoint = adjoint + self.args = (adjoint,) + self.shape = (adjoint.shape[1], adjoint.shape[0]) + + @property + def dtype(self): + return self.__adjoint.dtype + + def _adjoint(self): + return self.__adjoint + + +class IdentityOperator(LinearOperator): + """Identity operator — used as the default (no-op) preconditioner.""" + + def __init__(self, shape, dtype=None): + super().__init__(dtype, shape) + + def _matvec(self, x): + """Apply matrix-vector product via stored array.""" + return x + + def _rmatvec(self, x): + return x + + def _matmat(self, X): + return X + + def _rmatmat(self, X): + return X + + def _adjoint(self): + return self + + def _transpose(self): + return self + + +def aslinearoperator(A) -> LinearOperator: + """Wrap A as a LinearOperator if it is not already one. + + Handles (in order): + 1. Already a LinearOperator — returned as-is. + 2. dpnp.scipy.sparse sparse matrix. + 3. Dense 2-D dpnp.ndarray. + 4. Duck-typed objects with .shape and .matvec / @ support. + """ + if isinstance(A, LinearOperator): + return A + + try: + from dpnp.scipy import sparse as _sp # pylint: disable=import-outside-toplevel + + if _sp.issparse(A): + return MatrixLinearOperator(A) + except (ImportError, AttributeError): + pass + + if isinstance(A, dpnp.ndarray): + if A.ndim != 2: + raise ValueError( + f"aslinearoperator: dpnp array must be 2-D, got {A.ndim}-D" + ) + return MatrixLinearOperator(A) + + if hasattr(A, "shape") and len(A.shape) == 2: + m, n = int(A.shape[0]), int(A.shape[1]) + dtype = getattr(A, "dtype", None) + matvec = A.matvec if hasattr(A, "matvec") else (lambda x: A @ x) + rmatvec = A.rmatvec if hasattr(A, "rmatvec") else None + matmat = A.matmat if hasattr(A, "matmat") else None + rmatmat = A.rmatmat if hasattr(A, "rmatmat") else None + return LinearOperator( + (m, n), + matvec=matvec, + rmatvec=rmatvec, + matmat=matmat, + dtype=dtype, + rmatmat=rmatmat, + ) + + raise TypeError( + f"Cannot convert object of type {type(A)!r} to a LinearOperator. " + "Expected a LinearOperator, dpnp sparse matrix, or 2-D dpnp.ndarray." + ) diff --git a/dpnp/scipy/sparse/linalg/_iterative.py b/dpnp/scipy/sparse/linalg/_iterative.py new file mode 100644 index 000000000000..de09d2684115 --- /dev/null +++ b/dpnp/scipy/sparse/linalg/_iterative.py @@ -0,0 +1,966 @@ +# ***************************************************************************** +# Copyright (c) 2025, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +"""Iterative sparse linear solvers for dpnp -- pure GPU/SYCL implementation. + +All computation stays on the device (USM/oneMKL). There is NO host-dispatch +fallback: transferring data to the CPU for small systems defeats the purpose +of keeping a live computation on GPU memory. + +Solver coverage +--------------- +cg : Conjugate Gradient (Hermitian positive definite) +gmres : Restarted GMRES (general non-symmetric) +minres : MINRES (symmetric possibly indefinite) + +SpMV fast-path +-------------- +When a CSR dpnp sparse matrix is passed as A or M, _make_fast_matvec() +constructs a _CachedSpMV object that: + 1. Calls _sparse_gemv_init() ONCE to create the oneMKL matrix_handle, + register CSR pointers via set_csr_data, and run optimize_gemv + (the expensive sparsity-analysis phase). + 2. Calls _sparse_gemv_compute() on every matvec -- only the cheap + oneMKL sparse::gemv kernel fires; no handle setup overhead. + 3. Calls _sparse_gemv_release() in __del__ to free the handle. + +This means optimize_gemv runs once per operator, not once per iteration, +which is the correct usage pattern for oneMKL sparse BLAS. + +Supported dtypes for the oneMKL SpMV fast-path: + values : float32, float64, complex64, complex128 + indices: int32, int64 +Complex dtypes require oneMKL sparse BLAS support (available since +oneMKL 2023.x); if the dispatch table slot is nullptr (types_matrix.hpp +does not register the pair) a ValueError is raised by the C++ layer. +_make_fast_matvec catches this and falls back to A.dot(x). +""" + +from __future__ import annotations + +from typing import Callable + +import dpctl.utils as dpu +import numpy + +import dpnp +import dpnp.backend.extensions.blas._blas_impl as bi + +from ._interface import IdentityOperator, LinearOperator, aslinearoperator + +# --------------------------------------------------------------------------- +# oneMKL sparse SpMV hook -- cached-handle API +# --------------------------------------------------------------------------- + +try: + from dpnp.backend.extensions.sparse import _sparse_impl as _si + + _HAS_SPARSE_IMPL = True +except ImportError: + _si = None + _HAS_SPARSE_IMPL = False + +_SUPPORTED_DTYPES = frozenset("fdFD") + +def _np_dtype(dp_dtype) -> numpy.dtype: + """Normalise any dtype-like (dpnp type/numpy type/string) to numpy.dtype.""" + return numpy.dtype(dp_dtype) + +def _check_dtype(dtype, name: str) -> None: + if _np_dtype(dtype).char not in _SUPPORTED_DTYPES: + raise TypeError( + f"{name} has unsupported dtype {dtype}; " + "only float32, float64, complex64, complex128 are accepted." + ) + +class _CachedSpMV: + """ + Wrap a CSR matrix with a persistent oneMKL matrix_handle. + + The handle is initialised (set_csr_data + optimize_gemv) exactly once + in __init__. Subsequent calls to __call__ only invoke sparse::gemv, + paying no analysis overhead. The handle is released in __del__. + + Parameters + ---------- + A : dpnp CSR sparse matrix + trans : int 0=N, 1=T, 2=C (fixed at construction) + """ + + __slots__ = ( + "_A", + "_exec_q", + "_handle", + "_trans", + "_nrows", + "_ncols", + "_nnz", + "_out_size", + "_in_size", + "_dtype", + "_val_type_id", + ) + + def __init__(self, A, trans: int = 0): + self._A = A # keep alive so USM pointers stay valid + self._trans = int(trans) + self._nrows = int(A.shape[0]) + self._ncols = int(A.shape[1]) + self._nnz = int(A.data.shape[0]) + self._exec_q = A.data.sycl_queue + self._dtype = A.data.dtype + + # Output and input lengths depend on transpose mode. + # For trans=0 (N): y has nrows, x has ncols. + # For trans=1/2 (T/C): y has ncols, x has nrows. + if self._trans == 0: + self._out_size = self._nrows + self._in_size = self._ncols + else: + self._out_size = self._ncols + self._in_size = self._nrows + + self._handle = None + self._val_type_id = -1 + + # init_matrix_handle + set_csr_data + optimize_gemv (once). + # We must wait on optimize_gemv before any compute call can run; + # this is the only place __init__/__call__ blocks. + handle, val_type_id, ev = _si._sparse_gemv_init( + self._exec_q, + self._trans, + A.indptr, + A.indices, + A.data, + self._nrows, + self._ncols, + self._nnz, + [], + ) + ev.wait() + self._handle = handle + self._val_type_id = val_type_id + + def __call__(self, x: dpnp.ndarray) -> dpnp.ndarray: + """Y = op(A) * x -- only sparse::gemv fires, fully async.""" + y = dpnp.empty( + self._out_size, dtype=self._dtype, sycl_queue=self._exec_q + ) + # Do NOT wait on the event -- subsequent dpnp ops on the same + # queue will serialize behind it automatically. Blocking here + # throws away async overlap and dominates small-problem runtime. + _si._sparse_gemv_compute( + self._exec_q, + self._handle, + self._val_type_id, + self._trans, + 1.0, + x, + 0.0, + y, + self._nrows, + self._ncols, + [], + ) + return y + + def __del__(self): + # Guard against partial construction: _handle may not be set if + # __init__ raised before the assignment. + handle = getattr(self, "_handle", None) + if handle is not None and _si is not None: + try: + _si._sparse_gemv_release(self._exec_q, handle, []) + except Exception: + pass + self._handle = None + + +class _CachedSpMVPair: + """Holds forward and (lazily built) adjoint cached SpMV handles.""" + + __slots__ = ("forward", "_A", "_adjoint") + + def __init__(self, A): + self.forward = _CachedSpMV(A, trans=0) + self._A = A + self._adjoint = None + + def matvec(self, x): + """Apply the operator to vector x.""" + return self.forward(x) + + def rmatvec(self, x): + """Return the data type of the operator.""" + if self._adjoint is None: + # Build conjtrans handle on first use. For real dtypes + # this is equivalent to trans=1. + is_cpx = dpnp.issubdtype(self._A.data.dtype, dpnp.complexfloating) + self._adjoint = _CachedSpMV(self._A, trans=2 if is_cpx else 1) + return self._adjoint(x) + + +def _make_fast_matvec(A): + """Return a _CachedSpMVPair if A is a CSR matrix with oneMKL support, + or None if A is not an eligible sparse matrix. + + Falls back to None (caller uses A.dot) on: + - missing _sparse_impl extension + - dtype not supported by the C++ dispatch table + - any other C++ exception during handle initialisation + """ + try: + from dpnp.scipy import sparse as _sp + + if not (_sp.issparse(A) and A.format == "csr"): + return None + except (ImportError, AttributeError): + return None + + if not _HAS_SPARSE_IMPL: + return None + + # Only build the cached handle for supported dtypes. + if _np_dtype(A.data.dtype).char not in _SUPPORTED_DTYPES: + return None + + try: + return _CachedSpMVPair(A) + except Exception: + return None + + +def _make_system(A, M, x0, b): + """Validate and prepare (A_op, M_op, x, b, dtype) on device. + + dpnp-only policy: b, x0, and any dense operator inputs must already + be dpnp arrays. No host->device promotion happens here. + + dtype promotion follows CuPy v14 rules: A.dtype is used when it is in + {f,d,F,D}; otherwise b.dtype is promoted to float64 (real) or + complex128 (complex). + """ + if not isinstance(b, dpnp.ndarray): + raise TypeError(f"b must be a dpnp.ndarray, got {type(b).__name__}") + if x0 is not None and not isinstance(x0, dpnp.ndarray): + raise TypeError( + f"x0 must be a dpnp.ndarray or None, got {type(x0).__name__}" + ) + + A_op = aslinearoperator(A) + if A_op.shape[0] != A_op.shape[1]: + raise ValueError("A must be a square operator") + n = A_op.shape[0] + + b = b.reshape(-1) + if b.shape[0] != n: + raise ValueError( + f"b length {b.shape[0]} does not match operator dimension {n}" + ) + + # Dtype promotion: prefer A.dtype; fall back via b.dtype. + if ( + A_op.dtype is not None + and _np_dtype(A_op.dtype).char in _SUPPORTED_DTYPES + ): + dtype = A_op.dtype + elif dpnp.issubdtype(b.dtype, dpnp.complexfloating): + dtype = dpnp.complex128 + else: + dtype = dpnp.float64 + + b = b.astype(dtype, copy=False) + _check_dtype(b.dtype, "b") + + if x0 is None: + x = dpnp.zeros(n, dtype=dtype, sycl_queue=b.sycl_queue) + else: + x = x0.astype(dtype, copy=True).reshape(-1) + if x.shape[0] != n: + raise ValueError(f"x0 length {x.shape[0]} != n={n}") + + if M is None: + M_op = IdentityOperator((n, n), dtype=dtype) + else: + M_op = aslinearoperator(M) + if M_op.shape != A_op.shape: + raise ValueError( + f"preconditioner shape {M_op.shape} != " + f"operator shape {A_op.shape}" + ) + + fast_mv_M = _make_fast_matvec(M) + if fast_mv_M is not None: + _orig_M = M_op + + class _FastMOp(LinearOperator): + def __init__(self): + super().__init__(_orig_M.dtype, _orig_M.shape) + + def _matvec(self, x): + return fast_mv_M.matvec(x) + + def _rmatvec(self, x): + return fast_mv_M.rmatvec(x) + + M_op = _FastMOp() + + # Inject fast CSR SpMV for A if available. + fast_mv = _make_fast_matvec(A) + if fast_mv is not None: + _orig = A_op + + class _FastOp(LinearOperator): + def __init__(self): + super().__init__(_orig.dtype, _orig.shape) + + def _matvec(self, x): + return fast_mv.matvec(x) + + def _rmatvec(self, x): + return fast_mv.rmatvec(x) + + A_op = _FastOp() + + return A_op, M_op, x, b, dtype + + +def _get_atol(b_norm: float, atol, rtol: float) -> float: + """Absolute stopping tolerance: max(atol, rtol*||b||), mirroring SciPy.""" + if atol == "legacy" or atol is None: + atol = 0.0 + atol = float(atol) + if atol < 0: + raise ValueError( + f"atol={atol!r} is invalid; must be a real, non-negative number." + ) + return max(atol, float(rtol) * float(b_norm)) + +def cg( + A, + b, + x0: dpnp.ndarray | None = None, + *, + rtol: float = 1e-5, + tol: float | None = None, + maxiter: int | None = None, + M=None, + callback: Callable | None = None, + atol=None, +) -> tuple[dpnp.ndarray, int]: + """Conjugate Gradient -- pure dpnp/oneMKL, Hermitian positive definite A. + + Parameters + ---------- + A : array_like or LinearOperator -- HPD (n, n) + b : array_like -- right-hand side (n,) + x0 : array_like, optional -- initial guess + rtol : float -- relative tolerance (default 1e-5) + tol : float, optional -- deprecated alias for rtol + maxiter : int, optional -- max iterations (default 10*n) + M : LinearOperator or array_like, optional -- SPD preconditioner + callback: callable, optional -- callback(xk) after each iteration + atol : float, optional -- absolute tolerance + + Returns + ------- + x : dpnp.ndarray + info : int 0=converged >0=maxiter -1=breakdown + """ + if tol is not None: + rtol = tol + + A_op, M_op, x, b, dtype = _make_system(A, M, x0, b) + n = b.shape[0] + + bnrm = dpnp.linalg.norm(b) + bnrm_host = float(bnrm) + if bnrm_host == 0.0: + return dpnp.zeros_like(b), 0 + + atol_eff_host = _get_atol(bnrm_host, atol=atol, rtol=rtol) + + if maxiter is None: + maxiter = n * 10 + + rhotol = float(numpy.finfo(_np_dtype(dtype)).eps ** 2) + + r = b - A_op.matvec(x) if x0 is not None else b.copy() + z = M_op.matvec(r) + p = z.copy() + + # rz is kept as a 0-D dpnp array on device. + rz = dpnp.real(dpnp.vdot(r, z)) + + # Single sync for the initial breakdown check. + if float(dpnp.abs(rz)) < rhotol: + return x, 0 + + info = maxiter + + for _k in range(maxiter): + # Convergence check (sync). + rnorm = dpnp.linalg.norm(r) + if float(rnorm) <= atol_eff_host: + info = 0 + break + + Ap = A_op.matvec(p) + pAp = dpnp.real(dpnp.vdot(p, Ap)) # 0-D on device + + if float(dpnp.abs(pAp)) < rhotol: + info = -1 + break + + alpha = rz / pAp # 0-D on device + x = x + alpha * p # fully on-device + r = r - alpha * Ap + + if callback is not None: + callback(x) + + z = M_op.matvec(r) + rz_new = dpnp.real(dpnp.vdot(r, z)) + + if float(dpnp.abs(rz_new)) < rhotol: + info = 0 + break + + beta = rz_new / rz # 0-D on device + p = z + beta * p + rz = rz_new + else: + info = maxiter + + return x, int(info) + + +def gmres( + A, + b, + x0: dpnp.ndarray | None = None, + *, + rtol: float = 1e-5, + atol: float = 0.0, + restart: int | None = None, + maxiter: int | None = None, + M=None, + callback: Callable | None = None, + callback_type: str | None = None, +) -> tuple[dpnp.ndarray, int]: + """Uses Generalized Minimal RESidual iteration to solve ``Ax = b``. + + Parameters + ---------- + A : LinearOperator, dpnp sparse matrix, or 2-D dpnp.ndarray + The real or complex matrix of the linear system, shape (n, n). + b : dpnp.ndarray + Right-hand side of the linear system, shape (n,) or (n, 1). + x0 : dpnp.ndarray, optional + Starting guess for the solution. + rtol, atol : float + Tolerance for convergence: ``||r|| <= max(atol, rtol*||b||)``. + restart : int, optional + Number of iterations between restarts (default 20). Larger values + increase iteration cost but may be necessary for convergence. + maxiter : int, optional + Maximum number of iterations (default 10*n). + M : LinearOperator, dpnp sparse matrix, or 2-D dpnp.ndarray, optional + Preconditioner for ``A``; should approximate the inverse of ``A``. + callback : callable, optional + User-specified function to call on every restart. Called as + ``callback(arg)``, where ``arg`` is selected by ``callback_type``. + callback_type : {'x', 'pr_norm'}, optional + If 'x', the current solution vector is passed to the callback. + If 'pr_norm', the relative (preconditioned) residual norm. + Default is 'pr_norm' when a callback is supplied. + + Returns + ------- + x : dpnp.ndarray + The (approximate) solution. Note that this is M @ x in the + right-preconditioned formulation, matching CuPy's return value. + info : int + 0 if converged; iteration count if maxiter was reached. + + See Also + -------- + scipy.sparse.linalg.gmres + cupyx.scipy.sparse.linalg.gmres + """ + A_op, M_op, x, b, dtype = _make_system(A, M, x0, b) + matvec = A_op.matvec + psolve = M_op.matvec + + n = A_op.shape[0] + if n == 0: + return dpnp.empty_like(b), 0 + b_norm = dpnp.linalg.norm(b) + if b_norm == 0.0: + return b, 0 + atol = max(float(atol), rtol * float(b_norm)) + + if maxiter is None: + maxiter = n * 10 + if restart is None: + restart = 20 + restart = min(int(restart), n) + + if callback_type is None: + callback_type = "pr_norm" + if callback_type not in ("x", "pr_norm"): + raise ValueError(f"Unknown callback_type: {callback_type!r}") + if callback is None: + callback_type = None + + queue = b.sycl_queue + + # Krylov basis V, Hessenberg H, and RHS e all live on device to + # avoid host-device sync overhead (which dominates on Intel GPUs + # even for small transfers). CuPy keeps e on host and solves + # lstsq on CPU, but for dpnp we keep everything on device. + V = dpnp.empty((n, restart), dtype=dtype, sycl_queue=queue, order="F") + H = dpnp.zeros( + (restart + 1, restart), dtype=dtype, sycl_queue=queue, order="F" + ) + e = dpnp.zeros(restart + 1, dtype=dtype, sycl_queue=queue) + + compute_hu = _make_compute_hu(V) + + iters = 0 + while True: + mx = psolve(x) + r = b - matvec(mx) + r_norm = dpnp.linalg.norm(r) + + if callback_type == "x": + callback(mx) + elif callback_type == "pr_norm" and iters > 0: + callback(r_norm / b_norm) + + if r_norm <= atol or iters >= maxiter: + break + + v = r / r_norm + V[:, 0] = v + e[0] = r_norm + + # Arnoldi iteration + for j in range(restart): + z = psolve(v) + u = matvec(z) + H[: j + 1, j], u = compute_hu(u, j) + H[j + 1, j] = dpnp.linalg.norm(u) + if j + 1 < restart: + v = u / H[j + 1, j] + V[:, j + 1] = v + + # Solve the Hessenberg least-squares H y = e on device. + # Tiny problem (~restart x restart), kept on-device to avoid sync. + y, *_ = dpnp.linalg.lstsq(H, e, rcond=None) + x = x + dpnp.dot(V, y) + iters += restart + + info = 0 + if iters >= maxiter and not bool(r_norm <= atol): + info = iters + + return mx, info + + +def minres( + A, + b, + x0: dpnp.ndarray | None = None, + *, + rtol: float = 1e-5, + shift: float = 0.0, + tol: float | None = None, + maxiter: int | None = None, + M=None, + callback: Callable | None = None, + show: bool = False, + check: bool = False, +) -> tuple[dpnp.ndarray, int]: + """Uses MINimum RESidual iteration to solve ``Ax = b``. + + Solves the symmetric (possibly indefinite) system ``Ax = b`` or, + if *shift* is nonzero, ``(A - shift*I)x = b``. All computation + stays on the SYCL device; only scalar recurrence coefficients and + norms are transferred to the host for branching. + + The algorithm follows SciPy's MINRES (Paige & Saunders, 1975) + line-for-line. Three host syncs per iteration are unavoidable: + ``alpha`` and ``beta`` (Lanczos inner products) and ``ynorm`` + (solution norm for stopping tests). + + Parameters + ---------- + A : dpnp sparse matrix, 2-D dpnp.ndarray, or LinearOperator + The real symmetric or complex Hermitian matrix, shape ``(n, n)``. + b : dpnp.ndarray + Right-hand side, shape ``(n,)`` or ``(n, 1)``. + x0 : dpnp.ndarray, optional + Starting guess for the solution. + shift : float + If nonzero, solve ``(A - shift*I)x = b``. Default 0. + rtol : float + Relative tolerance for convergence. Default 1e-5. + tol : float, optional + Deprecated alias for *rtol*. + maxiter : int, optional + Maximum number of iterations. Default ``5*n``. + M : dpnp sparse matrix, dpnp.ndarray, or LinearOperator, optional + Preconditioner approximating the inverse of ``A``. + callback : callable, optional + Called as ``callback(xk)`` after each iteration. + show : bool + If True, print convergence summary each iteration. + check : bool + If True, verify that ``A`` and ``M`` are symmetric before + iterating. Costs extra matvecs. + + Returns + ------- + x : dpnp.ndarray + The converged (or best) solution. + info : int + 0 if converged, ``maxiter`` if the iteration limit was reached. + + Notes + ----- + This is a direct translation of the Paige--Saunders MINRES algorithm + as implemented in SciPy, adapted for dpnp device arrays with the + oneMKL SpMV cached-handle fast-path. + + See Also + -------- + scipy.sparse.linalg.minres + cupyx.scipy.sparse.linalg.minres + """ + if tol is not None: + rtol = tol + + A_op, M_op, x, b, dtype = _make_system(A, M, x0, b) + matvec = A_op.matvec + psolve = M_op.matvec + + n = A_op.shape[0] + if maxiter is None: + maxiter = 5 * n + + istop = 0 + itn = 0 + Anorm = 0 + Acond = 0 + rnorm = 0 + ynorm = 0 + + xtype = dtype + eps = dpnp.finfo(xtype).eps + + # ------------------------------------------------------------------ + # Set up y and v for the first Lanczos vector v1. + # y = beta1 * P' * v1, where P = M**(-1). + # v is really P' * v1. + # ------------------------------------------------------------------ + + Ax = matvec(x) + r1 = b - Ax + y = psolve(r1) + + # beta1 = -- one host sync (setup only). + # Transferred to host immediately because beta1 seeds ~5 host-side + # scalars (beta, qrnorm, phibar, rhs1) used in Python arithmetic + # and branches every iteration. Keeping it as a 0-D device array + # would cascade implicit syncs or 0-D allocations throughout the + # recurrence. + beta1 = dpnp.inner(r1, y) + + if beta1 < 0: + raise ValueError("indefinite preconditioner") + elif beta1 == 0: + return (x, 0) + + beta1 = dpnp.sqrt(beta1) + beta1 = float(beta1) + + if check: + # See if A is symmetric. All on device; only the bool syncs. + w_chk = matvec(y) + r2_chk = matvec(w_chk) + s = dpnp.inner(w_chk, w_chk) + t = dpnp.inner(y, r2_chk) + if abs(s - t) > (s + eps) * eps ** (1.0 / 3.0): + raise ValueError("non-symmetric matrix") + + # See if M is symmetric. + r2_chk = psolve(y) + s = dpnp.inner(y, y) + t = dpnp.inner(r1, r2_chk) + if abs(s - t) > (s + eps) * eps ** (1.0 / 3.0): + raise ValueError("non-symmetric preconditioner") + + # Initialise remaining quantities (all host-side scalars). + oldb = 0 + beta = beta1 + dbar = 0 + epsln = 0 + qrnorm = beta1 + phibar = beta1 + rhs1 = beta1 + rhs2 = 0 + tnorm2 = 0 + gmax = 0 + gmin = dpnp.finfo(xtype).max + cs = -1 + sn = 0 + queue = b.sycl_queue + w = dpnp.zeros(n, dtype=xtype, sycl_queue=queue) + w2 = dpnp.zeros(n, dtype=xtype, sycl_queue=queue) + r2 = r1 + + # Main Lanczos loop. + while itn < maxiter: + itn += 1 + + s = 1.0 / beta + v = s * y # on device + + y = matvec(v) + y = y - shift * v + + if itn >= 2: + y = y - (beta / oldb) * r1 + + # alpha = -- host sync #1 + alpha = float(dpnp.inner(v, y)) + + y = y - (alpha / beta) * r2 + r1 = r2 + r2 = y + y = psolve(r2) + oldb = beta + + # beta = sqrt() -- host sync #2 + beta = float(dpnp.inner(r2, y)) + if beta < 0: + raise ValueError("non-symmetric matrix") + beta = numpy.sqrt(beta) + + tnorm2 += alpha**2 + oldb**2 + beta**2 + + if itn == 1: + if beta / beta1 <= 10 * eps: + istop = -1 # Terminate later + + # Apply previous rotation Q_{k-1} to get + # [delta_k epsln_{k+1}] = [cs sn] [dbar_k 0 ] + # [gbar_k dbar_{k+1} ] [sn -cs] [alpha_k beta_{k+1}] + oldeps = epsln + delta = cs * dbar + sn * alpha + gbar = sn * dbar - cs * alpha + epsln = sn * beta + dbar = -cs * beta + root = numpy.sqrt(gbar**2 + dbar**2) + + # Compute the next plane rotation Q_k. + gamma = numpy.sqrt(gbar**2 + beta**2) + gamma = max(gamma, eps) + cs = gbar / gamma + sn = beta / gamma + phi = cs * phibar + phibar = sn * phibar + + # Update x -- all on device. + denom = 1.0 / gamma + w1 = w2 + w2 = w + w = (v - oldeps * w1 - delta * w2) * denom + x = x + phi * w + + # Go round again. + gmax = max(gmax, gamma) + gmin = min(gmin, gamma) + z = rhs1 / gamma + rhs1 = rhs2 - delta * z + rhs2 = -epsln * z + + # ---------------------------------------------------------- + # Estimate norms and test for convergence. + # ---------------------------------------------------------- + Anorm = numpy.sqrt(tnorm2) + ynorm = float(dpnp.linalg.norm(x)) # host sync #3 + epsa = Anorm * eps + epsx = Anorm * ynorm * eps + epsr = Anorm * ynorm * rtol + diag = gbar + if diag == 0: + diag = epsa + + qrnorm = phibar + rnorm = qrnorm + if ynorm == 0 or Anorm == 0: + test1 = numpy.inf + else: + test1 = rnorm / (Anorm * ynorm) # ||r|| / (||A|| ||x||) + if Anorm == 0: + test2 = numpy.inf + else: + test2 = root / Anorm # ||Ar|| / (||A|| ||r||) + + # Estimate cond(A). + Acond = gmax / gmin + + # Stopping criteria (SciPy's istop codes). + if istop == 0: + t1 = 1 + test1 + t2 = 1 + test2 + if t2 <= 1: + istop = 2 + if t1 <= 1: + istop = 1 + + if itn >= maxiter: + istop = 6 + if Acond >= 0.1 / eps: + istop = 4 + if epsx >= beta1: + istop = 3 + if test2 <= rtol: + istop = 2 + if test1 <= rtol: + istop = 1 + + if show: + prnt = ( + n <= 40 + or itn <= 10 + or itn >= maxiter - 10 + or itn % 10 == 0 + or qrnorm <= 10 * epsx + or qrnorm <= 10 * epsr + or Acond <= 1e-2 / eps + or istop != 0 + ) + if prnt: + x1 = float(x[0]) + print( + f"{itn:6g} {x1:12.5e} {test1:10.3e}" + f" {test2:10.3e}" + f" {Anorm:8.1e} {Acond:8.1e}" + f" {gbar / Anorm if Anorm else 0:8.1e}" + ) + if itn % 10 == 0: + print() + + if callback is not None: + callback(x) + + if istop != 0: + break + + if istop == 6: + info = maxiter + else: + info = 0 + + return (x, info) + + +def _make_compute_hu(V): + """Factory mirroring cupyx's _make_compute_hu using oneMKL gemv directly. + + Returns a closure compute_hu(u, j) that performs: + h = V[:, :j+1]^H @ u (gemv with transpose=True) + u = u - V[:, :j+1] @ h (gemv with transpose=False, then subtract) + + The current bi._gemv binding hardcodes alpha=1, beta=0, so the second + pass requires a temporary vector and an explicit subtraction. To get + CuPy's fused u -= V@h in one kernel, the C++ binding would need + alpha/beta parameters. + + V must be column-major; sub-views V[:, :j+1] of an F-order array + are themselves F-contiguous, so the same closure handles every j. + """ + if V.ndim != 2 or not V.flags.f_contiguous: + raise ValueError( + "_make_compute_hu: V must be a 2-D column-major (F-order) " + "dpnp array" + ) + + exec_q = V.sycl_queue + dtype = V.dtype + is_cpx = dpnp.issubdtype(dtype, dpnp.complexfloating) + + def compute_hu(u, j): + # h = V[:, :j+1]^H @ u (allocate fresh, length j+1) + h = dpnp.empty(j + 1, dtype=dtype, sycl_queue=exec_q) + + # Sub-view: column-major slice of the trailing axis is F-contiguous. + Vj = V[:, : j + 1] + Vj_usm = dpnp.get_usm_ndarray(Vj) + u_usm = dpnp.get_usm_ndarray(u) + h_usm = dpnp.get_usm_ndarray(h) + + _manager = dpu.SequentialOrderManager[exec_q] + + # Pass 1: h = Vj^T @ u (real) or h = (Vj^T @ u) then conj (complex) + ht1, ev1 = bi._gemv( + exec_q, + Vj_usm, + u_usm, + h_usm, + transpose=True, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht1, ev1) + + if is_cpx: + # h = conj(h) -- in-place, length j+1, negligible + h = dpnp.conj(h, out=h) + h_usm = dpnp.get_usm_ndarray(h) + + # Pass 2: tmp = Vj @ h, then u -= tmp + # No fused AXPY available, so we still allocate tmp. + tmp = dpnp.empty_like(u) + tmp_usm = dpnp.get_usm_ndarray(tmp) + ht2, ev2 = bi._gemv( + exec_q, + Vj_usm, + h_usm, + tmp_usm, + transpose=False, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht2, ev2) + + u -= tmp + return h, u + + return compute_hu diff --git a/dpnp/tests/test_scipy_sparse_linalg.py b/dpnp/tests/test_scipy_sparse_linalg.py new file mode 100644 index 000000000000..b11aa1fcc796 --- /dev/null +++ b/dpnp/tests/test_scipy_sparse_linalg.py @@ -0,0 +1,896 @@ +import warnings + +import numpy +import pytest +from numpy.testing import ( + assert_allclose, + assert_raises, +) + +import dpnp +from dpnp.scipy.sparse.linalg import ( + LinearOperator, + aslinearoperator, + cg, + gmres, + minres, +) +from dpnp.tests.helper import ( + assert_dtype_allclose, + generate_random_numpy_array, + get_all_dtypes, + get_float_complex_dtypes, + has_support_aspect64, + is_scipy_available, +) +from dpnp.tests.third_party.cupy import testing + +if is_scipy_available(): + import scipy.sparse.linalg as scipy_sla + + +# Helpers for constructing SPD, diagonally dominant, and symmetric +# indefinite test matrices. Kept small and local, matching the style of +# vvsort() at the top of test_linalg.py. +def _spd_matrix(n, dtype): + rng = numpy.random.default_rng(42) + is_complex = numpy.issubdtype(numpy.dtype(dtype), numpy.complexfloating) + if is_complex: + a = rng.standard_normal((n, n)) + 1j * rng.standard_normal((n, n)) + a = a.conj().T @ a + n * numpy.eye(n) + else: + a = rng.standard_normal((n, n)) + a = a.T @ a + n * numpy.eye(n) + return dpnp.asarray(a.astype(dtype)) + + +def _diag_dominant(n, dtype, seed=81): + rng = numpy.random.default_rng(seed) + is_complex = numpy.issubdtype(numpy.dtype(dtype), numpy.complexfloating) + if is_complex: + a = 0.05 * ( + rng.standard_normal((n, n)) + 1j * rng.standard_normal((n, n)) + ) + else: + a = 0.05 * rng.standard_normal((n, n)) + a = a + float(n) * numpy.eye(n) + return dpnp.asarray(a.astype(dtype)) + + +def _sym_indefinite(n, dtype, seed=99): + rng = numpy.random.default_rng(seed) + a = rng.standard_normal((n, n)) + q, _ = numpy.linalg.qr(a) + d = rng.standard_normal(n) + m = (q @ numpy.diag(d) @ q.T).astype(dtype) + return dpnp.asarray(m) + + +def _rhs(n, dtype, seed=7): + rng = numpy.random.default_rng(seed) + is_complex = numpy.issubdtype(numpy.dtype(dtype), numpy.complexfloating) + if is_complex: + b = rng.standard_normal(n) + 1j * rng.standard_normal(n) + else: + b = rng.standard_normal(n) + b /= numpy.linalg.norm(b) + return dpnp.asarray(b.astype(dtype)) + + +def _rtol_for(dtype): + if dtype in (dpnp.float32, dpnp.complex64, numpy.float32, numpy.complex64): + return 1e-5 + return 1e-8 + + +def _res_bound(dtype): + if dtype in (dpnp.float32, dpnp.complex64, numpy.float32, numpy.complex64): + return 1e-3 + return 1e-5 + + +# GMRES in dpnp.scipy.sparse.linalg._iterative uses real-valued Givens +# rotation formulas which are incorrect for complex Arnoldi, so GMRES +# returns wrong solutions for complex dtypes. Complex GMRES tests are +# xfailed below. When the Givens block is fixed the xfails will flip to +# XPASS and force an update here. +_GMRES_CPX_XFAIL = ( + "GMRES Givens rotation is real-valued; broken for complex dtypes" +) + +_GMRES_DTYPES = [ + dpnp.float32, + dpnp.float64, + pytest.param( + dpnp.complex64, + marks=pytest.mark.xfail(reason=_GMRES_CPX_XFAIL, strict=False), + ), + pytest.param( + dpnp.complex128, + marks=pytest.mark.xfail(reason=_GMRES_CPX_XFAIL, strict=False), + ), +] + + +class TestImports: + def test_all_symbols_importable(self): + from dpnp.scipy.sparse.linalg import ( # noqa: F401 + LinearOperator, + aslinearoperator, + cg, + gmres, + minres, + ) + + for sym in (LinearOperator, aslinearoperator, cg, gmres, minres): + assert callable(sym) + + def test_all_in_dunder_all(self): + import dpnp.scipy.sparse.linalg as mod + + for name in ( + "LinearOperator", + "aslinearoperator", + "cg", + "gmres", + "minres", + ): + assert name in mod.__all__ + + +class TestLinearOperator: + @pytest.mark.parametrize( + "shape", + [(5, 5), (7, 3), (3, 7)], + ids=["(5, 5)", "(7, 3)", "(3, 7)"], + ) + def test_shape(self, shape): + m, n = shape + lo = LinearOperator( + shape, + matvec=lambda x: dpnp.zeros(m, dtype=dpnp.float32), + dtype=dpnp.float32, + ) + assert lo.shape == (m, n) + assert lo.ndim == 2 + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_dtype_explicit(self, dtype): + n = 4 + a = dpnp.eye(n, dtype=dtype) + lo = LinearOperator( + (n, n), + matvec=lambda x: (a @ x.astype(dtype)).astype(dtype), + dtype=dtype, + ) + assert lo.dtype == dtype + + def test_dtype_inference_float64_default(self): + # Dtype inference probes matvec with a float64 vector, so the + # inferred dtype is float64 even when the underlying array is + # float32. Pin the current behaviour as a regression guard. + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 4 + a = dpnp.eye(n, dtype=dpnp.float32) + lo = LinearOperator((n, n), matvec=lambda x: a @ x) + assert lo.dtype == dpnp.float64 + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_matvec(self, dtype): + n = 6 + a = generate_random_numpy_array((n, n), dtype, seed_value=42) + ia = dpnp.array(a) + lo = LinearOperator((n, n), matvec=lambda x: ia @ x, dtype=dtype) + x = generate_random_numpy_array((n,), dtype, seed_value=1) + ix = dpnp.array(x) + result = lo.matvec(ix) + expected = a @ x + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_rmatvec(self, dtype): + n = 5 + a = generate_random_numpy_array((n, n), dtype, seed_value=12) + ia = dpnp.array(a) + lo = LinearOperator( + (n, n), + matvec=lambda x: ia @ x, + rmatvec=lambda x: dpnp.conj(ia.T) @ x, + dtype=dtype, + ) + x = generate_random_numpy_array((n,), dtype, seed_value=3) + ix = dpnp.array(x) + result = lo.rmatvec(ix) + expected = a.conj().T @ x + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_matmat_fallback_loop(self, dtype): + n, k = 5, 3 + a = generate_random_numpy_array((n, n), dtype, seed_value=55) + ia = dpnp.array(a) + lo = LinearOperator((n, n), matvec=lambda x: ia @ x, dtype=dtype) + x = generate_random_numpy_array((n, k), dtype, seed_value=9) + ix = dpnp.array(x) + result = lo.matmat(ix) + expected = a @ x + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_matmul_1d(self, dtype): + # lo @ x dispatches to matvec + n = 6 + a = generate_random_numpy_array((n, n), dtype, seed_value=42) + ia = dpnp.array(a) + lo = LinearOperator((n, n), matvec=lambda x: ia @ x, dtype=dtype) + x = generate_random_numpy_array((n,), dtype, seed_value=2) + ix = dpnp.array(x) + result = lo @ ix + expected = a @ x + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_matmul_2d(self, dtype): + # lo @ X dispatches to matmat + n, k = 5, 3 + a = generate_random_numpy_array((n, n), dtype, seed_value=42) + ia = dpnp.array(a) + lo = LinearOperator((n, n), matvec=lambda x: ia @ x, dtype=dtype) + x = generate_random_numpy_array((n, k), dtype, seed_value=5) + ix = dpnp.array(x) + result = lo @ ix + expected = a @ x + assert_dtype_allclose(result, expected) + + def test_call_alias(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 4 + ia = dpnp.eye(n, dtype=dpnp.float64) + lo = LinearOperator((n, n), matvec=lambda x: ia @ x, dtype=dpnp.float64) + ix = dpnp.ones(n, dtype=dpnp.float64) + assert_allclose(dpnp.asnumpy(lo(ix)), numpy.ones(n), atol=1e-12) + + def test_repr(self): + lo = LinearOperator( + (3, 4), + matvec=lambda x: dpnp.zeros(3, dtype=dpnp.float32), + dtype=dpnp.float32, + ) + r = repr(lo) + assert "LinearOperator" in r + assert "3x4" in r or "(3, 4)" in r + + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_subclass_custom_matmat(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + n, k = 7, 4 + a = generate_random_numpy_array((n, n), dtype, seed_value=42) + ia = dpnp.array(a) + + class MyOp(LinearOperator): + def __init__(self): + super().__init__(dtype=dtype, shape=(n, n)) + self._a = ia + + def _matvec(self, x): + return self._a @ x + + def _matmat(self, X): + return self._a @ X + + op = MyOp() + x = generate_random_numpy_array((n, k), dtype, seed_value=9) + ix = dpnp.array(x) + result = op.matmat(ix) + expected = a @ x + assert_dtype_allclose(result, expected) + + def test_linear_operator_errors(self): + lo = LinearOperator( + (3, 5), + matvec=lambda x: dpnp.zeros(3, dtype=dpnp.float32), + dtype=dpnp.float32, + ) + # matvec with wrong shape + assert_raises(ValueError, lo.matvec, dpnp.ones(4, dtype=dpnp.float32)) + + # rmatvec not provided + lo2 = LinearOperator( + (3, 3), + matvec=lambda x: dpnp.zeros(3, dtype=dpnp.float32), + dtype=dpnp.float32, + ) + assert_raises( + (NotImplementedError, ValueError), + lo2.rmatvec, + dpnp.zeros(3, dtype=dpnp.float32), + ) + + # matmat with 1-D input + assert_raises(ValueError, lo2.matmat, dpnp.ones(3, dtype=dpnp.float32)) + + # negative shape + assert_raises( + (ValueError, Exception), + LinearOperator, + (-1, 3), + matvec=lambda x: x, + dtype=dpnp.float32, + ) + + # shape with wrong ndim + assert_raises( + (ValueError, Exception), + LinearOperator, + (3,), + matvec=lambda x: x, + dtype=dpnp.float32, + ) + + +class TestAsLinearOperator: + def test_identity_if_already_linearoperator(self): + lo = LinearOperator((3, 3), matvec=lambda x: x, dtype=dpnp.float32) + assert aslinearoperator(lo) is lo + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_dense_dpnp_array_matvec(self, dtype): + n = 6 + a = generate_random_numpy_array((n, n), dtype, seed_value=42) + ia = dpnp.array(a) + lo = aslinearoperator(ia) + assert lo.shape == (n, n) + x = generate_random_numpy_array((n,), dtype, seed_value=1) + ix = dpnp.array(x) + result = lo.matvec(ix) + expected = a @ x + assert_dtype_allclose(result, expected) + + def test_dense_numpy_array_attributes_only(self): + # aslinearoperator(numpy_array) wraps with lambda x: A @ x where A + # remains a numpy array; calling matvec(dpnp_x) then fails because + # dpnp __rmatmul__ refuses numpy LHS. Only attributes are checked. + n = 5 + a = generate_random_numpy_array((n, n), numpy.float64, seed_value=42) + lo = aslinearoperator(a) + assert lo.shape == (n, n) + + def test_rmatvec_from_dpnp_dense(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 5 + a = generate_random_numpy_array((n, n), numpy.float64, seed_value=42) + ia = dpnp.array(a) + lo = aslinearoperator(ia) + x = generate_random_numpy_array((n,), numpy.float64, seed_value=2) + ix = dpnp.array(x) + result = lo.rmatvec(ix) + expected = a.conj().T @ x + assert_allclose(dpnp.asnumpy(result), expected, atol=1e-12) + + def test_duck_type_with_shape_and_matvec(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 4 + + class DuckOp: + shape = (n, n) + dtype = numpy.dtype(numpy.float64) + + def matvec(self, x): + return x * 2.0 + + def rmatvec(self, x): + return x * 2.0 + + lo = aslinearoperator(DuckOp()) + ix = dpnp.ones(n, dtype=dpnp.float64) + result = lo.matvec(ix) + assert_allclose(dpnp.asnumpy(result), numpy.full(n, 2.0), atol=1e-12) + + def test_aslinearoperator_errors(self): + assert_raises((TypeError, Exception), aslinearoperator, "nope") + + +class TestCg: + n = 30 + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_cg_converges_spd(self, dtype): + ia = _spd_matrix(self.n, dtype) + ib = _rhs(self.n, dtype) + x, info = cg(ia, ib, rtol=_rtol_for(dtype), maxiter=500) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + @pytest.mark.skipif(not is_scipy_available(), reason="SciPy not available") + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_cg_matches_scipy(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + a = dpnp.asnumpy(_spd_matrix(self.n, dtype)) + b = dpnp.asnumpy(_rhs(self.n, dtype)) + try: + x_ref, info_ref = scipy_sla.cg(a, b, rtol=1e-8, maxiter=500) + except TypeError: # scipy < 1.12 + x_ref, info_ref = scipy_sla.cg(a, b, tol=1e-8, maxiter=500) + assert info_ref == 0 + x_dp, info = cg(dpnp.array(a), dpnp.array(b), rtol=1e-8, maxiter=500) + assert info == 0 + tol = 1e-4 if dtype == dpnp.float32 else 1e-8 + assert_allclose(dpnp.asnumpy(x_dp), x_ref, rtol=tol, atol=tol) + + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_cg_x0_warm_start(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dtype) + ib = _rhs(self.n, dtype) + x0 = dpnp.ones(self.n, dtype=dtype) + x, info = cg(ia, ib, x0=x0, rtol=_rtol_for(dtype), maxiter=500) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_cg_b_2dim(self, dtype): + # b with shape (n, 1) must be accepted and flattened internally + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dtype) + ib = _rhs(self.n, dtype).reshape(self.n, 1) + _, info = cg(ia, ib, rtol=1e-8, maxiter=500) + assert info == 0 + + def test_cg_b_zero(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(10, dpnp.float64) + ib = dpnp.zeros(10, dtype=dpnp.float64) + x, info = cg(ia, ib, rtol=1e-8) + assert info == 0 + assert_allclose(dpnp.asnumpy(x), numpy.zeros(10), atol=1e-14) + + def test_cg_callback(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + calls = [] + cg( + ia, + ib, + callback=lambda xk: calls.append(float(dpnp.linalg.norm(xk))), + rtol=1e-10, + maxiter=200, + ) + assert len(calls) > 0 + + def test_cg_atol(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + x, _ = cg(ia, ib, rtol=0.0, atol=1e-1, maxiter=500) + assert float(dpnp.linalg.norm(ia @ x - ib)) < 1.0 + + def test_cg_exact_solution(self): + # x0 == true solution must return info == 0 immediately + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 10 + ia = _spd_matrix(n, dpnp.float64) + ib = _rhs(n, dpnp.float64) + x_true = dpnp.array( + numpy.linalg.solve(dpnp.asnumpy(ia), dpnp.asnumpy(ib)) + ) + _, info = cg(ia, ib, x0=x_true, rtol=1e-12) + assert info == 0 + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_cg_via_linear_operator(self, dtype): + ia = _spd_matrix(self.n, dtype) + ib = _rhs(self.n, dtype) + lo = aslinearoperator(ia) + x, info = cg(lo, ib, rtol=_rtol_for(dtype), maxiter=500) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + def test_cg_maxiter_nonconvergence(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(50, dpnp.float64) + ib = _rhs(50, dpnp.float64) + _, info = cg(ia, ib, rtol=1e-15, atol=0.0, maxiter=1) + assert info != 0 + + def test_cg_diag_preconditioner(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + M = aslinearoperator(dpnp.diag(1.0 / dpnp.diag(ia))) + _, info = cg(ia, ib, M=M, rtol=1e-8, maxiter=500) + assert info == 0 + + def test_cg_errors(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(5, dpnp.float64) + ib = dpnp.ones(6, dtype=dpnp.float64) + # b length mismatch + with pytest.raises((ValueError, Exception)): + cg(ia, ib, maxiter=1) + + +class TestGmres: + n = 30 + + @pytest.mark.parametrize("dtype", _GMRES_DTYPES) + def test_gmres_converges_diag_dominant(self, dtype): + if not has_support_aspect64() and dtype in ( + dpnp.float64, + dpnp.complex128, + ): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dtype) + ib = _rhs(self.n, dtype) + x, _ = gmres( + ia, + ib, + rtol=_rtol_for(dtype), + maxiter=200, + restart=self.n, + ) + # Check actual residual rather than info: see comment above + # _GMRES_CPX_XFAIL. + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + @pytest.mark.skipif(not is_scipy_available(), reason="SciPy not available") + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_gmres_matches_scipy(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + a = dpnp.asnumpy(_diag_dominant(self.n, dtype)) + b = dpnp.asnumpy(_rhs(self.n, dtype)) + req_rtol = _rtol_for(dtype) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + x_ref, _ = scipy_sla.gmres( + a, b, rtol=req_rtol, restart=self.n, maxiter=None + ) + except TypeError: # scipy < 1.12 + x_ref, _ = scipy_sla.gmres( + a, b, tol=req_rtol, restart=self.n, maxiter=None + ) + x_dp, info = gmres( + dpnp.array(a), + dpnp.array(b), + rtol=req_rtol, + restart=self.n, + maxiter=50, + ) + assert info == 0 + tol = 1e-3 if dtype == dpnp.float32 else 1e-7 + assert_allclose(dpnp.asnumpy(x_dp), x_ref, rtol=tol, atol=tol) + + @pytest.mark.parametrize("restart", [None, 5, 15], ids=["None", "5", "15"]) + def test_gmres_restart_values(self, restart): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + _, info = gmres(ia, ib, rtol=1e-8, restart=restart, maxiter=100) + assert info == 0 + + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_gmres_x0_warm_start(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dtype) + ib = _rhs(self.n, dtype) + x0 = dpnp.ones(self.n, dtype=dtype) + x, _ = gmres( + ia, + ib, + x0=x0, + rtol=_rtol_for(dtype), + restart=self.n, + maxiter=200, + ) + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + def test_gmres_b_2dim(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64).reshape(self.n, 1) + _, info = gmres(ia, ib, rtol=1e-8, restart=self.n, maxiter=100) + assert info == 0 + + def test_gmres_b_zero(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(10, dpnp.float64) + ib = dpnp.zeros(10, dtype=dpnp.float64) + x, info = gmres(ia, ib, rtol=1e-8) + assert info == 0 + assert_allclose(dpnp.asnumpy(x), numpy.zeros(10), atol=1e-14) + + def test_gmres_callback_x(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + calls = [] + gmres( + ia, + ib, + callback=lambda xk: calls.append(1), + callback_type="x", + rtol=1e-10, + maxiter=20, + restart=self.n, + ) + assert len(calls) > 0 + + def test_gmres_callback_pr_norm(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + values = [] + gmres( + ia, + ib, + callback=lambda r: values.append(float(r)), + callback_type="pr_norm", + rtol=1e-10, + maxiter=20, + restart=self.n, + ) + assert len(values) > 0 + assert all(v >= 0 for v in values) + + def test_gmres_atol(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + x, _ = gmres( + ia, + ib, + rtol=0.0, + atol=1e-6, + restart=self.n, + maxiter=50, + ) + assert float(dpnp.linalg.norm(ia @ x - ib)) < 1e-4 + + @pytest.mark.parametrize("dtype", _GMRES_DTYPES) + def test_gmres_via_linear_operator(self, dtype): + if not has_support_aspect64() and dtype in ( + dpnp.float64, + dpnp.complex128, + ): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dtype) + ib = _rhs(self.n, dtype) + lo = aslinearoperator(ia) + x, _ = gmres( + lo, + ib, + rtol=_rtol_for(dtype), + restart=self.n, + maxiter=200, + ) + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + def test_gmres_nonconvergence(self): + # Ill-conditioned Hilbert matrix + tiny restart must not converge + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + n = 48 + idx = numpy.arange(n, dtype=numpy.float64) + a = 1.0 / (idx[:, None] + idx[None, :] + 1.0) + rng = numpy.random.default_rng(5) + b = rng.standard_normal(n) + ia = dpnp.array(a) + ib = dpnp.array(b) + x, info = gmres(ia, ib, rtol=1e-15, atol=0.0, restart=2, maxiter=2) + rel = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert rel > 1e-12 + assert info != 0 + + @pytest.mark.xfail(reason=_GMRES_CPX_XFAIL, strict=False) + def test_gmres_complex_system(self): + if not has_support_aspect64(): + pytest.skip("complex128 not supported on this device") + n = 15 + ia = _diag_dominant(n, dpnp.complex128) + ib = _rhs(n, dpnp.complex128) + x, _ = gmres(ia, ib, rtol=1e-8, restart=n, maxiter=200) + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < 1e-5 + + def test_gmres_errors(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + # unknown callback_type + assert_raises(ValueError, gmres, ia, ib, callback_type="garbage") + + +class TestMinres: + n = 30 + + @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64]) + def test_minres_converges_spd(self, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dtype) + ib = _rhs(self.n, dtype) + x, info = minres(ia, ib, rtol=1e-8, maxiter=500) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < 1e-4 + + def test_minres_converges_sym_indefinite(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _sym_indefinite(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + x, _ = minres(ia, ib, rtol=1e-8, maxiter=1000) + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < 1e-3 + + @pytest.mark.skipif(not is_scipy_available(), reason="SciPy not available") + def test_minres_matches_scipy(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + a = dpnp.asnumpy(_spd_matrix(self.n, dpnp.float64)) + b = dpnp.asnumpy(_rhs(self.n, dpnp.float64)) + try: + x_ref, _ = scipy_sla.minres(a, b, rtol=1e-8) + except TypeError: + x_ref, _ = scipy_sla.minres(a, b, tol=1e-8) + x_dp, info = minres(dpnp.array(a), dpnp.array(b), rtol=1e-8) + assert info == 0 + assert_allclose(dpnp.asnumpy(x_dp), x_ref, rtol=1e-5, atol=1e-6) + + def test_minres_x0_warm_start(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + x0 = dpnp.zeros(self.n, dtype=dpnp.float64) + _, info = minres(ia, ib, x0=x0, rtol=1e-8) + assert info == 0 + + def test_minres_shift(self): + # shift != 0 solves (A - shift*I) x = b + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + a = dpnp.asnumpy(_spd_matrix(self.n, dpnp.float64)) + b = dpnp.asnumpy(_rhs(self.n, dpnp.float64)) + shift = 0.5 + x_dp, info = minres( + dpnp.array(a), dpnp.array(b), shift=shift, rtol=1e-8 + ) + assert info == 0 + a_shifted = a - shift * numpy.eye(self.n) + res = numpy.linalg.norm( + a_shifted @ dpnp.asnumpy(x_dp) - b + ) / numpy.linalg.norm(b) + assert res < 1e-4 + + def test_minres_b_zero(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(10, dpnp.float64) + ib = dpnp.zeros(10, dtype=dpnp.float64) + x, info = minres(ia, ib, rtol=1e-8) + assert info == 0 + assert_allclose(dpnp.asnumpy(x), numpy.zeros(10), atol=1e-14) + + def test_minres_via_linear_operator(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + lo = aslinearoperator(ia) + _, info = minres(lo, ib, rtol=1e-8) + assert info == 0 + + def test_minres_callback(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(self.n, dpnp.float64) + ib = _rhs(self.n, dpnp.float64) + calls = [] + minres( + ia, + ib, + callback=lambda xk: calls.append(1), + rtol=1e-10, + ) + assert len(calls) > 0 + + def test_minres_errors(self): + if not has_support_aspect64(): + pytest.skip("float64 not supported on this device") + lo = aslinearoperator(dpnp.ones((4, 5), dtype=dpnp.float64)) + ib = dpnp.ones(4, dtype=dpnp.float64) + # non-square operator + assert_raises((ValueError, Exception), minres, lo, ib) + + +class TestSolversIntegration: + @pytest.mark.parametrize( + "n, dtype", + [ + (10, dpnp.float32), + (10, dpnp.float64), + (30, dpnp.float64), + (50, dpnp.float64), + ], + ) + def test_cg_spd_via_linearoperator(self, n, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(n, dtype) + lo = aslinearoperator(ia) + ib = _rhs(n, dtype) + x, info = cg(lo, ib, rtol=_rtol_for(dtype), maxiter=n * 10) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + @pytest.mark.parametrize( + "n, dtype", + [ + (10, dpnp.float32), + (10, dpnp.float64), + (30, dpnp.float64), + ], + ) + def test_gmres_nonsymmetric_via_linearoperator(self, n, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _diag_dominant(n, dtype) + lo = aslinearoperator(ia) + ib = _rhs(n, dtype) + x, _ = gmres(lo, ib, rtol=_rtol_for(dtype), restart=n, maxiter=200) + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < _res_bound(dtype) + + @pytest.mark.skipif( + not is_scipy_available(), reason="SciPy required for minres" + ) + @pytest.mark.parametrize( + "n, dtype", + [ + (10, dpnp.float64), + (30, dpnp.float64), + ], + ) + def test_minres_spd_via_linearoperator(self, n, dtype): + if not has_support_aspect64() and dtype == dpnp.float64: + pytest.skip("float64 not supported on this device") + ia = _spd_matrix(n, dtype) + lo = aslinearoperator(ia) + ib = _rhs(n, dtype) + x, info = minres(lo, ib, rtol=1e-8) + assert info == 0 + res = float(dpnp.linalg.norm(ia @ x - ib) / dpnp.linalg.norm(ib)) + assert res < 1e-4