diff --git a/fastcrypto-tbls/src/fast_mult.rs b/fastcrypto-tbls/src/fast_mult.rs new file mode 100644 index 0000000000..29e25810c9 --- /dev/null +++ b/fastcrypto-tbls/src/fast_mult.rs @@ -0,0 +1,50 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use fastcrypto::groups::Scalar; +use std::borrow::Borrow; + +/// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply. +/// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1. +pub(crate) fn fast_mult(x: (C, u128), y: u128) -> (C, u128) { + if x.1.leading_zeros() >= (128 - y.leading_zeros()) { + (x.0, x.1 * y) + } else { + (x.0 * C::from(x.1), y) + } +} + +/// Compute initial * \prod factors. +pub(crate) fn fast_product(initial: C, factors: impl Iterator) -> C { + let (result, remaining) = factors.fold((initial, 1), |acc, factor| { + debug_assert_ne!(factor, 0); + fast_mult(acc, factor) + }); + debug_assert_ne!(remaining, 0); + result * C::ScalarType::from(remaining) +} + +/// Compute initial * (terms_0 - base) * (terms_1 - base)... +pub(crate) fn fast_product_of_differences( + initial: C, + base: u128, + terms: impl Iterator>, +) -> C { + let mut negative = false; + let mut result = fast_product( + initial, + terms.map(|term| { + let term = term.borrow(); + if base > *term { + negative = !negative; + base - term + } else { + term - base + } + }), + ); + if negative { + result = -result; + }; + result +} diff --git a/fastcrypto-tbls/src/lib.rs b/fastcrypto-tbls/src/lib.rs index 00451a4c43..9ad95859bb 100644 --- a/fastcrypto-tbls/src/lib.rs +++ b/fastcrypto-tbls/src/lib.rs @@ -56,6 +56,7 @@ pub mod nodes_tests; // #[path = "tests/nidkg_tests.rs"] // pub mod nidkg_tests; +mod fast_mult; #[cfg(test)] #[path = "tests/nizk_tests.rs"] pub mod nizk_tests; diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index 24edaa1006..039cfe19e8 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -5,6 +5,7 @@ // modified for our needs. // +use crate::fast_mult::fast_product_of_differences; use crate::types::{to_scalar, IndexedValue, ShareIndex}; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; @@ -169,29 +170,6 @@ impl Poly { )) } - /// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply. - /// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1. - pub(crate) fn fast_mult(x: (C::ScalarType, u128), y: u128) -> (C::ScalarType, u128) { - if x.1.leading_zeros() >= (128 - y.leading_zeros()) { - (x.0, x.1 * y) - } else { - (x.0 * C::ScalarType::from(x.1), y) - } - } - - /// Compute initial * \prod factors. - pub(crate) fn fast_product( - initial: C::ScalarType, - factors: impl Iterator, - ) -> C::ScalarType { - let (result, remaining) = factors.fold((initial, 1), |acc, factor| { - debug_assert_ne!(factor, 0); - Self::fast_mult(acc, factor) - }); - debug_assert_ne!(remaining, 0); - result * C::ScalarType::from(remaining) - } - fn get_lagrange_coefficients_for_c0( t: u16, shares: impl Iterator>>, @@ -212,34 +190,21 @@ impl Poly { } let x_as_scalar = C::ScalarType::from(x); - let full_numerator = C::ScalarType::product( - indices - .iter() - .map(|i| C::ScalarType::from(*i) - x_as_scalar), - ); + let full_numerator = + fast_product_of_differences(C::ScalarType::generator(), x, indices.iter()); Ok(( full_numerator, indices .iter() - .map(|i| { - let mut negative = false; - let mut denominator = Self::fast_product( - C::ScalarType::from(*i) - x_as_scalar, - indices.iter().filter(|j| *j != i).map(|j| { - if i > j { - negative = !negative; - i - j - } else { - // i < j (but not equal) - j - i - } - }), - ); - if negative { - denominator = -denominator; - } - denominator.inverse().expect("safe since i != j") + .map(|&i| { + fast_product_of_differences( + C::ScalarType::from(i) - x_as_scalar, + i, + indices.iter().filter(|&j| *j != i), + ) + .inverse() + .expect("safe since i != j") }) .collect(), )) @@ -374,28 +339,16 @@ impl Poly { full_numerator *= MonicLinear(-to_scalar::(point.index)); } - Ok(Poly::sum(points.iter().enumerate().map(|(i, p_i)| { + Ok(Poly::sum(points.iter().map(|p_i| { let x_i = p_i.index.get() as u128; - let mut negative = false; - let mut denominator = Self::fast_product( + let denominator = fast_product_of_differences( C::ScalarType::generator(), + x_i, points .iter() - .enumerate() - .filter(|(j, _)| *j != i) - .map(|(_, p_j)| { - let x_j = p_j.index.get() as u128; - if x_i > x_j { - negative = !negative; - x_i - x_j - } else { - x_j - x_i - } - }), + .map(|p_j| p_j.index.get() as u128) + .filter(|&x_j| x_j != x_i), ); - if negative { - denominator = -denominator; - } (&full_numerator / MonicLinear(-to_scalar::(p_i.index))) * &(p_i.value / denominator).unwrap() }))) diff --git a/fastcrypto-tbls/src/tests/polynomial_tests.rs b/fastcrypto-tbls/src/tests/polynomial_tests.rs index 1550e889ab..32c1e9350f 100644 --- a/fastcrypto-tbls/src/tests/polynomial_tests.rs +++ b/fastcrypto-tbls/src/tests/polynomial_tests.rs @@ -231,6 +231,7 @@ mod scalar_tests { #[generic_tests::define] mod points_tests { use super::*; + use crate::fast_mult::fast_mult; #[test] fn test_eval_and_commit() { @@ -308,44 +309,32 @@ mod points_tests { let x = 1u128 << 109; // 110 bit set let y = 1u128 << 17; // 18 bit set assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::generator(), x * y) + fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y) ); let x = 1u128 << 17; let y = 1u128 << 109; assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::generator(), x * y) + fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y) ); let x = 1u128 << (109 - 1); // all 109 bits set let y = 1u128 << (19 - 1); // all 19 bits set assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::generator(), x * y) + fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y) ); let x = 1u128 << 120; let y = 1u128 << 13; - assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::from(x), y) - ); + assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y)); let x = 1u128 << 21; let y = 1u128 << 120; - assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::from(x), y) - ); + assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y)); let x = u128::MAX; let y = 1u128; - assert!( - Poly::::fast_mult((G::ScalarType::generator(), x), y) - == (G::ScalarType::from(x), y) - ); + assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y)); } #[instantiate_tests()]