Skip to content

Commit accf7a5

Browse files
authored
Fix EltwiseReduceMod (#90)
* Fix AVVX512DQ EltwiseReduceMod
1 parent 8a976dd commit accf7a5

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

hexl/eltwise/eltwise-reduce-mod.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,23 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n,
9797
}
9898
return;
9999
}
100+
101+
#ifdef HEXL_HAS_AVX512IFMA
102+
if (has_avx512ifma && modulus < (1ULL << 52)) {
103+
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
104+
output_mod_factor);
105+
return;
106+
}
107+
#endif
108+
100109
#ifdef HEXL_HAS_AVX512DQ
101110
if (has_avx512dq) {
102-
if (modulus < (1ULL << 52)) {
103-
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
104-
output_mod_factor);
105-
} else {
106-
EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor,
107-
output_mod_factor);
108-
}
111+
EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor,
112+
output_mod_factor);
109113
return;
110114
}
111115
#endif
116+
112117
HEXL_VLOG(3, "Calling EltwiseReduceModNative");
113118
EltwiseReduceModNative(result, operand, n, modulus, input_mod_factor,
114119
output_mod_factor);

test/test-eltwise-reduce-mod.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "hexl/logging/logging.hpp"
1111
#include "hexl/number-theory/number-theory.hpp"
1212
#include "test-util.hpp"
13+
#include "util/util-internal.hpp"
1314

1415
namespace intel {
1516
namespace hexl {
@@ -79,5 +80,48 @@ TEST(EltwiseReduceMod, 4_2) {
7980
CheckEqual(result, exp_out);
8081
}
8182

83+
// First parameter is the number of bits in the modulus
84+
// Second parameter is whether or not to prefer small moduli
85+
class EltwiseReduceModTest
86+
: public ::testing::TestWithParam<std::tuple<uint64_t, bool>> {
87+
protected:
88+
void SetUp() override {
89+
m_modulus_bits = std::get<0>(GetParam());
90+
m_prefer_small_primes = std::get<1>(GetParam());
91+
m_modulus = GeneratePrimes(1, m_modulus_bits, m_prefer_small_primes)[0];
92+
}
93+
94+
void TearDown() override {}
95+
96+
public:
97+
uint64_t m_N{1024 + 7}; // m_N % 8 = 7 to test AVX512 boundary case
98+
uint64_t m_modulus_bits;
99+
bool m_prefer_small_primes;
100+
uint64_t m_modulus;
101+
};
102+
103+
// Test public API matches Native implementation on random values
104+
TEST_P(EltwiseReduceModTest, Random) {
105+
uint64_t upper_bound =
106+
m_modulus < (1ULL << 32) ? m_modulus * m_modulus : 1ULL << 63;
107+
108+
auto input = GenerateInsecureUniformRandomValues(m_N, 0, upper_bound);
109+
std::vector<uint64_t> result_native(m_N, 0);
110+
std::vector<uint64_t> result_public_api(m_N, 0);
111+
112+
EltwiseReduceModNative(result_native.data(), input.data(), m_N, m_modulus,
113+
m_modulus, 1);
114+
EltwiseReduceMod(result_public_api.data(), input.data(), m_N, m_modulus,
115+
m_modulus, 1);
116+
AssertEqual(result_native, result_public_api);
117+
}
118+
119+
INSTANTIATE_TEST_SUITE_P(
120+
EltwiseReduceMod, EltwiseReduceModTest,
121+
::testing::Combine(::testing::ValuesIn(AlignedVector64<uint64_t>{
122+
20, 25, 30, 31, 32, 33, 35, 40, 48, 49, 50, 51, 52,
123+
55, 58, 59, 60}),
124+
::testing::ValuesIn(std::vector<bool>{false, true})));
125+
82126
} // namespace hexl
83127
} // namespace intel

0 commit comments

Comments
 (0)