Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions fastcrypto-tbls/src/fast_mult.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use fastcrypto::groups::Scalar;
use std::borrow::Borrow;

/// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply.
/// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1.
pub(crate) fn fast_mult<C: Scalar>(x: (C, u128), y: u128) -> (C, u128) {
if x.1.leading_zeros() >= (128 - y.leading_zeros()) {
(x.0, x.1 * y)
} else {
(x.0 * C::from(x.1), y)
}
}

/// Compute initial * \prod factors.
pub(crate) fn fast_product<C: Scalar>(initial: C, factors: impl Iterator<Item = u128>) -> C {
let (result, remaining) = factors.fold((initial, 1), |acc, factor| {
debug_assert_ne!(factor, 0);
fast_mult(acc, factor)
});
debug_assert_ne!(remaining, 0);
result * C::ScalarType::from(remaining)
}

/// Compute initial * (terms_0 - base) * (terms_1 - base)...
pub(crate) fn fast_product_of_differences<C: Scalar>(
initial: C,
base: u128,
terms: impl Iterator<Item = impl Borrow<u128>>,
) -> C {
let mut negative = false;
let mut result = fast_product(
initial,
terms.map(|term| {
let term = term.borrow();
if base > *term {
negative = !negative;
base - term
} else {
term - base
}
}),
);
if negative {
result = -result;
};
result
}
1 change: 1 addition & 0 deletions fastcrypto-tbls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub mod nodes_tests;
// #[path = "tests/nidkg_tests.rs"]
// pub mod nidkg_tests;

mod fast_mult;
#[cfg(test)]
#[path = "tests/nizk_tests.rs"]
pub mod nizk_tests;
79 changes: 16 additions & 63 deletions fastcrypto-tbls/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// modified for our needs.
//

use crate::fast_mult::fast_product_of_differences;
use crate::types::{to_scalar, IndexedValue, ShareIndex};
use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
Expand Down Expand Up @@ -169,29 +170,6 @@ impl<C: GroupElement> Poly<C> {
))
}

/// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply.
/// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1.
pub(crate) fn fast_mult(x: (C::ScalarType, u128), y: u128) -> (C::ScalarType, u128) {
if x.1.leading_zeros() >= (128 - y.leading_zeros()) {
(x.0, x.1 * y)
} else {
(x.0 * C::ScalarType::from(x.1), y)
}
}

/// Compute initial * \prod factors.
pub(crate) fn fast_product(
initial: C::ScalarType,
factors: impl Iterator<Item = u128>,
) -> C::ScalarType {
let (result, remaining) = factors.fold((initial, 1), |acc, factor| {
debug_assert_ne!(factor, 0);
Self::fast_mult(acc, factor)
});
debug_assert_ne!(remaining, 0);
result * C::ScalarType::from(remaining)
}

fn get_lagrange_coefficients_for_c0(
t: u16,
shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
Expand All @@ -212,34 +190,21 @@ impl<C: GroupElement> Poly<C> {
}

let x_as_scalar = C::ScalarType::from(x);
let full_numerator = C::ScalarType::product(
indices
.iter()
.map(|i| C::ScalarType::from(*i) - x_as_scalar),
);
let full_numerator =
fast_product_of_differences(C::ScalarType::generator(), x, indices.iter());

Ok((
full_numerator,
indices
.iter()
.map(|i| {
let mut negative = false;
let mut denominator = Self::fast_product(
C::ScalarType::from(*i) - x_as_scalar,
indices.iter().filter(|j| *j != i).map(|j| {
if i > j {
negative = !negative;
i - j
} else {
// i < j (but not equal)
j - i
}
}),
);
if negative {
denominator = -denominator;
}
denominator.inverse().expect("safe since i != j")
.map(|&i| {
fast_product_of_differences(
C::ScalarType::from(i) - x_as_scalar,
i,
indices.iter().filter(|&j| *j != i),
)
.inverse()
.expect("safe since i != j")
})
.collect(),
))
Expand Down Expand Up @@ -374,28 +339,16 @@ impl<C: Scalar> Poly<C> {
full_numerator *= MonicLinear(-to_scalar::<C>(point.index));
}

Ok(Poly::sum(points.iter().enumerate().map(|(i, p_i)| {
Ok(Poly::sum(points.iter().map(|p_i| {
let x_i = p_i.index.get() as u128;
let mut negative = false;
let mut denominator = Self::fast_product(
let denominator = fast_product_of_differences(
C::ScalarType::generator(),
x_i,
points
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, p_j)| {
let x_j = p_j.index.get() as u128;
if x_i > x_j {
negative = !negative;
x_i - x_j
} else {
x_j - x_i
}
}),
.map(|p_j| p_j.index.get() as u128)
.filter(|&x_j| x_j != x_i),
);
if negative {
denominator = -denominator;
}
(&full_numerator / MonicLinear(-to_scalar::<C>(p_i.index)))
* &(p_i.value / denominator).unwrap()
})))
Expand Down
25 changes: 7 additions & 18 deletions fastcrypto-tbls/src/tests/polynomial_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ mod scalar_tests {
#[generic_tests::define]
mod points_tests {
use super::*;
use crate::fast_mult::fast_mult;

#[test]
fn test_eval_and_commit<G: GroupElement>() {
Expand Down Expand Up @@ -308,44 +309,32 @@ mod points_tests {
let x = 1u128 << 109; // 110 bit set
let y = 1u128 << 17; // 18 bit set
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::generator(), x * y)
fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y)
);

let x = 1u128 << 17;
let y = 1u128 << 109;
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::generator(), x * y)
fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y)
);

let x = 1u128 << (109 - 1); // all 109 bits set
let y = 1u128 << (19 - 1); // all 19 bits set
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::generator(), x * y)
fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::generator(), x * y)
);

let x = 1u128 << 120;
let y = 1u128 << 13;
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::from(x), y)
);
assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y));

let x = 1u128 << 21;
let y = 1u128 << 120;
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::from(x), y)
);
assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y));

let x = u128::MAX;
let y = 1u128;
assert!(
Poly::<G::ScalarType>::fast_mult((G::ScalarType::generator(), x), y)
== (G::ScalarType::from(x), y)
);
assert!(fast_mult((G::ScalarType::generator(), x), y) == (G::ScalarType::from(x), y));
}

#[instantiate_tests(<RistrettoPoint>)]
Expand Down
Loading