Skip to content

Commit eecc3bf

Browse files
authored
Adding a clear top 12 bit utility (#87)
* Testing clearTop12b function * Adding ClearTop12b_64 function * Fix typo * Converting into template function * Fixing formula * Fixing description * Adding template param descriptors
1 parent 49c44ac commit eecc3bf

File tree

4 files changed

+58
-67
lines changed

4 files changed

+58
-67
lines changed

benchmark/bench-eltwise-reduce-mod.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ static void BM_EltwiseReduceModMontAVX512BitShift52LT(
239239
AlignedVector64<uint64_t> output(input_size, 0);
240240

241241
for (auto _ : state) {
242-
EltwiseMontReduceModAVX512<52>(output.data(), input_a.data(),
243-
input_b.data(), input_size, modulus, inv_mod,
244-
r);
242+
EltwiseMontReduceModAVX512<52, 46>(output.data(), input_a.data(),
243+
input_b.data(), input_size, modulus,
244+
inv_mod);
245245
}
246246
}
247247

@@ -266,8 +266,8 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift52LT(
266266
AlignedVector64<uint64_t> output(input_size, 0);
267267

268268
for (auto _ : state) {
269-
EltwiseMontgomeryFormAVX512<52>(output.data(), input_a.data(), R2_mod_q,
270-
input_size, modulus, inv_mod, r);
269+
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q,
270+
input_size, modulus, inv_mod);
271271
}
272272
}
273273

@@ -292,8 +292,8 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift64LT(
292292
AlignedVector64<uint64_t> output(input_size, 0);
293293

294294
for (auto _ : state) {
295-
EltwiseMontgomeryFormAVX512<64>(output.data(), input_a.data(), R2_mod_q,
296-
input_size, modulus, inv_mod, r);
295+
EltwiseMontgomeryFormAVX512<64, 46>(output.data(), input_a.data(), R2_mod_q,
296+
input_size, modulus, inv_mod);
297297
}
298298
}
299299

@@ -318,10 +318,10 @@ static void BM_EltwiseReduceModInOutMontFormAVX512BitShift52LT(
318318
AlignedVector64<uint64_t> output(input_size, 0);
319319

320320
for (auto _ : state) {
321-
EltwiseMontgomeryFormAVX512<52>(output.data(), input_a.data(), R2_mod_q,
322-
input_size, modulus, inv_mod, r);
323-
EltwiseMontgomeryFormAVX512<52>(output.data(), output.data(), 1ULL,
324-
input_size, modulus, inv_mod, r);
321+
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q,
322+
input_size, modulus, inv_mod);
323+
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), output.data(), 1ULL,
324+
input_size, modulus, inv_mod);
325325
}
326326
}
327327

hexl/eltwise/eltwise-reduce-mod-avx512.hpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,20 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand,
141141

142142
/// @brief Returns Montgomery form of modular product ab mod q, computed via the
143143
/// REDC algorithm, also known as Montgomery reduction.
144+
/// @tparam BitShift denotes the operational length, in bits, of the operands
145+
/// and result values.
146+
/// @tparam r defines the value of R, being R = 2^r. R > modulus.
144147
/// @param[in] a input vector. T = ab in the range [0, Rq − 1].
145148
/// @param[in] b input vector.
146-
/// @param[in] r 2 pow r is R.
147-
/// @param[in] modulus with R = 2^r such that gcd(R, modulus) = 1. R > modulus.
149+
/// @param[in] modulus such that gcd(R, modulus) = 1.
148150
/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
149151
/// @param[in] n number of elements in input vector.
150152
/// @param[out] result unsigned long int vector in the range [0, q − 1] such
151153
/// that S ≡ TR^−1 mod q
152-
template <int BitShift>
154+
template <int BitShift, int r>
153155
void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
154156
const uint64_t* b, uint64_t n, uint64_t modulus,
155-
uint64_t inv_mod, int r) {
157+
uint64_t inv_mod) {
156158
HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
157159
HEXL_CHECK(b != nullptr, "Require operand b != nullptr");
158160
HEXL_CHECK(n != 0, "Require n != 0");
@@ -192,7 +194,6 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
192194
const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
193195
const __m512i* v_b = reinterpret_cast<const __m512i*>(b);
194196
__m512i* v_result = reinterpret_cast<__m512i*>(result);
195-
__m512i v_mod_R_mask = _mm512_set1_epi64(mod_R_mask);
196197
__m512i v_modulus = _mm512_set1_epi64(modulus);
197198
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
198199
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
@@ -210,8 +211,8 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
210211
v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
211212
}
212213

213-
__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift>(
214-
v_T_hi, v_T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
214+
__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
215+
v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
215216
HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
216217
"v_op exceeds bound " << modulus);
217218
_mm512_storeu_si512(v_result, v_c);
@@ -223,18 +224,20 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
223224

224225
/// @brief Returns Montgomery form of a mod q, computed via the REDC algorithm,
225226
/// also known as Montgomery reduction.
227+
/// @tparam BitShift denotes the operational length, in bits, of the operands
228+
/// and result values.
229+
/// @tparam r defines the value of R, being R = 2^r. R > modulus.
226230
/// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1].
227231
/// @param[in] R2_mod_q R^2 mod q.
228-
/// @param[in] r 2 pow r is R.
229-
/// @param[in] modulus with R = 2^r such that gcd(R, modulus) = 1. R > modulus.
232+
/// @param[in] modulus such that gcd(R, modulus) = 1.
230233
/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
231234
/// @param[in] n number of elements in input vector.
232235
/// @param[out] result unsigned long int vector in the range [0, q − 1] such
233236
/// that S ≡ TR^−1 mod q
234-
template <int BitShift>
237+
template <int BitShift, int r>
235238
void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
236239
uint64_t R2_mod_q, uint64_t n,
237-
uint64_t modulus, uint64_t inv_mod, int r) {
240+
uint64_t modulus, uint64_t inv_mod) {
238241
HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
239242
HEXL_CHECK(n != 0, "Require n != 0");
240243
HEXL_CHECK(modulus > 1, "Require modulus > 1");
@@ -271,7 +274,6 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
271274

272275
const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
273276
__m512i* v_result = reinterpret_cast<__m512i*>(result);
274-
__m512i v_mod_R_mask = _mm512_set1_epi64(mod_R_mask);
275277
__m512i v_b = _mm512_set1_epi64(R2_mod_q);
276278
__m512i v_modulus = _mm512_set1_epi64(modulus);
277279
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
@@ -289,8 +291,8 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
289291
v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
290292
}
291293

292-
__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift>(
293-
v_T_hi, v_T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
294+
__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
295+
v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
294296
HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
295297
"v_op exceeds bound " << modulus);
296298
_mm512_storeu_si512(v_result, v_c);

hexl/util/avx512-util.hpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ inline std::vector<double> ExtractValues(__m512d x) {
6262
return ret;
6363
}
6464

65+
// Returns lower NumBits bits from a 64-bit value
66+
template <int NumBits>
67+
inline __m512i ClearTopBits64(__m512i x) {
68+
const __m512i low52b_mask = _mm512_set1_epi64((1ULL << NumBits) - 1);
69+
return _mm512_and_epi64(x, low52b_mask);
70+
}
71+
6572
// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
6673
// and y to form a 2*BitShift-bit intermediate result.
6774
// Returns the high BitShift-bit unsigned integer from the intermediate result
@@ -231,8 +238,7 @@ inline __m512i _mm512_hexl_mullo_add_lo_epi<52>(__m512i x, __m512i y,
231238
__m512i result = _mm512_madd52lo_epu64(x, y, z);
232239

233240
// Clear high 12 bits from result
234-
const __m512i two_pow52_min1 = _mm512_set1_epi64((1ULL << 52) - 1);
235-
result = _mm512_and_epi64(result, two_pow52_min1);
241+
result = ClearTopBits64<52>(result);
236242
return result;
237243
}
238244
#endif
@@ -372,16 +378,15 @@ inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b,
372378

373379
// Returns Montgomery form of ab mod q, computed via the REDC algorithm,
374380
// also known as Montgomery reduction.
375-
// Inputs: r and q with R = 2^r such that gcd(R, q) = 1. R > q.
381+
// Template: r with R = 2^r
382+
// Inputs: q such that gcd(R, q) = 1. R > q.
376383
// v_inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
377384
// T = ab in the range [0, Rq − 1].
378385
// T_hi and T_lo for BitShift = 64 should be given in 63 bits.
379386
// Output: Integer S in the range [0, q − 1] such that S ≡ TR^−1 mod q
380-
template <int BitShift>
387+
template <int BitShift, int r>
381388
inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
382-
__m512i q, int r,
383-
__m512i v_mod_R_msk,
384-
__m512i v_inv_mod,
389+
__m512i q, __m512i v_inv_mod,
385390
__m512i v_rs_or_msk) {
386391
HEXL_CHECK(BitShift == 52 || BitShift == 64,
387392
"Invalid bitshift " << BitShift << "; need 52 or 64");
@@ -390,9 +395,9 @@ inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
390395
if (BitShift == 52) {
391396
// Operation:
392397
// m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
393-
__m512i m = _mm512_and_epi64(T_lo, v_mod_R_msk);
398+
__m512i m = ClearTopBits64<r>(T_lo);
394399
m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
395-
m = _mm512_and_epi64(m, v_mod_R_msk);
400+
m = ClearTopBits64<r>(m);
396401

397402
// Operation: t ← (T + mN) / R = (T + m*q) >> r
398403
// Hi part
@@ -415,9 +420,9 @@ inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
415420

416421
// Operation:
417422
// m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
418-
__m512i m = _mm512_and_epi64(T_lo, v_mod_R_msk);
423+
__m512i m = ClearTopBits64<r>(T_lo);
419424
m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
420-
m = _mm512_and_epi64(m, v_mod_R_msk);
425+
m = ClearTopBits64<r>(m);
421426

422427
__m512i mq_hi = _mm512_hexl_mulhi_epi<BitShift>(m, q);
423428
__m512i mq_lo = _mm512_hexl_mullo_epi<BitShift>(m, q);
@@ -461,9 +466,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
461466
if (mask != 0) {
462467
// values above 2^52
463468
__m512i x_hi = _mm512_srli_epi64(x, static_cast<unsigned int>(52ULL));
464-
__m512i x_intr = _mm512_slli_epi64(x, static_cast<unsigned int>(12ULL));
465-
__m512i x_lo =
466-
_mm512_srli_epi64(x_intr, static_cast<unsigned int>(12ULL));
469+
__m512i x_lo = ClearTopBits64<52>(x);
467470

468471
// c1 = floor(U / 2^{n + beta})
469472
__m512i c1_lo =

test/test-avx512-util.cpp

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -379,23 +379,21 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) {
379379

380380
uint64_t modulus = 5;
381381
int r = 3;
382-
uint64_t R = (1ULL << r);
383382
uint64_t prod_rs = (1ULL << (52 - r));
384383
uint64_t inv_mod = HenselLemma2adicRoot(r, modulus);
385384

386385
// mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
387-
__m512i v_mod_R_mask = _mm512_set1_epi64(R - 1);
388386
__m512i v_modulus = _mm512_set1_epi64(modulus);
389387
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
390388
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
391389

392-
__m512i _c = _mm512_hexl_montgomery_reduce<52>(
393-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
390+
__m512i _c = _mm512_hexl_montgomery_reduce<52, 3>(T_hi, T_lo, v_modulus,
391+
v_inv_mod, v_prod_rs);
394392
AssertEqual(_c, expected_out);
395393

396394
// Out of Montgomery form
397-
_c = _mm512_hexl_montgomery_reduce<52>(T_hi, _c, v_modulus, r, v_mod_R_mask,
398-
v_inv_mod, v_prod_rs);
395+
_c = _mm512_hexl_montgomery_reduce<52, 3>(T_hi, _c, v_modulus, v_inv_mod,
396+
v_prod_rs);
399397

400398
AssertEqual(_c, expected_c_out);
401399
}
@@ -419,16 +417,13 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) {
419417
// Also, for r = 46 and N = 67280421310725 then N' = 62463730494515
420418
__m512i T_hi = _mm512_set_epi64(559639348720ULL, 0, 0, 0, 0, 0, 0, 0);
421419
__m512i T_lo = _mm512_set_epi64(1832906312477596ULL, 0, 0, 0, 0, 0, 0, 0);
422-
423-
int r = 46;
424420
__m512i v_modulus = _mm512_set1_epi64(67280421310725);
425421
__m512i v_inv_mod = _mm512_set1_epi64(62463730494515);
426-
__m512i v_mod_R_mask = _mm512_set1_epi64(70368744177663);
427422
__m512i v_prod_rs = _mm512_set1_epi64(64);
428423

429424
// 52 bits
430-
__m512i c = _mm512_hexl_montgomery_reduce<52>(
431-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
425+
__m512i c = _mm512_hexl_montgomery_reduce<52, 46>(T_hi, T_lo, v_modulus,
426+
v_inv_mod, v_prod_rs);
432427
AssertEqual(c, expected_out);
433428
}
434429

@@ -437,18 +432,16 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) {
437432
int r = 51;
438433
uint64_t modulus = 2251799813684809;
439434
uint64_t inv_mod = HenselLemma2adicRoot(r, modulus);
440-
uint64_t mod_R_mask = (1ULL << r) - 1;
441435
uint64_t prod_rs = (1ULL << (52 - r));
442436
__m512i expected_out =
443437
_mm512_set_epi64(1832909426971103, 0, 0, 0, 0, 0, 0, 0);
444438
__m512i T_hi = _mm512_set_epi64(5446ULL, 0, 0, 0, 0, 0, 0, 0);
445439
__m512i T_lo = _mm512_set_epi64(3006504763740625ULL, 0, 0, 0, 0, 0, 0, 0);
446440
__m512i v_modulus = _mm512_set1_epi64(modulus);
447441
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
448-
__m512i v_mod_R_mask = _mm512_set1_epi64(mod_R_mask);
449442
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
450-
__m512i c = _mm512_hexl_montgomery_reduce<52>(
451-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
443+
__m512i c = _mm512_hexl_montgomery_reduce<52, 51>(T_hi, T_lo, v_modulus,
444+
v_inv_mod, v_prod_rs);
452445
AssertEqual(c, expected_out);
453446
}
454447
}
@@ -465,11 +458,8 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) {
465458
__m512i expected_out = _mm512_set_epi64(1546598034044, 0, 0, 0, 0, 0, 0, 0);
466459
__m512i T_hi = _mm512_set_epi64(559639348720ULL, 0, 0, 0, 0, 0, 0, 0);
467460
__m512i T_lo = _mm512_set_epi64(1832906312477596ULL, 0, 0, 0, 0, 0, 0, 0);
468-
469-
int r = 46;
470461
__m512i v_modulus = _mm512_set1_epi64(67280421310725);
471462
__m512i v_inv_mod = _mm512_set1_epi64(62463730494515);
472-
__m512i v_mod_R_mask = _mm512_set1_epi64(70368744177663);
473463

474464
// 64 bits
475465
uint64_t prod_rs = (1ULL << 63) - 1;
@@ -478,8 +468,8 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) {
478468
T_hi = _mm512_set_epi64(273261400, 0, 0, 0, 0, 0, 0, 0);
479469
T_lo = _mm512_set_epi64(6847304339915631516, 0, 0, 0, 0, 0, 0, 0);
480470

481-
__m512i c = _mm512_hexl_montgomery_reduce<64>(
482-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
471+
__m512i c = _mm512_hexl_montgomery_reduce<64, 46>(T_hi, T_lo, v_modulus,
472+
v_inv_mod, v_prod_rs);
483473
AssertEqual(c, expected_out);
484474
}
485475

@@ -488,7 +478,6 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) {
488478
int r = 61;
489479
uint64_t modulus = 2305843009213693487;
490480
uint64_t inv_mod = HenselLemma2adicRoot(r, modulus);
491-
uint64_t mod_R_mask = (1ULL << r) - 1ULL;
492481
uint64_t prod_rs = (1ULL << 63) - 1;
493482
__m512i expected_out =
494483
_mm512_set_epi64(59185395909485265, 0, 0, 0, 0, 0, 0, 0);
@@ -497,10 +486,9 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) {
497486
_mm512_set_epi64(9074465024201096609ULL, 0, 0, 0, 0, 0, 0, 0);
498487
__m512i v_modulus = _mm512_set1_epi64(modulus);
499488
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
500-
__m512i v_mod_R_mask = _mm512_set1_epi64(mod_R_mask);
501489
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
502-
__m512i c = _mm512_hexl_montgomery_reduce<64>(
503-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
490+
__m512i c = _mm512_hexl_montgomery_reduce<64, 61>(T_hi, T_lo, v_modulus,
491+
v_inv_mod, v_prod_rs);
504492
AssertEqual(c, expected_out);
505493
}
506494

@@ -509,18 +497,16 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) {
509497
int r = 62;
510498
uint64_t modulus = 4611686018427387631;
511499
uint64_t inv_mod = HenselLemma2adicRoot(r, modulus);
512-
uint64_t mod_R_mask = (1ULL << r) - 1;
513500
uint64_t prod_rs = (1ULL << 63) - 1;
514501
__m512i expected_out =
515502
_mm512_set_epi64(34747555017826833, 0, 0, 0, 0, 0, 0, 0);
516503
__m512i T_hi = _mm512_set_epi64(1ULL, 0, 0, 0, 0, 0, 0, 0);
517504
__m512i T_lo = _mm512_set_epi64(262710483011949601ULL, 0, 0, 0, 0, 0, 0, 0);
518505
__m512i v_modulus = _mm512_set1_epi64(modulus);
519506
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
520-
__m512i v_mod_R_mask = _mm512_set1_epi64(mod_R_mask);
521507
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
522-
__m512i c = _mm512_hexl_montgomery_reduce<64>(
523-
T_hi, T_lo, v_modulus, r, v_mod_R_mask, v_inv_mod, v_prod_rs);
508+
__m512i c = _mm512_hexl_montgomery_reduce<64, 62>(T_hi, T_lo, v_modulus,
509+
v_inv_mod, v_prod_rs);
524510
AssertEqual(c, expected_out);
525511
}
526512
}

0 commit comments

Comments
 (0)