Skip to content

Commit 8a976dd

Browse files
authored
Fboemer/reference inv ntt (#89)
* Add reference radix-2 Inv NTT
1 parent eecc3bf commit 8a976dd

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

hexl/ntt/ntt-internal.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ void ForwardTransformToBitReverseRadix4(
5757
const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1,
5858
uint64_t output_mod_factor = 1);
5959

60-
/// @brief Reference NTT which is written for clarity rather than performance
60+
/// @brief Reference forward NTT which is written for clarity rather than
61+
/// performance
6162
/// @param[in, out] operand Input data. Overwritten with NTT output
6263
/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a
6364
/// power of two.
@@ -68,6 +69,18 @@ void ReferenceForwardTransformToBitReverse(
6869
uint64_t* operand, uint64_t n, uint64_t modulus,
6970
const uint64_t* root_of_unity_powers);
7071

72+
/// @brief Reference inverse NTT which is written for clarity rather than
73+
/// performance
74+
/// @param[in, out] operand Input data. Overwritten with NTT output
75+
/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a
76+
/// power of two.
77+
/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n
78+
/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in
79+
/// F_q. In bit-reversed order.
80+
void ReferenceInverseTransformFromBitReverse(
81+
uint64_t* operand, uint64_t n, uint64_t modulus,
82+
const uint64_t* inv_root_of_unity_powers);
83+
7184
/// @brief Radix-2 native C++ NTT implementation of the inverse NTT
7285
/// @param[out] result Output data. Overwritten with NTT output
7386
/// @param[in] operand Input data.

hexl/ntt/ntt-radix-2.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,43 @@ void ReferenceForwardTransformToBitReverse(
290290
}
291291
}
292292

293+
void ReferenceInverseTransformFromBitReverse(
294+
uint64_t* operand, uint64_t n, uint64_t modulus,
295+
const uint64_t* inv_root_of_unity_powers) {
296+
HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
297+
HEXL_CHECK(inv_root_of_unity_powers != nullptr,
298+
"inv_root_of_unity_powers == nullptr");
299+
HEXL_CHECK(operand != nullptr, "operand == nullptr");
300+
301+
size_t t = 1;
302+
size_t root_index = 1;
303+
for (size_t m = (n >> 1); m >= 1; m >>= 1) {
304+
size_t j1 = 0;
305+
for (size_t i = 0; i < m; i++, root_index++) {
306+
const uint64_t W = inv_root_of_unity_powers[root_index];
307+
uint64_t* X_r = operand + j1;
308+
uint64_t* Y_r = X_r + t;
309+
for (size_t j = 0; j < t; j++) {
310+
uint64_t X_op = *X_r;
311+
uint64_t Y_op = *Y_r;
312+
// Butterfly X' = (X + Y) mod q, Y' = W(X-Y) mod q
313+
*X_r = AddUIntMod(X_op, Y_op, modulus);
314+
*Y_r = MultiplyMod(W, SubUIntMod(X_op, Y_op, modulus), modulus);
315+
X_r++;
316+
Y_r++;
317+
}
318+
j1 += (t << 1);
319+
}
320+
t <<= 1;
321+
}
322+
323+
// Final multiplication by N^{-1}
324+
const uint64_t inv_n = InverseMod(n, modulus);
325+
for (size_t i = 0; i < n; ++i) {
326+
operand[i] = MultiplyMod(operand[i], inv_n, modulus);
327+
}
328+
}
329+
293330
void InverseTransformFromBitReverseRadix2(
294331
uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
295332
const uint64_t* inv_root_of_unity_powers,

test/test-ntt.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ TEST_P(DegreeModulusInputOutput, API) {
255255
ReferenceForwardTransformToBitReverse(input.data(), N, modulus,
256256
ntt.GetRootOfUnityPowers().data());
257257
AssertEqual(input, exp_output);
258+
ReferenceInverseTransformFromBitReverse(input.data(), N, modulus,
259+
ntt.GetInvRootOfUnityPowers().data());
260+
AssertEqual(input, input_copy);
258261

259262
// Test round-trip
260263
input = input_copy;
@@ -448,6 +451,22 @@ TEST_P(NttNativeTest, InverseRadix4Random) {
448451
AssertEqual(input, input_radix4);
449452
}
450453

454+
TEST_P(NttNativeTest, InverseRadix2Random) {
455+
auto input = GenerateInsecureUniformRandomValues(m_N, 1, 2);
456+
auto input_reference = input;
457+
458+
InverseTransformFromBitReverseRadix2(
459+
input.data(), input.data(), m_N, m_modulus,
460+
m_ntt.GetInvRootOfUnityPowers().data(),
461+
m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1);
462+
463+
ReferenceInverseTransformFromBitReverse(
464+
input_reference.data(), m_N, m_modulus,
465+
m_ntt.GetInvRootOfUnityPowers().data());
466+
467+
AssertEqual(input, input_reference);
468+
}
469+
451470
INSTANTIATE_TEST_SUITE_P(
452471
NTT, NttNativeTest,
453472
::testing::Combine(

0 commit comments

Comments
 (0)