-
Notifications
You must be signed in to change notification settings - Fork 63
Add SimpleArray::pow() for matrix power
#844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1364,6 +1364,8 @@ class SimpleArrayMixinMatrix | |
| return result; | ||
| } | ||
|
|
||
| A pow(ssize_t n) const; | ||
|
|
||
| private: | ||
|
|
||
| static void validate_positive(const char * operation_name, ssize_t n) | ||
|
|
@@ -1407,6 +1409,51 @@ class SimpleArrayMixinMatrix | |
|
|
||
| }; /* end class SimpleArrayMixinMatrix */ | ||
|
|
||
| /** | ||
| * Compute the matrix power A^n for a square SimpleArray and a non-negative | ||
| * integer exponent. The exponentiation-by-squaring algorithm keeps the | ||
| * number of matrix multiplications at O(log n). A^0 is the identity matrix. | ||
| * Negative exponents are not supported because they require matrix | ||
| * inversion. | ||
| */ | ||
| template <typename A, typename T> | ||
| A SimpleArrayMixinMatrix<A, T>::pow(ssize_t n) const | ||
| { | ||
| auto const * athis = static_cast<A const *>(this); | ||
| validate_square("pow"); | ||
| if (n < 0) | ||
| { | ||
| throw std::invalid_argument( | ||
| std::format("SimpleArray::pow(): exponent must be non-negative, " | ||
| "but got {}", | ||
| n)); | ||
| } | ||
|
|
||
| auto const dim = static_cast<ssize_t>(athis->shape(0)); | ||
| A result = eye(dim); | ||
| if (n == 0) | ||
| { | ||
| return result; | ||
| } | ||
|
|
||
| // Exponentiation by squaring. n is non-negative here, so scanning its | ||
| // bits is well-defined. | ||
| A base = *athis; | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Squaring (O(ln n) complexity) needs two arrays. |
||
| while (n > 0) | ||
| { | ||
| if ((n & 1) != 0) | ||
| { | ||
| result.imatmul(base); | ||
| } | ||
| n >>= 1; | ||
| if (n > 0) | ||
| { | ||
| base.imatmul(base); | ||
| } | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| } /* end namespace detail */ | ||
|
|
||
| // Tag type for explicit alignment constructor | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -495,6 +495,126 @@ def test_imatmul_operator(self): | |
| np.testing.assert_array_almost_equal(a.ndarray, expected) | ||
|
|
||
|
|
||
| class MatrixPowerTestBase(mm.testing.TestBase): | ||
| """Tests for matrix power A^n with non-negative integer n""" | ||
|
|
||
| def assert_pow(self, mat, mat_data, n): | ||
| result = mat.pow(n) | ||
| expected = np.linalg.matrix_power(mat_data, n) | ||
|
|
||
| self.assertEqual(list(result.shape), list(expected.shape)) | ||
| np.testing.assert_array_almost_equal(result.ndarray, expected) | ||
| return result | ||
|
|
||
| def test_zero_exponent(self): | ||
| """A^0 is the identity matrix""" | ||
| mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
|
|
||
| result = self.assert_pow(mat, mat_data, 0) | ||
| np.testing.assert_array_almost_equal( | ||
| result.ndarray, np.eye(2, dtype=self.dtype)) | ||
|
|
||
| def test_one_exponent(self): | ||
| """A^1 reproduces the original matrix""" | ||
| mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
|
|
||
| result = self.assert_pow(mat, mat_data, 1) | ||
| np.testing.assert_array_almost_equal(result.ndarray, mat_data) | ||
|
|
||
| def test_small_exponents(self): | ||
| """A^n matches numpy.linalg.matrix_power for small n""" | ||
| mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
|
|
||
| for n in range(0, 8): | ||
| with self.subTest(n=n): | ||
| self.assert_pow(mat, mat_data, n) | ||
|
|
||
| def test_identity_power(self): | ||
| """The identity matrix is invariant under any power""" | ||
| identity = self.SimpleArray.eye(4) | ||
| identity_data = np.eye(4, dtype=self.dtype) | ||
|
|
||
| for n in (0, 1, 5, 10): | ||
| with self.subTest(n=n): | ||
| self.assert_pow(identity, identity_data, n) | ||
|
|
||
| def test_matrix_dim_to_5(self): | ||
| """A^n matches numpy across several square matrices and exponents""" | ||
| fixtures = [ | ||
| [[-3]], | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test with 1x1, 2x2, 3x3, and 5x5 matrices. |
||
| [[2, 1], [0, 0]], | ||
| [[3, -3, 1], [-2, -3, 0], [3, 2, 2]], | ||
| [[2, 2, 0, -3, 2], | ||
| [0, 0, -1, -2, 3], | ||
| [2, 1, -1, 2, 0], | ||
| [0, 0, -2, -3, 0], | ||
| [3, -3, 3, 2, -2]], | ||
| ] | ||
| exponents = [0, 1, 2, 3, 6] | ||
|
|
||
| for fixture in fixtures: | ||
| mat_data = np.array(fixture, dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
| for n in exponents: | ||
| with self.subTest(size=mat_data.shape[0], n=n): | ||
| self.assert_pow(mat, mat_data, n) | ||
|
|
||
| def test_negative_exponent_error(self): | ||
| """A negative exponent is rejected""" | ||
| mat_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| ValueError, | ||
| r"SimpleArray::pow\(\): exponent must be non-negative, " | ||
| r"but got -1"): | ||
| mat.pow(-1) | ||
|
|
||
| def test_non_square_error(self): | ||
| """A non-square matrix cannot be raised to a power""" | ||
| mat_data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Power requires square matrices. |
||
| dtype=self.dtype) | ||
| mat = self.SimpleArray(array=mat_data) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| RuntimeError, | ||
| r"SimpleArray::pow\(\): operation requires square " | ||
| r"SimpleArray, but got 2x3 shape"): | ||
| mat.pow(2) | ||
|
|
||
| def test_non_2d_error(self): | ||
| """A non-2D SimpleArray cannot be raised to a power""" | ||
| # 1D, 3D, and 4D arrays must all be rejected, and the error must | ||
| # report the offending dimensionality. | ||
| shapes = [(3,), (2, 2, 2), (2, 2, 2, 2)] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Matrices must be 2D arrays. |
||
|
|
||
| for shape in shapes: | ||
| ndim = len(shape) | ||
| with self.subTest(ndim=ndim): | ||
| mat = self.SimpleArray( | ||
| array=np.ones(shape, dtype=self.dtype)) | ||
| with self.assertRaisesRegex( | ||
| RuntimeError, | ||
| r"SimpleArray::pow\(\): operation requires 2D " | ||
| r"SimpleArray, but got %dD SimpleArray" % ndim): | ||
| mat.pow(2) | ||
|
|
||
|
|
||
| class MatrixPowerFloat32TC(MatrixPowerTestBase, unittest.TestCase): | ||
| def setUp(self): | ||
| self.dtype = np.float32 | ||
| self.SimpleArray = mm.SimpleArrayFloat32 | ||
|
|
||
|
|
||
| class MatrixPowerFloat64TC(MatrixPowerTestBase, unittest.TestCase): | ||
| def setUp(self): | ||
| self.dtype = np.float64 | ||
| self.SimpleArray = mm.SimpleArrayFloat64 | ||
|
|
||
|
|
||
| class MatmulFloat32TC(MatmulTestBase, unittest.TestCase): | ||
| def setUp(self): | ||
| self.dtype = np.float32 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise exceptions.