From 663a693a4ff285a3a132f0cf894623aea7a6789e Mon Sep 17 00:00:00 2001 From: ThreeMonth03 Date: Thu, 21 May 2026 21:53:08 +0800 Subject: [PATCH] Implement matmul_veclib() with cblas --- cpp/modmesh/buffer/CMakeLists.txt | 2 + cpp/modmesh/buffer/SimpleArray.hpp | 395 ++------------ cpp/modmesh/buffer/matmul.cpp | 205 ++++++++ cpp/modmesh/buffer/matmul.hpp | 491 ++++++++++++++++++ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 21 +- cpp/modmesh/math/Complex.hpp | 22 + gtests/CMakeLists.txt | 5 + profiling/profile_matrix_ops.py | 24 +- tests/test_matrix.py | 96 +++- 9 files changed, 884 insertions(+), 377 deletions(-) create mode 100644 cpp/modmesh/buffer/matmul.cpp create mode 100644 cpp/modmesh/buffer/matmul.hpp diff --git a/cpp/modmesh/buffer/CMakeLists.txt b/cpp/modmesh/buffer/CMakeLists.txt index cc86aa890..34b007083 100644 --- a/cpp/modmesh/buffer/CMakeLists.txt +++ b/cpp/modmesh/buffer/CMakeLists.txt @@ -11,11 +11,13 @@ set(MODMESH_BUFFER_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/BufferExpander.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleArray.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleCollector.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.hpp CACHE FILEPATH "" FORCE) set(MODMESH_BUFFER_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/BufferExpander.cpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleArray.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp CACHE FILEPATH "" FORCE) set(MODMESH_BUFFER_PYMODHEADERS diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 4cff83183..494013e88 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -29,6 +29,7 @@ */ #include +#include #include #include @@ -134,363 +135,6 @@ struct SimpleArrayInternalTypes using buffer_type = ConcreteBuffer; }; /* end class SimpleArrayInternalType */ -template -class SimpleArrayMatmulHelper -{ - -private: - - using internal_types = detail::SimpleArrayInternalTypes; - -public: - - using value_type = typename internal_types::value_type; - using shape_type = typename internal_types::shape_type; - - SimpleArrayMatmulHelper() = delete; - SimpleArrayMatmulHelper(A const & lhs, A const & rhs); - SimpleArrayMatmulHelper(A const & lhs, - A const & rhs, - size_t tile_x, - size_t tile_y, - size_t tile_z); - ~SimpleArrayMatmulHelper() = default; - - SimpleArrayMatmulHelper(SimpleArrayMatmulHelper const &) = delete; - SimpleArrayMatmulHelper(SimpleArrayMatmulHelper &&) = delete; - SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper const &) = delete; - SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper &&) = delete; - - A matmul(); - A matmul_fast(); - -private: - - static std::string shape_str(A const & arr); - void check_dims() const; - void check_inner(size_t lhs_idx, size_t rhs_idx) const; - void check_tiles() const; - A matmul_vec_vec(); - A matmul_vec_mat(); - A matmul_mat_vec(); - A matmul_mat_mat(); - A pack_rhs(size_t n, size_t k); - void accumulate_tile(A const & packed_rhs, - size_t row_begin, - size_t row_end, - size_t col_begin, - size_t col_end, - size_t inner_begin, - size_t inner_end); - A matmul_mat_mat_tiled(); - - A const & m_lhs; - A const & m_rhs; - A m_result; - size_t m_tile_x; - size_t m_tile_y; - size_t m_tile_z; - -}; /* end class SimpleArrayMatmulHelper */ - -template -SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, A const & rhs) - : SimpleArrayMatmulHelper(lhs, rhs, 0, 0, 0) -{ -} - -template -SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, - A const & rhs, - size_t tile_x, - size_t tile_y, - size_t tile_z) - : m_lhs(lhs) - , m_rhs(rhs) - , m_tile_x(tile_x) - , m_tile_y(tile_y) - , m_tile_z(tile_z) -{ - check_dims(); - - size_t const lhs_ndim = m_lhs.ndim(); - size_t const rhs_ndim = m_rhs.ndim(); - - if (lhs_ndim == 1 && rhs_ndim == 1) - { - check_inner(0, 0); - m_result = A(1); - return; - } - - if (lhs_ndim == 1) - { - check_inner(0, 0); - m_result = A(m_rhs.shape(1)); - return; - } - - if (rhs_ndim == 1) - { - check_inner(1, 0); - m_result = A(m_lhs.shape(0)); - return; - } - - check_inner(1, 0); - shape_type const result_shape{m_lhs.shape(0), m_rhs.shape(1)}; - m_result = A(result_shape); -} - -template -A SimpleArrayMatmulHelper::matmul() -{ - if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) - { - return matmul_vec_vec(); - } - if (m_lhs.ndim() == 1) - { - return matmul_vec_mat(); - } - if (m_rhs.ndim() == 1) - { - return matmul_mat_vec(); - } - - return matmul_mat_mat(); -} - -/** - * Perform fast matrix multiplication for SimpleArrays. - * This implementation currently uses tiling for 2D x 2D matrix multiplication. - * Future optimizations may add other techniques such as SIMD kernels. - */ -template -A SimpleArrayMatmulHelper::matmul_fast() -{ - check_tiles(); - - if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) - { - return matmul_vec_vec(); - } - if (m_lhs.ndim() == 1) - { - return matmul_vec_mat(); - } - if (m_rhs.ndim() == 1) - { - return matmul_mat_vec(); - } - - return matmul_mat_mat_tiled(); -} - -/** - * Format shape for matrix multiplication diagnostics. - */ -template -std::string SimpleArrayMatmulHelper::shape_str(A const & arr) -{ - if (arr.ndim() == 0) - { - return "()"; - } - - std::string result = "("; - for (size_t i = 0; i < arr.ndim(); ++i) - { - if (i > 0) - { - result += ","; - } - result += std::to_string(arr.shape(i)); - } - result += ")"; - return result; -} - -template -void SimpleArrayMatmulHelper::check_dims() const -{ - bool const lhs_is_supported = m_lhs.ndim() == 1 || m_lhs.ndim() == 2; - bool const rhs_is_supported = m_rhs.ndim() == 1 || m_rhs.ndim() == 2; - if (lhs_is_supported && rhs_is_supported) - { - return; - } - - std::string const err = std::format("SimpleArray::matmul(): unsupported dimensions: " - "this={} other={}. SimpleArray must be 1D or 2D.", - shape_str(m_lhs), - shape_str(m_rhs)); - throw std::out_of_range(err); -} - -template -void SimpleArrayMatmulHelper::check_inner(size_t lhs_idx, size_t rhs_idx) const -{ - if (m_lhs.shape(lhs_idx) == m_rhs.shape(rhs_idx)) - { - return; - } - - throw std::out_of_range( - std::format("SimpleArray::matmul(): shape mismatch: this={} other={}", - shape_str(m_lhs), - shape_str(m_rhs))); -} - -template -void SimpleArrayMatmulHelper::check_tiles() const -{ - if (m_tile_x != 0 && m_tile_y != 0 && m_tile_z != 0) - { - return; - } - - throw std::out_of_range( - std::format("SimpleArray::fast_matmul(): tile sizes must be positive: " - "tile_x={} tile_y={} tile_z={}", - m_tile_x, - m_tile_y, - m_tile_z)); -} - -template -A SimpleArrayMatmulHelper::matmul_vec_vec() -{ - size_t const k = m_lhs.shape(0); - value_type v = 0; - for (size_t i = 0; i < k; ++i) - { - v += m_lhs(i) * m_rhs.data(i); - } - m_result.data(0) = v; - return std::move(m_result); -} - -template -A SimpleArrayMatmulHelper::matmul_vec_mat() -{ - size_t const n = m_result.size(); - size_t const k = m_lhs.shape(0); - for (size_t j = 0; j < n; ++j) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += m_lhs(l) * m_rhs(l, j); - } - m_result.data(j) = v; - } - return std::move(m_result); -} - -template -A SimpleArrayMatmulHelper::matmul_mat_vec() -{ - size_t const m = m_result.size(); - size_t const k = m_lhs.shape(1); - for (size_t i = 0; i < m; ++i) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += m_lhs(i, l) * m_rhs(l); - } - m_result.data(i) = v; - } - return std::move(m_result); -} - -template -A SimpleArrayMatmulHelper::matmul_mat_mat() -{ - size_t const m = m_result.shape(0); - size_t const n = m_result.shape(1); - size_t const k = m_lhs.shape(1); - for (size_t i = 0; i < m; ++i) - { - for (size_t j = 0; j < n; ++j) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += m_lhs(i, l) * m_rhs(l, j); - } - m_result(i, j) = v; - } - } - return std::move(m_result); -} - -template -A SimpleArrayMatmulHelper::pack_rhs(size_t n, size_t k) -{ - shape_type const packing_shape{n, k}; - A packing(packing_shape); - for (size_t i = 0; i < n; ++i) - { - for (size_t j = 0; j < k; ++j) - { - packing(i, j) = m_rhs(j, i); - } - } - return packing; -} - -template -void SimpleArrayMatmulHelper::accumulate_tile(A const & packed_rhs, - size_t row_begin, - size_t row_end, - size_t col_begin, - size_t col_end, - size_t inner_begin, - size_t inner_end) -{ - for (size_t i = row_begin; i < row_end; ++i) - { - for (size_t j = col_begin; j < col_end; ++j) - { - value_type v = m_result(i, j); - for (size_t l = inner_begin; l < inner_end; ++l) - { - v += m_lhs(i, l) * packed_rhs(j, l); - } - m_result(i, j) = v; - } - } -} - -template -A SimpleArrayMatmulHelper::matmul_mat_mat_tiled() -{ - size_t const m = m_result.shape(0); - size_t const n = m_result.shape(1); - size_t const k = m_lhs.shape(1); - A packed_rhs = pack_rhs(n, k); - for (size_t i = 0; i < m_result.size(); ++i) - { - m_result.data(i) = value_type{0}; - } - for (size_t row = 0; row < m; row += m_tile_x) - { - size_t const row_end = std::min(row + m_tile_x, m); - for (size_t col = 0; col < n; col += m_tile_y) - { - size_t const col_end = std::min(col + m_tile_y, n); - for (size_t inner = 0; inner < k; inner += m_tile_z) - { - size_t const inner_end = std::min(inner + m_tile_z, k); - accumulate_tile(packed_rhs, row, row_end, col, col_end, inner, inner_end); - } - } - } - return std::move(m_result); -} - template class SimpleArrayMixinModifiers { @@ -1309,11 +953,13 @@ class SimpleArrayMixinCalculators A matmul(A const & other) const; A & imatmul(A const & other); - A fast_matmul(A const & other, + A matmul_veclib(A const & other) const; + A & imatmul_veclib(A const & other); + A matmul_fast(A const & other, size_t tile_x, size_t tile_y, size_t tile_z) const; - A & fast_imatmul(A const & other, + A & imatmul_fast(A const & other, size_t tile_x, size_t tile_y, size_t tile_z); @@ -1437,12 +1083,37 @@ A & SimpleArrayMixinCalculators::imatmul(A const & other) return *athis; } +/** + * Perform matrix multiplication using Accelerate/CBLAS when available. + */ +template +A SimpleArrayMixinCalculators::matmul_veclib(A const & other) const +{ + auto const * athis = static_cast(this); + SimpleArrayMatmulHelper helper(*athis, other); + return helper.matmul_veclib(); +} + +/** + * Perform in-place matrix multiplication using Accelerate/CBLAS when available. + * The result replaces the content of the current array. + */ +template +A & SimpleArrayMixinCalculators::imatmul_veclib(A const & other) +{ + auto athis = static_cast(this); + A result = athis->matmul_veclib(other); + *athis = std::move(result); + + return *athis; +} + /** * Perform fast matrix multiplication for SimpleArrays. * This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication. */ template -A SimpleArrayMixinCalculators::fast_matmul(A const & other, +A SimpleArrayMixinCalculators::matmul_fast(A const & other, size_t tile_x, size_t tile_y, size_t tile_z) const @@ -1458,13 +1129,13 @@ A SimpleArrayMixinCalculators::fast_matmul(A const & other, * The result replaces the content of the current array. */ template -A & SimpleArrayMixinCalculators::fast_imatmul(A const & other, +A & SimpleArrayMixinCalculators::imatmul_fast(A const & other, size_t tile_x, size_t tile_y, size_t tile_z) { auto athis = static_cast(this); - A result = athis->fast_matmul(other, tile_x, tile_y, tile_z); + A result = athis->matmul_fast(other, tile_x, tile_y, tile_z); *athis = std::move(result); return *athis; diff --git a/cpp/modmesh/buffer/matmul.cpp b/cpp/modmesh/buffer/matmul.cpp new file mode 100644 index 000000000..6a3bd1371 --- /dev/null +++ b/cpp/modmesh/buffer/matmul.cpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2026, Chun-Shih Chang + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * - Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * - Neither the name of the copyright holder nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include + +#if defined(__APPLE__) && defined(__arm64__) +#ifndef ACCELERATE_NEW_LAPACK +#define ACCELERATE_NEW_LAPACK +#endif +#ifndef ACCELERATE_LAPACK_ILP64 +#define ACCELERATE_LAPACK_ILP64 +#endif +#include + +#include +#include +#endif + +#include + +namespace modmesh +{ + +namespace detail +{ + +#if defined(__APPLE__) && defined(__arm64__) +struct BlasDims +{ + BlasDims(size_t m_in, size_t n_in, size_t k_in) + : m(to_lapack_int(m_in, "m")) + , n(to_lapack_int(n_in, "n")) + , k(to_lapack_int(k_in, "k")) + { + } + + __LAPACK_int m; + __LAPACK_int n; + __LAPACK_int k; + +private: + + static __LAPACK_int to_lapack_int(size_t value, char const * name) + { + if (value <= static_cast(std::numeric_limits<__LAPACK_int>::max())) + { + return static_cast<__LAPACK_int>(value); + } + + throw std::out_of_range( + std::format("SimpleArray::matmul_veclib(): {}={} exceeds LAPACK integer range", + name, + value)); + } +}; + +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + float const * lhs, + float const * rhs, + float * result) +{ + BlasDims const dims(m, n, k); + cblas_sgemm(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + dims.m, + dims.n, + dims.k, + 1.0F, + lhs, + dims.k, + rhs, + dims.n, + 0.0F, + result, + dims.n); +} + +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + double const * lhs, + double const * rhs, + double * result) +{ + BlasDims const dims(m, n, k); + cblas_dgemm(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + dims.m, + dims.n, + dims.k, + 1.0, + lhs, + dims.k, + rhs, + dims.n, + 0.0, + result, + dims.n); +} + +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + Complex const * lhs, + Complex const * rhs, + Complex * result) +{ + BlasDims const dims(m, n, k); + std::complex const alpha{1.0F, 0.0F}; + std::complex const beta{0.0F, 0.0F}; + cblas_cgemm(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + dims.m, + dims.n, + dims.k, + &alpha, + as_std_complex_pointer(lhs), + dims.k, + as_std_complex_pointer(rhs), + dims.n, + &beta, + as_std_complex_pointer(result), + dims.n); +} + +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + Complex const * lhs, + Complex const * rhs, + Complex * result) +{ + BlasDims const dims(m, n, k); + std::complex const alpha{1.0, 0.0}; + std::complex const beta{0.0, 0.0}; + cblas_zgemm(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + dims.m, + dims.n, + dims.k, + &alpha, + as_std_complex_pointer(lhs), + dims.k, + as_std_complex_pointer(rhs), + dims.n, + &beta, + as_std_complex_pointer(result), + dims.n); +} +#else +void matmul_veclib_backend(size_t, size_t, size_t, float const *, float const *, float *) +{ + throw_matmul_veclib_unavailable(); +} + +void matmul_veclib_backend(size_t, size_t, size_t, double const *, double const *, double *) +{ + throw_matmul_veclib_unavailable(); +} + +void matmul_veclib_backend(size_t, size_t, size_t, Complex const *, Complex const *, Complex *) +{ + throw_matmul_veclib_unavailable(); +} + +void matmul_veclib_backend(size_t, size_t, size_t, Complex const *, Complex const *, Complex *) +{ + throw_matmul_veclib_unavailable(); +} +#endif + +} /* end namespace detail */ + +} /* end namespace modmesh */ diff --git a/cpp/modmesh/buffer/matmul.hpp b/cpp/modmesh/buffer/matmul.hpp new file mode 100644 index 000000000..417592338 --- /dev/null +++ b/cpp/modmesh/buffer/matmul.hpp @@ -0,0 +1,491 @@ +#pragma once + +/* + * Copyright (c) 2026, Chun-Shih Chang + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * - Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * - Neither the name of the copyright holder nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace modmesh +{ + +namespace detail +{ + +#if defined(__APPLE__) && defined(__arm64__) +template +inline constexpr bool can_matmul_veclib_v = std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>; +#else +template +inline constexpr bool can_matmul_veclib_v = false; +#endif /* defined(__APPLE__) && defined(__arm64__) */ + +inline void throw_matmul_veclib_unavailable() +{ + throw std::runtime_error( + "SimpleArray::matmul_veclib(): Accelerate/CBLAS matmul is only " + "available on Apple Silicon"); +} + +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + float const * lhs, + float const * rhs, + float * result); +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + double const * lhs, + double const * rhs, + double * result); +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + Complex const * lhs, + Complex const * rhs, + Complex * result); +void matmul_veclib_backend(size_t m, + size_t n, + size_t k, + Complex const * lhs, + Complex const * rhs, + Complex * result); + +template +void matmul_veclib_backend(size_t, size_t, size_t, T const *, T const *, T *) +{ + throw_matmul_veclib_unavailable(); +} + +template +class SimpleArrayMatmulHelper +{ + +public: + + using value_type = T; + using shape_type = small_vector; + + SimpleArrayMatmulHelper() = delete; + SimpleArrayMatmulHelper(A const & lhs, A const & rhs); + SimpleArrayMatmulHelper(A const & lhs, + A const & rhs, + size_t tile_x, + size_t tile_y, + size_t tile_z); + ~SimpleArrayMatmulHelper() = default; + + SimpleArrayMatmulHelper(SimpleArrayMatmulHelper const &) = delete; + SimpleArrayMatmulHelper(SimpleArrayMatmulHelper &&) = delete; + SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper const &) = delete; + SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper &&) = delete; + + A matmul(); + A matmul_fast(); + A matmul_veclib(); + +private: + + static std::string shape_str(A const & arr); + void check_dims() const; + void check_inner(size_t lhs_idx, size_t rhs_idx) const; + void check_tiles() const; + A matmul_vec_vec(); + A matmul_vec_mat(); + A matmul_mat_vec(); + A matmul_mat_mat(); + A matmul_mat_mat_veclib(); + A pack_rhs(size_t n, size_t k); + void accumulate_tile(A const & packed_rhs, + size_t row_begin, + size_t row_end, + size_t col_begin, + size_t col_end, + size_t inner_begin, + size_t inner_end); + A matmul_mat_mat_tiled(); + + A const & m_lhs; + A const & m_rhs; + A m_result; + size_t m_tile_x; + size_t m_tile_y; + size_t m_tile_z; + +}; /* end class SimpleArrayMatmulHelper */ + +template +SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, A const & rhs) + : SimpleArrayMatmulHelper(lhs, rhs, 0, 0, 0) +{ +} + +template +SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, + A const & rhs, + size_t tile_x, + size_t tile_y, + size_t tile_z) + : m_lhs(lhs) + , m_rhs(rhs) + , m_tile_x(tile_x) + , m_tile_y(tile_y) + , m_tile_z(tile_z) +{ + check_dims(); + + size_t const lhs_ndim = m_lhs.ndim(); + size_t const rhs_ndim = m_rhs.ndim(); + + if (lhs_ndim == 1 && rhs_ndim == 1) + { + check_inner(0, 0); + m_result = A(1); + return; + } + + if (lhs_ndim == 1) + { + check_inner(0, 0); + m_result = A(m_rhs.shape(1)); + return; + } + + if (rhs_ndim == 1) + { + check_inner(1, 0); + m_result = A(m_lhs.shape(0)); + return; + } + + check_inner(1, 0); + shape_type const result_shape{m_lhs.shape(0), m_rhs.shape(1)}; + m_result = A(result_shape); +} + +template +A SimpleArrayMatmulHelper::matmul() +{ + if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) + { + return matmul_vec_vec(); + } + if (m_lhs.ndim() == 1) + { + return matmul_vec_mat(); + } + if (m_rhs.ndim() == 1) + { + return matmul_mat_vec(); + } + + return matmul_mat_mat(); +} + +/** + * Perform fast matrix multiplication for SimpleArrays. + * This implementation currently uses tiling for 2D x 2D matrix multiplication. + * Future optimizations may add other techniques such as SIMD kernels. + */ +template +A SimpleArrayMatmulHelper::matmul_fast() +{ + check_tiles(); + + if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) + { + return matmul_vec_vec(); + } + if (m_lhs.ndim() == 1) + { + return matmul_vec_mat(); + } + if (m_rhs.ndim() == 1) + { + return matmul_mat_vec(); + } + + return matmul_mat_mat_tiled(); +} + +/** + * Perform matrix multiplication using Accelerate/CBLAS when available. + */ +template +A SimpleArrayMatmulHelper::matmul_veclib() +{ + if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) + { + return matmul_vec_vec(); + } + if (m_lhs.ndim() == 1) + { + return matmul_vec_mat(); + } + if (m_rhs.ndim() == 1) + { + return matmul_mat_vec(); + } + + return matmul_mat_mat_veclib(); +} + +/** + * Format shape for matrix multiplication diagnostics. + */ +template +std::string SimpleArrayMatmulHelper::shape_str(A const & arr) +{ + if (arr.ndim() == 0) + { + return "()"; + } + + std::string result = "("; + for (size_t i = 0; i < arr.ndim(); ++i) + { + if (i > 0) + { + result += ","; + } + result += std::to_string(arr.shape(i)); + } + result += ")"; + return result; +} + +template +void SimpleArrayMatmulHelper::check_dims() const +{ + bool const lhs_is_supported = m_lhs.ndim() == 1 || m_lhs.ndim() == 2; + bool const rhs_is_supported = m_rhs.ndim() == 1 || m_rhs.ndim() == 2; + if (lhs_is_supported && rhs_is_supported) + { + return; + } + + std::string const err = std::format("SimpleArray::matmul(): unsupported dimensions: " + "this={} other={}. SimpleArray must be 1D or 2D.", + shape_str(m_lhs), + shape_str(m_rhs)); + throw std::out_of_range(err); +} + +template +void SimpleArrayMatmulHelper::check_inner(size_t lhs_idx, size_t rhs_idx) const +{ + if (m_lhs.shape(lhs_idx) == m_rhs.shape(rhs_idx)) + { + return; + } + + throw std::out_of_range( + std::format("SimpleArray::matmul(): shape mismatch: this={} other={}", + shape_str(m_lhs), + shape_str(m_rhs))); +} + +template +void SimpleArrayMatmulHelper::check_tiles() const +{ + if (m_tile_x != 0 && m_tile_y != 0 && m_tile_z != 0) + { + return; + } + + throw std::out_of_range( + std::format("SimpleArray::matmul_fast(): tile sizes must be positive: " + "tile_x={} tile_y={} tile_z={}", + m_tile_x, + m_tile_y, + m_tile_z)); +} + +template +A SimpleArrayMatmulHelper::matmul_vec_vec() +{ + size_t const k = m_lhs.shape(0); + value_type v = 0; + for (size_t i = 0; i < k; ++i) + { + v += m_lhs(i) * m_rhs.data(i); + } + m_result.data(0) = v; + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_vec_mat() +{ + size_t const n = m_result.size(); + size_t const k = m_lhs.shape(0); + for (size_t j = 0; j < n; ++j) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(l) * m_rhs(l, j); + } + m_result.data(j) = v; + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_mat_vec() +{ + size_t const m = m_result.size(); + size_t const k = m_lhs.shape(1); + for (size_t i = 0; i < m; ++i) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(i, l) * m_rhs(l); + } + m_result.data(i) = v; + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_mat_mat() +{ + size_t const m = m_result.shape(0); + size_t const n = m_result.shape(1); + size_t const k = m_lhs.shape(1); + for (size_t i = 0; i < m; ++i) + { + for (size_t j = 0; j < n; ++j) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(i, l) * m_rhs(l, j); + } + m_result(i, j) = v; + } + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_mat_mat_veclib() +{ + if (can_matmul_veclib_v && m_lhs.is_c_contiguous() && m_rhs.is_c_contiguous()) + { + size_t const m = m_result.shape(0); + size_t const n = m_result.shape(1); + size_t const k = m_lhs.shape(1); + matmul_veclib_backend(m, n, k, m_lhs.data(), m_rhs.data(), m_result.data()); + return std::move(m_result); + } + + return matmul_mat_mat(); +} + +template +A SimpleArrayMatmulHelper::pack_rhs(size_t n, size_t k) +{ + shape_type const packing_shape{n, k}; + A packing(packing_shape); + for (size_t i = 0; i < n; ++i) + { + for (size_t j = 0; j < k; ++j) + { + packing(i, j) = m_rhs(j, i); + } + } + return packing; +} + +template +void SimpleArrayMatmulHelper::accumulate_tile(A const & packed_rhs, + size_t row_begin, + size_t row_end, + size_t col_begin, + size_t col_end, + size_t inner_begin, + size_t inner_end) +{ + for (size_t i = row_begin; i < row_end; ++i) + { + for (size_t j = col_begin; j < col_end; ++j) + { + value_type v = m_result(i, j); + for (size_t l = inner_begin; l < inner_end; ++l) + { + v += m_lhs(i, l) * packed_rhs(j, l); + } + m_result(i, j) = v; + } + } +} + +template +A SimpleArrayMatmulHelper::matmul_mat_mat_tiled() +{ + size_t const m = m_result.shape(0); + size_t const n = m_result.shape(1); + size_t const k = m_lhs.shape(1); + A packed_rhs = pack_rhs(n, k); + for (size_t i = 0; i < m_result.size(); ++i) + { + m_result.data(i) = value_type{0}; + } + for (size_t row = 0; row < m; row += m_tile_x) + { + size_t const row_end = std::min(row + m_tile_x, m); + for (size_t col = 0; col < n; col += m_tile_y) + { + size_t const col_end = std::min(col + m_tile_y, n); + for (size_t inner = 0; inner < k; inner += m_tile_z) + { + size_t const inner_end = std::min(inner + m_tile_z, k); + accumulate_tile(packed_rhs, row, row_end, col, col_end, inner, inner_end); + } + } + } + return std::move(m_result); +} + +} /* end namespace detail */ + +} /* end namespace modmesh */ diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 800ff87fd..65127ab21 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -379,19 +379,20 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray [](wrapped_type const & self, value_type scalar) { return self.div(scalar); }) .def("matmul", &wrapped_type::matmul) - .def("__matmul__", &wrapped_type::matmul) + .def("matmul_veclib", &wrapped_type::matmul_veclib) .def( - "fast_matmul", + "matmul_fast", [](wrapped_type const & self, wrapped_type const & other, size_t tile_x, size_t tile_y, size_t tile_z) - { return self.fast_matmul(other, tile_x, tile_y, tile_z); }, + { return self.matmul_fast(other, tile_x, tile_y, tile_z); }, py::arg("other"), py::arg("tile_x") = 16, py::arg("tile_y") = 16, py::arg("tile_z") = 16) + .def("__matmul__", &wrapped_type::matmul) // TODO: In-place operation should return reference to self to support function chaining /* * Regular in-place methods (iadd, imul, etc.) are procedural calls and do @@ -433,6 +434,20 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray { self.idiv(scalar); }) .def("imatmul", [](wrapped_type & self, wrapped_type const & other) { self.imatmul(other); }) + .def("imatmul_veclib", [](wrapped_type & self, wrapped_type const & other) + { self.imatmul_veclib(other); }) + .def( + "imatmul_fast", + [](wrapped_type & self, + wrapped_type const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z) + { self.imatmul_fast(other, tile_x, tile_y, tile_z); }, + py::arg("other"), + py::arg("tile_x") = 16, + py::arg("tile_y") = 16, + py::arg("tile_z") = 16) .def("__imatmul__", [](wrapped_type & self, wrapped_type const & other) { self.imatmul(other); diff --git a/cpp/modmesh/math/Complex.hpp b/cpp/modmesh/math/Complex.hpp index 34c9c1acd..605642ea5 100644 --- a/cpp/modmesh/math/Complex.hpp +++ b/cpp/modmesh/math/Complex.hpp @@ -1,3 +1,5 @@ +#pragma once + /* * Copyright (c) 2025, Chun-Hsu Lai * @@ -260,6 +262,26 @@ ComplexImpl operator/(T lhs, const ComplexImpl & rhs) template using Complex = detail::ComplexImpl; +template +inline constexpr bool is_std_complex_layout_compatible_v = std::is_standard_layout_v> && + sizeof(Complex) == sizeof(std::complex) && + alignof(Complex) == alignof(std::complex); + +static_assert(is_std_complex_layout_compatible_v); +static_assert(is_std_complex_layout_compatible_v); + +template +std::complex const * as_std_complex_pointer(Complex const * ptr) +{ + return reinterpret_cast const *>(ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) +} + +template +std::complex * as_std_complex_pointer(Complex * ptr) +{ + return reinterpret_cast *>(ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) +} + // clang-format off template struct is_complex : std::false_type {}; diff --git a/gtests/CMakeLists.txt b/gtests/CMakeLists.txt index 2ade2370a..5fed97ebd 100644 --- a/gtests/CMakeLists.txt +++ b/gtests/CMakeLists.txt @@ -17,6 +17,10 @@ FetchContent_MakeAvailable(googletest) enable_testing() +if (APPLE) + find_library(APPLE_FWK_ACCELERATE Accelerate REQUIRED) +endif () # APPLE + # The `test_nopython` target is only for testing the C++ interface of the non-Python part of the library. add_executable( test_nopython @@ -41,6 +45,7 @@ target_link_libraries( test_nopython GTest::gtest_main GTest::gmock_main + ${APPLE_FWK_ACCELERATE} ) include(GoogleTest) diff --git a/profiling/profile_matrix_ops.py b/profiling/profile_matrix_ops.py index 1cfddfa55..c533148b3 100644 --- a/profiling/profile_matrix_ops.py +++ b/profiling/profile_matrix_ops.py @@ -25,6 +25,7 @@ # POSSIBILITY OF SUCH DAMAGE. import functools +import platform import numpy as np import modmesh @@ -58,10 +59,26 @@ def profile_matmul_naive_sa(lhs, rhs): return lhs.matmul(rhs) +@profile_function +def profile_matmul_veclib_sa(lhs, rhs): + return lhs.matmul_veclib(rhs) + + def profile_matmul_fast_sa(lhs, rhs, tile_x, tile_y, tile_z): name = f"profile_matmul_fast_sa_{tile_x}_{tile_y}_{tile_z}" _ = modmesh.CallProfilerProbe(name) - return lhs.fast_matmul(rhs, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z) + return lhs.matmul_fast(rhs, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z) + + +def supports_matmul_veclib(dtype): + return ( + platform.system() == "Darwin" + and platform.machine() == "arm64" + and np.dtype(dtype) in ( + np.dtype(np.float32), + np.dtype(np.float64), + ) + ) def make_data(dtype, shape): @@ -79,10 +96,13 @@ def profile_matmul_operation(dtype, shapes, it=10): rhs = make_data(dtype, (m, m)) lhs_sa = make_container(lhs) rhs_sa = make_container(rhs) + use_veclib = supports_matmul_veclib(dtype) modmesh.call_profiler.reset() for _ in range(it): profile_matmul_np(lhs, rhs) profile_matmul_naive_sa(lhs_sa, rhs_sa) + if use_veclib: + profile_matmul_veclib_sa(lhs_sa, rhs_sa) for tile_x, tile_y, tile_z in tile_configs: profile_matmul_fast_sa(lhs_sa, rhs_sa, tile_x, tile_y, tile_z) @@ -104,6 +124,8 @@ def print_row(*cols): print_row("-" * 20, "-" * 15, "-" * 15) npbase = out["np"] keys = ["np", "naive_sa"] + if use_veclib: + keys += ["veclib_sa"] keys += [ f"fast_sa_{tile_x}_{tile_y}_{tile_z}" for tile_x, tile_y, tile_z in tile_configs diff --git a/tests/test_matrix.py b/tests/test_matrix.py index fe3e8a0c0..00bd6f94f 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -83,6 +83,27 @@ def setUp(self): class MatmulTestBase(mm.testing.TestBase): """Tests for matrix-matrix multiplication""" + def assert_matmul_veclib(self, lhs, rhs, expected, matmul_result): + veclib_unavailable = ( + "SimpleArray::matmul_veclib(): " + "Accelerate/CBLAS matmul is only " + "available on Apple Silicon" + ) + + try: + veclib_result = lhs.matmul_veclib(rhs) + except RuntimeError as exc: + # FIXME: Split veclib backend coverage into dedicated tests and + # mark unsupported platforms as expected failures once a follow-up + # issue is filed. + self.assertEqual(str(exc), veclib_unavailable) + return + + self.assertEqual(list(veclib_result.shape), list(expected.shape)) + np.testing.assert_array_almost_equal(veclib_result.ndarray, expected) + np.testing.assert_array_almost_equal(veclib_result.ndarray, + matmul_result.ndarray) + def test_square(self): """Test basic square matrix multiplication""" # Create 2x2 matrices @@ -97,7 +118,7 @@ def test_square(self): # Test matrix multiplication result = a.matmul(b) - fast_result = a.fast_matmul(b) + fast_result = a.matmul_fast(b) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, expected) @@ -105,6 +126,7 @@ def test_square(self): np.testing.assert_array_almost_equal(fast_result.ndarray, expected) np.testing.assert_array_almost_equal(fast_result.ndarray, result.ndarray) + self.assert_matmul_veclib(a, b, expected, result) def test_rectangular(self): """Test rectangular matrix multiplication""" @@ -122,7 +144,7 @@ def test_rectangular(self): dtype=self.dtype) result = a.matmul(b) - fast_result = a.fast_matmul(b) + fast_result = a.matmul_fast(b) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, expected) @@ -130,6 +152,7 @@ def test_rectangular(self): np.testing.assert_array_almost_equal(fast_result.ndarray, expected) np.testing.assert_array_almost_equal(fast_result.ndarray, result.ndarray) + self.assert_matmul_veclib(a, b, expected, result) def test_identity(self): """Test multiplication with identity matrix""" @@ -141,7 +164,7 @@ def test_identity(self): identity = self.SimpleArray.eye(3) result = a.matmul(identity) - fast_result = a.fast_matmul(identity) + fast_result = a.matmul_fast(identity) self.assertEqual(list(result.shape), [3, 3]) np.testing.assert_array_almost_equal(result.ndarray, a_data) @@ -149,6 +172,7 @@ def test_identity(self): np.testing.assert_array_almost_equal(fast_result.ndarray, a_data) np.testing.assert_array_almost_equal(fast_result.ndarray, result.ndarray) + self.assert_matmul_veclib(a, identity, a_data, result) def test_zero(self): """Test multiplication with zero matrix""" @@ -159,7 +183,7 @@ def test_zero(self): zero = self.SimpleArray(array=zero_data) result = a.matmul(zero) - fast_result = a.fast_matmul(zero) + fast_result = a.matmul_fast(zero) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, zero_data) @@ -167,6 +191,7 @@ def test_zero(self): np.testing.assert_array_almost_equal(fast_result.ndarray, zero_data) np.testing.assert_array_almost_equal(fast_result.ndarray, result.ndarray) + self.assert_matmul_veclib(a, zero, zero_data, result) def test_dimension_mismatch_error(self): """Test error handling for incompatible dimensions""" @@ -192,7 +217,13 @@ def test_dimension_mismatch_error(self): r"SimpleArray::matmul\(\): shape mismatch: this=\(2,2\) other=" r"\(3,3\)" ): - a.fast_matmul(b) + a.matmul_fast(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: this=\(2,2\) other=" + r"\(3,3\)" + ): + a.matmul_veclib(b) def test_compare_with_numpy(self): """Compare results with NumPy using fixed test data""" @@ -294,7 +325,7 @@ def test_compare_with_numpy(self): # Compute with our implementation result = a.matmul(b) - fast_result = a.fast_matmul(b) + fast_result = a.matmul_fast(b) # Verify with NumPy np_result = np.matmul(a_data, b_data) @@ -306,6 +337,7 @@ def test_compare_with_numpy(self): list(expected.shape)) np.testing.assert_array_almost_equal(fast_result.ndarray, result.ndarray) + self.assert_matmul_veclib(a, b, expected, result) if self.dtype == np.float32: np.testing.assert_array_almost_equal( result.ndarray, expected, decimal=4) @@ -338,7 +370,13 @@ def test_wrong_shape_error(self): r"SimpleArray::matmul\(\): unsupported dimensions: " r"this=\(2,2,2\) other=\(2,2,2\)\. SimpleArray must be 1D or 2D." ): - a_3d.fast_matmul(b_3d) + a_3d.matmul_fast(b_3d) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): unsupported dimensions: " + r"this=\(2,2,2\) other=\(2,2,2\)\. SimpleArray must be 1D or 2D." + ): + a_3d.matmul_veclib(b_3d) a = np.zeros((3, 3), dtype=self.dtype) b = np.zeros((2, 3), dtype=self.dtype) @@ -355,7 +393,13 @@ def test_wrong_shape_error(self): r"SimpleArray::matmul\(\): shape mismatch: " r"this=\(3,3\) other=\(2,3\)" ): - a.fast_matmul(b) + a.matmul_fast(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(3,3\) other=\(2,3\)" + ): + a.matmul_veclib(b) a = np.zeros((3, 3), dtype=self.dtype) b = np.zeros((2), dtype=self.dtype) @@ -372,7 +416,13 @@ def test_wrong_shape_error(self): r"SimpleArray::matmul\(\): shape mismatch: " r"this=\(3,3\) other=\(2\)" ): - a.fast_matmul(b) + a.matmul_fast(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(3,3\) other=\(2\)" + ): + a.matmul_veclib(b) a = np.zeros((2), dtype=self.dtype) b = np.zeros((3, 3), dtype=self.dtype) @@ -389,7 +439,13 @@ def test_wrong_shape_error(self): r"SimpleArray::matmul\(\): shape mismatch: " r"this=\(2\) other=\(3,3\)" ): - a.fast_matmul(b) + a.matmul_fast(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(2\) other=\(3,3\)" + ): + a.matmul_veclib(b) a = np.zeros((2), dtype=self.dtype) b = np.zeros((3), dtype=self.dtype) @@ -406,7 +462,13 @@ def test_wrong_shape_error(self): r"SimpleArray::matmul\(\): shape mismatch: " r"this=\(2\) other=\(3\)" ): - a.fast_matmul(b) + a.matmul_fast(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(2\) other=\(3\)" + ): + a.matmul_veclib(b) def test_matmul_operator(self): """Test @ operator for matrix multiplication""" @@ -433,6 +495,8 @@ def test_imatmul_method(self): b_data = np.array([[5.0, 6.0], [7.0, 8.0]], dtype=self.dtype) a = self.SimpleArray(array=a_data) + a_fast = self.SimpleArray(array=a_data) + a_veclib = self.SimpleArray(array=a_data) b = self.SimpleArray(array=b_data) # Expected result: [[19, 22], [43, 50]] @@ -440,10 +504,20 @@ def test_imatmul_method(self): # Test imatmul() method a.imatmul(b) + # Test imatmul_veclib() method + a_veclib.imatmul_veclib(b) + # Test imatmul_fast() method + a_fast.imatmul_fast(b) # Verify the result self.assertEqual(list(a.shape), [2, 2]) np.testing.assert_array_almost_equal(a.ndarray, expected) + self.assertEqual(list(a_veclib.shape), [2, 2]) + np.testing.assert_array_almost_equal(a_veclib.ndarray, expected) + self.assertEqual(list(a_fast.shape), [2, 2]) + np.testing.assert_array_almost_equal(a_fast.ndarray, expected) + np.testing.assert_array_almost_equal(a_veclib.ndarray, a.ndarray) + np.testing.assert_array_almost_equal(a_fast.ndarray, a.ndarray) def test_imatmul_operator(self): """Test @= operator for in-place matrix multiplication"""