diff --git a/crypto/multiexp/src/batch.rs b/crypto/multiexp/src/batch.rs index 24384596..945c8bcb 100644 --- a/crypto/multiexp/src/batch.rs +++ b/crypto/multiexp/src/batch.rs @@ -130,6 +130,14 @@ where // If 1 and 2 were valid, this would've only taken 2 rounds to complete // To prevent this from being gamed, if there's an odd number of elements, randomize which // side the split occurs on + + // This does risk breaking determinism + // The concern is if the select split point causes different paths to be taken when multiple + // invalid elements exist + // While the split point may move an element from the right to the left, always choosing the + // left side (if it's invalid) means this will still always return the left-most, + // invalid element + if slice.len() % 2 == 1 { split += usize::try_from(split_side & 1).unwrap(); split_side >>= 1; diff --git a/crypto/multiexp/src/lib.rs b/crypto/multiexp/src/lib.rs index d0a4542a..321c5b87 100644 --- a/crypto/multiexp/src/lib.rs +++ b/crypto/multiexp/src/lib.rs @@ -169,6 +169,7 @@ where match algorithm(pairs.len()) { Algorithm::Null => Group::identity(), Algorithm::Single => pairs[0].1 * pairs[0].0, + // These functions panic if called without any pairs Algorithm::Straus(window) => straus(pairs, window), Algorithm::Pippenger(window) => pippenger(pairs, window), } diff --git a/crypto/multiexp/src/tests/batch.rs b/crypto/multiexp/src/tests/batch.rs new file mode 100644 index 00000000..02331a7b --- /dev/null +++ b/crypto/multiexp/src/tests/batch.rs @@ -0,0 +1,94 @@ +use rand_core::OsRng; + +use zeroize::Zeroize; + +use rand_core::RngCore; + +use ff::{Field, PrimeFieldBits}; +use group::Group; + +use crate::BatchVerifier; + +pub(crate) fn test_batch() +where + G::Scalar: PrimeFieldBits + Zeroize, +{ + let valid = |batch: BatchVerifier<_, G>| { + assert!(batch.verify()); + assert!(batch.verify_vartime()); + assert_eq!(batch.blame_vartime(), None); + assert_eq!(batch.verify_with_vartime_blame(), Ok(())); + assert_eq!(batch.verify_vartime_with_vartime_blame(), Ok(())); + }; + + let invalid = |batch: BatchVerifier<_, G>, id| { + assert!(!batch.verify()); + assert!(!batch.verify_vartime()); + assert_eq!(batch.blame_vartime(), Some(id)); + assert_eq!(batch.verify_with_vartime_blame(), Err(id)); + assert_eq!(batch.verify_vartime_with_vartime_blame(), Err(id)); + }; + + // Test an empty batch + let batch = BatchVerifier::new(0); + valid(batch); + + // Test a batch with one set of statements + let valid_statements = + vec![(-G::Scalar::one(), G::generator()), (G::Scalar::one(), G::generator())]; + let mut batch = BatchVerifier::new(1); + batch.queue(&mut OsRng, 0, valid_statements.clone()); + valid(batch); + + // Test a batch with an invalid set of statements fails properly + let invalid_statements = vec![(-G::Scalar::one(), G::generator())]; + let mut batch = BatchVerifier::new(1); + batch.queue(&mut OsRng, 0, invalid_statements.clone()); + invalid(batch, 0); + + // Test blame can properly identify faulty participants + // Run with 17 statements, rotating which one is faulty + for i in 0 .. 17 { + let mut batch = BatchVerifier::new(17); + for j in 0 .. 17 { + batch.queue( + &mut OsRng, + j, + if i == j { invalid_statements.clone() } else { valid_statements.clone() }, + ); + } + invalid(batch, i); + } + + // Test blame always identifies the left-most invalid statement + for i in 1 .. 32 { + for j in 1 .. i { + let mut batch = BatchVerifier::new(j); + let mut leftmost = None; + + // Create j statements + for k in 0 .. j { + batch.queue( + &mut OsRng, + k, + // The usage of i / 10 makes this less likely to add invalid elements, and increases + // the space between them + // For high i values, yet low j values, this will make it likely that random elements + // are at/near the end + if ((OsRng.next_u64() % u64::try_from(1 + (i / 4)).unwrap()) == 0) || + (leftmost.is_none() && (k == (j - 1))) + { + if leftmost.is_none() { + leftmost = Some(k); + } + invalid_statements.clone() + } else { + valid_statements.clone() + }, + ); + } + + invalid(batch, leftmost.unwrap()); + } + } +} diff --git a/crypto/multiexp/src/tests/mod.rs b/crypto/multiexp/src/tests/mod.rs index b587d356..550da1c0 100644 --- a/crypto/multiexp/src/tests/mod.rs +++ b/crypto/multiexp/src/tests/mod.rs @@ -10,7 +10,12 @@ use group::Group; use k256::ProjectivePoint; use dalek_ff_group::EdwardsPoint; -use crate::{straus, pippenger, multiexp, multiexp_vartime}; +use crate::{straus, straus_vartime, pippenger, pippenger_vartime, multiexp, multiexp_vartime}; + +#[cfg(feature = "batch")] +mod batch; +#[cfg(feature = "batch")] +use batch::test_batch; #[allow(dead_code)] fn benchmark_internal(straus_bool: bool) @@ -85,26 +90,59 @@ fn test_multiexp() where G::Scalar: PrimeFieldBits + Zeroize, { + let test = |pairs: &[_], sum| { + // These should automatically determine the best algorithm + assert_eq!(multiexp(pairs), sum); + assert_eq!(multiexp_vartime(pairs), sum); + + // Also explicitly test straus/pippenger for each bit size + if !pairs.is_empty() { + for window in 1 .. 8 { + assert_eq!(straus(pairs, window), sum); + assert_eq!(straus_vartime(pairs, window), sum); + assert_eq!(pippenger(pairs, window), sum); + assert_eq!(pippenger_vartime(pairs, window), sum); + } + } + }; + + // Test an empty multiexp is identity + test(&[], G::identity()); + + // Test an multiexp of identity/zero elements is identity + test(&[(G::Scalar::zero(), G::generator())], G::identity()); + test(&[(G::Scalar::one(), G::identity())], G::identity()); + + // Test a variety of multiexp sizes let mut pairs = Vec::with_capacity(1000); let mut sum = G::identity(); for _ in 0 .. 10 { + // Test a multiexp of a single item + // On successive loop iterations, this will test a multiexp with an odd number of pairs + pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng))); + sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; + test(&pairs, sum); + for _ in 0 .. 100 { pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng))); sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0; } - assert_eq!(multiexp(&pairs), sum); - assert_eq!(multiexp_vartime(&pairs), sum); + test(&pairs, sum); } } #[test] fn test_secp256k1() { test_multiexp::(); + #[cfg(feature = "batch")] + test_batch::(); } #[test] fn test_ed25519() { test_multiexp::(); + #[cfg(feature = "batch")] + test_batch::(); } #[ignore]