|
10 | 10 | #include "hexl/logging/logging.hpp" |
11 | 11 | #include "hexl/number-theory/number-theory.hpp" |
12 | 12 | #include "test-util.hpp" |
| 13 | +#include "util/util-internal.hpp" |
13 | 14 |
|
14 | 15 | namespace intel { |
15 | 16 | namespace hexl { |
@@ -79,5 +80,48 @@ TEST(EltwiseReduceMod, 4_2) { |
79 | 80 | CheckEqual(result, exp_out); |
80 | 81 | } |
81 | 82 |
|
| 83 | +// First parameter is the number of bits in the modulus |
| 84 | +// Second parameter is whether or not to prefer small moduli |
| 85 | +class EltwiseReduceModTest |
| 86 | + : public ::testing::TestWithParam<std::tuple<uint64_t, bool>> { |
| 87 | + protected: |
| 88 | + void SetUp() override { |
| 89 | + m_modulus_bits = std::get<0>(GetParam()); |
| 90 | + m_prefer_small_primes = std::get<1>(GetParam()); |
| 91 | + m_modulus = GeneratePrimes(1, m_modulus_bits, m_prefer_small_primes)[0]; |
| 92 | + } |
| 93 | + |
| 94 | + void TearDown() override {} |
| 95 | + |
| 96 | + public: |
| 97 | + uint64_t m_N{1024 + 7}; // m_N % 8 = 7 to test AVX512 boundary case |
| 98 | + uint64_t m_modulus_bits; |
| 99 | + bool m_prefer_small_primes; |
| 100 | + uint64_t m_modulus; |
| 101 | +}; |
| 102 | + |
| 103 | +// Test public API matches Native implementation on random values |
| 104 | +TEST_P(EltwiseReduceModTest, Random) { |
| 105 | + uint64_t upper_bound = |
| 106 | + m_modulus < (1ULL << 32) ? m_modulus * m_modulus : 1ULL << 63; |
| 107 | + |
| 108 | + auto input = GenerateInsecureUniformRandomValues(m_N, 0, upper_bound); |
| 109 | + std::vector<uint64_t> result_native(m_N, 0); |
| 110 | + std::vector<uint64_t> result_public_api(m_N, 0); |
| 111 | + |
| 112 | + EltwiseReduceModNative(result_native.data(), input.data(), m_N, m_modulus, |
| 113 | + m_modulus, 1); |
| 114 | + EltwiseReduceMod(result_public_api.data(), input.data(), m_N, m_modulus, |
| 115 | + m_modulus, 1); |
| 116 | + AssertEqual(result_native, result_public_api); |
| 117 | +} |
| 118 | + |
| 119 | +INSTANTIATE_TEST_SUITE_P( |
| 120 | + EltwiseReduceMod, EltwiseReduceModTest, |
| 121 | + ::testing::Combine(::testing::ValuesIn(AlignedVector64<uint64_t>{ |
| 122 | + 20, 25, 30, 31, 32, 33, 35, 40, 48, 49, 50, 51, 52, |
| 123 | + 55, 58, 59, 60}), |
| 124 | + ::testing::ValuesIn(std::vector<bool>{false, true}))); |
| 125 | + |
82 | 126 | } // namespace hexl |
83 | 127 | } // namespace intel |
0 commit comments