From e18cff6b90acea6f960c752bcfc904974a489794 Mon Sep 17 00:00:00 2001 From: Yung-Yu Chen Date: Sat, 30 May 2026 22:23:24 +0800 Subject: [PATCH] Implement matrix power SimpleArray::pow() Add SimpleArray::pow(n) computing A^n for a square matrix and a non-negative integer exponent via exponentiation by squaring (O(log n) matmuls). A^0 returns the identity; negative exponents are rejected since they require matrix inversion (tracked in issue #719). Include Float32/Float64 tests covering exponent values, identity invariance, a numpy cross-check over fixed fixtures, and the non-square and non-2D (1D/3D/4D) error cases. For issue #821. --- cpp/modmesh/buffer/SimpleArray.hpp | 47 +++++++ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 1 + tests/test_matrix.py | 120 ++++++++++++++++++ 3 files changed, 168 insertions(+) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 039e4c95..0c2406b5 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -1364,6 +1364,8 @@ class SimpleArrayMixinMatrix return result; } + A pow(ssize_t n) const; + private: static void validate_positive(const char * operation_name, ssize_t n) @@ -1407,6 +1409,51 @@ class SimpleArrayMixinMatrix }; /* end class SimpleArrayMixinMatrix */ +/** + * Compute the matrix power A^n for a square SimpleArray and a non-negative + * integer exponent. The exponentiation-by-squaring algorithm keeps the + * number of matrix multiplications at O(log n). A^0 is the identity matrix. + * Negative exponents are not supported because they require matrix + * inversion. + */ +template +A SimpleArrayMixinMatrix::pow(ssize_t n) const +{ + auto const * athis = static_cast(this); + validate_square("pow"); + if (n < 0) + { + throw std::invalid_argument( + std::format("SimpleArray::pow(): exponent must be non-negative, " + "but got {}", + n)); + } + + auto const dim = static_cast(athis->shape(0)); + A result = eye(dim); + if (n == 0) + { + return result; + } + + // Exponentiation by squaring. n is non-negative here, so scanning its + // bits is well-defined. + A base = *athis; + while (n > 0) + { + if ((n & 1) != 0) + { + result.imatmul(base); + } + n >>= 1; + if (n > 0) + { + base.imatmul(base); + } + } + return result; +} + } /* end namespace detail */ // Tag type for explicit alignment constructor diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 96c4c571..b61c29ef 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -499,6 +499,7 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray (*this) .def_static("eye", &wrapped_type::eye, py::arg("n"), "Create an identity matrix of size n x n") + .def("pow", &wrapped_type::pow, py::arg("n"), "Compute the matrix power A^n for a non-negative integer n") .def_static("scaled_eye", &wrapped_type::scaled_eye, py::arg("n"), py::arg("scale"), "Create a scaled identity matrix of size n x n") .def("hermitian", &wrapped_type::hermitian, "Create hermitian (conjugate transpose) of the matrix") .def("symmetrize", &wrapped_type::symmetrize, "Create symmetric matrix by averaging with its transpose") diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 6d5a4925..4121416d 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -495,6 +495,126 @@ def test_imatmul_operator(self): np.testing.assert_array_almost_equal(a.ndarray, expected) +class MatrixPowerTestBase(mm.testing.TestBase): + """Tests for matrix power A^n with non-negative integer n""" + + def assert_pow(self, mat, mat_data, n): + result = mat.pow(n) + expected = np.linalg.matrix_power(mat_data, n) + + self.assertEqual(list(result.shape), list(expected.shape)) + np.testing.assert_array_almost_equal(result.ndarray, expected) + return result + + def test_zero_exponent(self): + """A^0 is the identity matrix""" + mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + + result = self.assert_pow(mat, mat_data, 0) + np.testing.assert_array_almost_equal( + result.ndarray, np.eye(2, dtype=self.dtype)) + + def test_one_exponent(self): + """A^1 reproduces the original matrix""" + mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + + result = self.assert_pow(mat, mat_data, 1) + np.testing.assert_array_almost_equal(result.ndarray, mat_data) + + def test_small_exponents(self): + """A^n matches numpy.linalg.matrix_power for small n""" + mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + + for n in range(0, 8): + with self.subTest(n=n): + self.assert_pow(mat, mat_data, n) + + def test_identity_power(self): + """The identity matrix is invariant under any power""" + identity = self.SimpleArray.eye(4) + identity_data = np.eye(4, dtype=self.dtype) + + for n in (0, 1, 5, 10): + with self.subTest(n=n): + self.assert_pow(identity, identity_data, n) + + def test_matrix_dim_to_5(self): + """A^n matches numpy across several square matrices and exponents""" + fixtures = [ + [[-3]], + [[2, 1], [0, 0]], + [[3, -3, 1], [-2, -3, 0], [3, 2, 2]], + [[2, 2, 0, -3, 2], + [0, 0, -1, -2, 3], + [2, 1, -1, 2, 0], + [0, 0, -2, -3, 0], + [3, -3, 3, 2, -2]], + ] + exponents = [0, 1, 2, 3, 6] + + for fixture in fixtures: + mat_data = np.array(fixture, dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + for n in exponents: + with self.subTest(size=mat_data.shape[0], n=n): + self.assert_pow(mat, mat_data, n) + + def test_negative_exponent_error(self): + """A negative exponent is rejected""" + mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + + with self.assertRaisesRegex( + ValueError, + r"SimpleArray::pow\(\): exponent must be non-negative, " + r"but got -1"): + mat.pow(-1) + + def test_non_square_error(self): + """A non-square matrix cannot be raised to a power""" + mat_data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + dtype=self.dtype) + mat = self.SimpleArray(array=mat_data) + + with self.assertRaisesRegex( + RuntimeError, + r"SimpleArray::pow\(\): operation requires square " + r"SimpleArray, but got 2x3 shape"): + mat.pow(2) + + def test_non_2d_error(self): + """A non-2D SimpleArray cannot be raised to a power""" + # 1D, 3D, and 4D arrays must all be rejected, and the error must + # report the offending dimensionality. + shapes = [(3,), (2, 2, 2), (2, 2, 2, 2)] + + for shape in shapes: + ndim = len(shape) + with self.subTest(ndim=ndim): + mat = self.SimpleArray( + array=np.ones(shape, dtype=self.dtype)) + with self.assertRaisesRegex( + RuntimeError, + r"SimpleArray::pow\(\): operation requires 2D " + r"SimpleArray, but got %dD SimpleArray" % ndim): + mat.pow(2) + + +class MatrixPowerFloat32TC(MatrixPowerTestBase, unittest.TestCase): + def setUp(self): + self.dtype = np.float32 + self.SimpleArray = mm.SimpleArrayFloat32 + + +class MatrixPowerFloat64TC(MatrixPowerTestBase, unittest.TestCase): + def setUp(self): + self.dtype = np.float64 + self.SimpleArray = mm.SimpleArrayFloat64 + + class MatmulFloat32TC(MatmulTestBase, unittest.TestCase): def setUp(self): self.dtype = np.float32