From a66994aadec8b51a16b0ca843ca59146c29786b1 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Sun, 27 Aug 2023 15:33:17 -0400 Subject: [PATCH] Use FCMP implementation of BP+ in monero-serai (#344) * Add in an implementation of BP+ based off the paper, intended for clarity and review This was done as part of my work on FCMPs from Monero, and is copied from https://github.com/kayabaNerve/full-chain-membership-proofs * Remove crate structure of BP+ * Remove arithmetic circuit code * Remove AC/VC generators code * Remove generator transcript Monero uses non-transcripted static generators. * Further trimming of generators * Remove the single range proof It's unused by Monero and accordingly unhelpful. * Work on getting BP+ to compile in its new env * Correct BP+ folder name * Further tweaks to get closer to compiling * Remove the ScalarMatrix file It's only used for AC proofs * Compiles, with tests passing * Lock BP+ to Ed25519 instead of the generic Ciphersuite * Resolve most warnings in BP+ * Make existing bulletproofs test easier to read * Further strip generators * Swap G/H as Monero did * Replace RangeCommitment with Commitment * Hard-code BP+ h to Ed25519's generator * Use pub(crate) for BP+, not pub * Replace initial_transcript with hash_plus * Rename hash_plus to initial_transcript * Finish integrating the FCMP BP+ impl * Move BP+ folder * Correct no-std support * Rename "long_n" to eta * Add note on non-prime order dfg points --- coins/monero/build.rs | 2 +- coins/monero/src/ringct/bulletproofs/mod.rs | 87 ++-- coins/monero/src/ringct/bulletproofs/plus.rs | 310 ------------ .../plus/aggregate_range_proof.rs | 249 ++++++++++ .../src/ringct/bulletproofs/plus/mod.rs | 92 ++++ .../ringct/bulletproofs/plus/point_vector.rs | 50 ++ .../ringct/bulletproofs/plus/scalar_vector.rs | 114 +++++ .../ringct/bulletproofs/plus/transcript.rs | 24 + .../plus/weighted_inner_product.rs | 445 ++++++++++++++++++ .../src/ringct/bulletproofs/scalar_vector.rs | 27 -- .../{bulletproofs.rs => bulletproofs/mod.rs} | 4 +- .../plus/aggregate_range_proof.rs | 30 ++ .../monero/src/tests/bulletproofs/plus/mod.rs | 4 + .../plus/weighted_inner_product.rs | 82 ++++ 14 files changed, 1154 insertions(+), 366 deletions(-) delete mode 100644 coins/monero/src/ringct/bulletproofs/plus.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/mod.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/point_vector.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/transcript.rs create mode 100644 coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs rename coins/monero/src/tests/{bulletproofs.rs => bulletproofs/mod.rs} (99%) create mode 100644 coins/monero/src/tests/bulletproofs/plus/aggregate_range_proof.rs create mode 100644 coins/monero/src/tests/bulletproofs/plus/mod.rs create mode 100644 coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs diff --git a/coins/monero/build.rs b/coins/monero/build.rs index 34c34b6b..a54a3f2d 100644 --- a/coins/monero/build.rs +++ b/coins/monero/build.rs @@ -41,7 +41,7 @@ fn generators(prefix: &'static str, path: &str) { .write_all( format!( " - pub static GENERATORS_CELL: OnceLock = OnceLock::new(); + pub(crate) static GENERATORS_CELL: OnceLock = OnceLock::new(); pub fn GENERATORS() -> &'static Generators {{ GENERATORS_CELL.get_or_init(|| Generators {{ G: [ diff --git a/coins/monero/src/ringct/bulletproofs/mod.rs b/coins/monero/src/ringct/bulletproofs/mod.rs index e7e071f5..6b25b1a0 100644 --- a/coins/monero/src/ringct/bulletproofs/mod.rs +++ b/coins/monero/src/ringct/bulletproofs/mod.rs @@ -19,12 +19,10 @@ pub(crate) mod core; use self::core::LOG_N; pub(crate) mod original; -pub use original::GENERATORS as BULLETPROOFS_GENERATORS; -pub(crate) mod plus; -pub use plus::GENERATORS as BULLETPROOFS_PLUS_GENERATORS; +use self::original::OriginalStruct; -pub(crate) use self::original::OriginalStruct; -pub(crate) use self::plus::PlusStruct; +pub(crate) mod plus; +use self::plus::*; pub(crate) const MAX_OUTPUTS: usize = self::core::MAX_M; @@ -33,7 +31,7 @@ pub(crate) const MAX_OUTPUTS: usize = self::core::MAX_M; #[derive(Clone, PartialEq, Eq, Debug)] pub enum Bulletproofs { Original(OriginalStruct), - Plus(PlusStruct), + Plus(AggregateRangeProof), } impl Bulletproofs { @@ -80,13 +78,22 @@ impl Bulletproofs { outputs: &[Commitment], plus: bool, ) -> Result { + if outputs.is_empty() { + Err(TransactionError::NoOutputs)?; + } if outputs.len() > MAX_OUTPUTS { - return Err(TransactionError::TooManyOutputs)?; + Err(TransactionError::TooManyOutputs)?; } Ok(if !plus { Bulletproofs::Original(OriginalStruct::prove(rng, outputs)) } else { - Bulletproofs::Plus(PlusStruct::prove(rng, outputs)) + use dalek_ff_group::EdwardsPoint as DfgPoint; + Bulletproofs::Plus( + AggregateRangeStatement::new(outputs.iter().map(|com| DfgPoint(com.calculate())).collect()) + .unwrap() + .prove(rng, AggregateRangeWitness::new(outputs).unwrap()) + .unwrap(), + ) }) } @@ -95,7 +102,22 @@ impl Bulletproofs { pub fn verify(&self, rng: &mut R, commitments: &[EdwardsPoint]) -> bool { match self { Bulletproofs::Original(bp) => bp.verify(rng, commitments), - Bulletproofs::Plus(bp) => bp.verify(rng, commitments), + Bulletproofs::Plus(bp) => { + let mut verifier = BatchVerifier::new(1); + // If this commitment is torsioned (which is allowed), this won't be a well-formed + // dfg::EdwardsPoint (expected to be of prime-order) + // The actual BP+ impl will perform a torsion clear though, making this safe + // TODO: Have AggregateRangeStatement take in dalek EdwardsPoint for clarity on this + let Some(statement) = AggregateRangeStatement::new( + commitments.iter().map(|c| dalek_ff_group::EdwardsPoint(*c)).collect(), + ) else { + return false; + }; + if !statement.verify(rng, &mut verifier, (), bp.clone()) { + return false; + } + verifier.verify_vartime() + } } } @@ -112,7 +134,14 @@ impl Bulletproofs { ) -> bool { match self { Bulletproofs::Original(bp) => bp.batch_verify(rng, verifier, id, commitments), - Bulletproofs::Plus(bp) => bp.batch_verify(rng, verifier, id, commitments), + Bulletproofs::Plus(bp) => { + let Some(statement) = AggregateRangeStatement::new( + commitments.iter().map(|c| dalek_ff_group::EdwardsPoint(*c)).collect(), + ) else { + return false; + }; + statement.verify(rng, verifier, id, bp.clone()) + } } } @@ -137,14 +166,14 @@ impl Bulletproofs { } Bulletproofs::Plus(bp) => { - write_point(&bp.A, w)?; - write_point(&bp.A1, w)?; - write_point(&bp.B, w)?; - write_scalar(&bp.r1, w)?; - write_scalar(&bp.s1, w)?; - write_scalar(&bp.d1, w)?; - specific_write_vec(&bp.L, w)?; - specific_write_vec(&bp.R, w) + write_point(&bp.A.0, w)?; + write_point(&bp.wip.A.0, w)?; + write_point(&bp.wip.B.0, w)?; + write_scalar(&bp.wip.r_answer.0, w)?; + write_scalar(&bp.wip.s_answer.0, w)?; + write_scalar(&bp.wip.delta_answer.0, w)?; + specific_write_vec(&bp.wip.L.iter().cloned().map(|L| L.0).collect::>(), w)?; + specific_write_vec(&bp.wip.R.iter().cloned().map(|R| R.0).collect::>(), w) } } } @@ -182,15 +211,19 @@ impl Bulletproofs { /// Read Bulletproofs+. pub fn read_plus(r: &mut R) -> io::Result { - Ok(Bulletproofs::Plus(PlusStruct { - A: read_point(r)?, - A1: read_point(r)?, - B: read_point(r)?, - r1: read_scalar(r)?, - s1: read_scalar(r)?, - d1: read_scalar(r)?, - L: read_vec(read_point, r)?, - R: read_vec(read_point, r)?, + use dalek_ff_group::{Scalar as DfgScalar, EdwardsPoint as DfgPoint}; + + Ok(Bulletproofs::Plus(AggregateRangeProof { + A: DfgPoint(read_point(r)?), + wip: WipProof { + A: DfgPoint(read_point(r)?), + B: DfgPoint(read_point(r)?), + r_answer: DfgScalar(read_scalar(r)?), + s_answer: DfgScalar(read_scalar(r)?), + delta_answer: DfgScalar(read_scalar(r)?), + L: read_vec(read_point, r)?.into_iter().map(DfgPoint).collect(), + R: read_vec(read_point, r)?.into_iter().map(DfgPoint).collect(), + }, })) } } diff --git a/coins/monero/src/ringct/bulletproofs/plus.rs b/coins/monero/src/ringct/bulletproofs/plus.rs deleted file mode 100644 index 4d8d2fce..00000000 --- a/coins/monero/src/ringct/bulletproofs/plus.rs +++ /dev/null @@ -1,310 +0,0 @@ -use std_shims::{vec::Vec, sync::OnceLock}; - -use rand_core::{RngCore, CryptoRng}; - -use zeroize::Zeroize; - -use curve25519_dalek::{scalar::Scalar as DalekScalar, edwards::EdwardsPoint as DalekPoint}; - -use group::ff::Field; -use dalek_ff_group::{ED25519_BASEPOINT_POINT as G, Scalar, EdwardsPoint}; - -use multiexp::BatchVerifier; - -use crate::{ - Commitment, hash, - ringct::{hash_to_point::raw_hash_to_point, bulletproofs::core::*}, -}; - -include!(concat!(env!("OUT_DIR"), "/generators_plus.rs")); - -static TRANSCRIPT_CELL: OnceLock<[u8; 32]> = OnceLock::new(); -pub(crate) fn TRANSCRIPT() -> [u8; 32] { - *TRANSCRIPT_CELL.get_or_init(|| { - EdwardsPoint(raw_hash_to_point(hash(b"bulletproof_plus_transcript"))).compress().to_bytes() - }) -} - -// TRANSCRIPT isn't a Scalar, so we need this alternative for the first hash -fn hash_plus>(commitments: C) -> (Scalar, Vec) { - let (cache, commitments) = hash_commitments(commitments); - (hash_to_scalar(&[TRANSCRIPT().as_ref(), &cache.to_bytes()].concat()), commitments) -} - -// d[j*N+i] = z**(2*(j+1)) * 2**i -fn d(z: Scalar, M: usize, MN: usize) -> (ScalarVector, ScalarVector) { - let zpow = ScalarVector::even_powers(z, 2 * M); - let mut d = vec![Scalar::ZERO; MN]; - for j in 0 .. M { - for i in 0 .. N { - d[(j * N) + i] = zpow[j] * TWO_N()[i]; - } - } - (zpow, ScalarVector(d)) -} - -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct PlusStruct { - pub(crate) A: DalekPoint, - pub(crate) A1: DalekPoint, - pub(crate) B: DalekPoint, - pub(crate) r1: DalekScalar, - pub(crate) s1: DalekScalar, - pub(crate) d1: DalekScalar, - pub(crate) L: Vec, - pub(crate) R: Vec, -} - -impl PlusStruct { - pub(crate) fn prove( - rng: &mut R, - commitments: &[Commitment], - ) -> PlusStruct { - let generators = GENERATORS(); - - let (logMN, M, MN) = MN(commitments.len()); - - let (aL, aR) = bit_decompose(commitments); - let commitments_points = commitments.iter().map(Commitment::calculate).collect::>(); - let (mut cache, _) = hash_plus(commitments_points.clone()); - let (mut alpha1, A) = alpha_rho(&mut *rng, generators, &aL, &aR); - - let y = hash_cache(&mut cache, &[A.compress().to_bytes()]); - let mut cache = hash_to_scalar(&y.to_bytes()); - let z = cache; - - let (zpow, d) = d(z, M, MN); - - let aL1 = aL - z; - - let ypow = ScalarVector::powers(y, MN + 2); - let mut y_for_d = ScalarVector(ypow.0[1 ..= MN].to_vec()); - y_for_d.0.reverse(); - let aR1 = (aR + z) + (y_for_d * d); - - for (j, gamma) in commitments.iter().map(|c| Scalar(c.mask)).enumerate() { - alpha1 += zpow[j] * ypow[MN + 1] * gamma; - } - - let mut a = aL1; - let mut b = aR1; - - let yinv = y.invert().unwrap(); - let yinvpow = ScalarVector::powers(yinv, MN); - - let mut G_proof = generators.G[.. a.len()].to_vec(); - let mut H_proof = generators.H[.. a.len()].to_vec(); - - let mut L = Vec::with_capacity(logMN); - let mut R = Vec::with_capacity(logMN); - - while a.len() != 1 { - let (aL, aR) = a.split(); - let (bL, bR) = b.split(); - - let cL = weighted_inner_product(&aL, &bR, y); - let cR = weighted_inner_product(&(&aR * ypow[aR.len()]), &bL, y); - - let (mut dL, mut dR) = (Scalar::random(&mut *rng), Scalar::random(&mut *rng)); - - let (G_L, G_R) = G_proof.split_at(aL.len()); - let (H_L, H_R) = H_proof.split_at(aL.len()); - - let mut L_i = LR_statements(&(&aL * yinvpow[aL.len()]), G_R, &bR, H_L, cL, H()); - L_i.push((dL, G)); - let L_i = prove_multiexp(&L_i); - L.push(L_i); - - let mut R_i = LR_statements(&(&aR * ypow[aR.len()]), G_L, &bL, H_R, cR, H()); - R_i.push((dR, G)); - let R_i = prove_multiexp(&R_i); - R.push(R_i); - - let w = hash_cache(&mut cache, &[L_i.compress().to_bytes(), R_i.compress().to_bytes()]); - let winv = w.invert().unwrap(); - - G_proof = hadamard_fold(G_L, G_R, winv, w * yinvpow[aL.len()]); - H_proof = hadamard_fold(H_L, H_R, w, winv); - - a = (&aL * w) + (aR * (winv * ypow[aL.len()])); - b = (bL * winv) + (bR * w); - - alpha1 += (dL * (w * w)) + (dR * (winv * winv)); - - dL.zeroize(); - dR.zeroize(); - } - - let mut r = Scalar::random(&mut *rng); - let mut s = Scalar::random(&mut *rng); - let mut d = Scalar::random(&mut *rng); - let mut eta = Scalar::random(&mut *rng); - - let A1 = prove_multiexp(&[ - (r, G_proof[0]), - (s, H_proof[0]), - (d, G), - ((r * y * b[0]) + (s * y * a[0]), H()), - ]); - let B = prove_multiexp(&[(r * y * s, H()), (eta, G)]); - let e = hash_cache(&mut cache, &[A1.compress().to_bytes(), B.compress().to_bytes()]); - - let r1 = (a[0] * e) + r; - r.zeroize(); - let s1 = (b[0] * e) + s; - s.zeroize(); - let d1 = ((d * e) + eta) + (alpha1 * (e * e)); - d.zeroize(); - eta.zeroize(); - alpha1.zeroize(); - - let res = PlusStruct { - A: *A, - A1: *A1, - B: *B, - r1: *r1, - s1: *s1, - d1: *d1, - L: L.drain(..).map(|L| *L).collect(), - R: R.drain(..).map(|R| *R).collect(), - }; - debug_assert!(res.verify(rng, &commitments_points)); - res - } - - #[must_use] - fn verify_core( - &self, - rng: &mut R, - verifier: &mut BatchVerifier, - id: ID, - commitments: &[DalekPoint], - ) -> bool { - // Verify commitments are valid - if commitments.is_empty() || (commitments.len() > MAX_M) { - return false; - } - - // Verify L and R are properly sized - if self.L.len() != self.R.len() { - return false; - } - - let (logMN, M, MN) = MN(commitments.len()); - if self.L.len() != logMN { - return false; - } - - // Rebuild all challenges - let (mut cache, commitments) = hash_plus(commitments.iter().copied()); - let y = hash_cache(&mut cache, &[self.A.compress().to_bytes()]); - let yinv = y.invert().unwrap(); - let z = hash_to_scalar(&y.to_bytes()); - cache = z; - - let mut w = Vec::with_capacity(logMN); - let mut winv = Vec::with_capacity(logMN); - for (L, R) in self.L.iter().zip(&self.R) { - w.push(hash_cache(&mut cache, &[L.compress().to_bytes(), R.compress().to_bytes()])); - winv.push(cache.invert().unwrap()); - } - - let e = hash_cache(&mut cache, &[self.A1.compress().to_bytes(), self.B.compress().to_bytes()]); - - // Convert the proof from * INV_EIGHT to its actual form - let normalize = |point: &DalekPoint| EdwardsPoint(point.mul_by_cofactor()); - - let L = self.L.iter().map(normalize).collect::>(); - let R = self.R.iter().map(normalize).collect::>(); - let A = normalize(&self.A); - let A1 = normalize(&self.A1); - let B = normalize(&self.B); - - // Verify it - let mut proof = Vec::with_capacity(logMN + 5 + (2 * (MN + logMN))); - - let mut yMN = y; - for _ in 0 .. logMN { - yMN *= yMN; - } - let yMNy = yMN * y; - - let (zpow, d) = d(z, M, MN); - let zsq = zpow[0]; - - let esq = e * e; - let minus_esq = -esq; - let commitment_weight = minus_esq * yMNy; - for (i, commitment) in commitments.iter().map(EdwardsPoint::mul_by_cofactor).enumerate() { - proof.push((commitment_weight * zpow[i], commitment)); - } - - // Invert B, instead of the Scalar, as the latter is only 2x as expensive yet enables reduction - // to a single addition under vartime for the first BP verified in the batch, which is expected - // to be much more significant - proof.push((Scalar::ONE, -B)); - proof.push((-e, A1)); - proof.push((minus_esq, A)); - proof.push((Scalar(self.d1), G)); - - let d_sum = zpow.sum() * Scalar::from(u64::MAX); - let y_sum = weighted_powers(y, MN).sum(); - proof.push(( - Scalar(self.r1 * y.0 * self.s1) + (esq * ((yMNy * z * d_sum) + ((zsq - z) * y_sum))), - H(), - )); - - let w_cache = challenge_products(&w, &winv); - - let mut e_r1_y = e * Scalar(self.r1); - let e_s1 = e * Scalar(self.s1); - let esq_z = esq * z; - let minus_esq_z = -esq_z; - let mut minus_esq_y = minus_esq * yMN; - - let generators = GENERATORS(); - for i in 0 .. MN { - proof.push((e_r1_y * w_cache[i] + esq_z, generators.G[i])); - proof.push(( - (e_s1 * w_cache[(!i) & (MN - 1)]) + minus_esq_z + (minus_esq_y * d[i]), - generators.H[i], - )); - - e_r1_y *= yinv; - minus_esq_y *= yinv; - } - - for i in 0 .. logMN { - proof.push((minus_esq * w[i] * w[i], L[i])); - proof.push((minus_esq * winv[i] * winv[i], R[i])); - } - - verifier.queue(rng, id, proof); - true - } - - #[must_use] - pub(crate) fn verify( - &self, - rng: &mut R, - commitments: &[DalekPoint], - ) -> bool { - let mut verifier = BatchVerifier::new(1); - if self.verify_core(rng, &mut verifier, (), commitments) { - verifier.verify_vartime() - } else { - false - } - } - - #[must_use] - pub(crate) fn batch_verify( - &self, - rng: &mut R, - verifier: &mut BatchVerifier, - id: ID, - commitments: &[DalekPoint], - ) -> bool { - self.verify_core(rng, verifier, id, commitments) - } -} diff --git a/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs b/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs new file mode 100644 index 00000000..b99e5f52 --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs @@ -0,0 +1,249 @@ +use std_shims::vec::Vec; + +use rand_core::{RngCore, CryptoRng}; + +use zeroize::{Zeroize, ZeroizeOnDrop}; + +use multiexp::{multiexp, multiexp_vartime, BatchVerifier}; +use group::{ + ff::{Field, PrimeField}, + Group, GroupEncoding, +}; +use dalek_ff_group::{Scalar, EdwardsPoint}; + +use crate::{ + Commitment, + ringct::{ + bulletproofs::core::{MAX_M, N}, + bulletproofs::plus::{ + ScalarVector, PointVector, GeneratorsList, Generators, + transcript::*, + weighted_inner_product::{WipStatement, WipWitness, WipProof}, + padded_pow_of_2, u64_decompose, + }, + }, +}; + +// Figure 3 +#[derive(Clone, Debug)] +pub(crate) struct AggregateRangeStatement { + generators: Generators, + V: Vec, +} + +impl Zeroize for AggregateRangeStatement { + fn zeroize(&mut self) { + self.V.zeroize(); + } +} + +#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)] +pub(crate) struct AggregateRangeWitness { + values: Vec, + gammas: Vec, +} + +impl AggregateRangeWitness { + pub(crate) fn new(commitments: &[Commitment]) -> Option { + if commitments.is_empty() || (commitments.len() > MAX_M) { + return None; + } + + let mut values = Vec::with_capacity(commitments.len()); + let mut gammas = Vec::with_capacity(commitments.len()); + for commitment in commitments { + values.push(commitment.amount); + gammas.push(Scalar(commitment.mask)); + } + Some(AggregateRangeWitness { values, gammas }) + } +} + +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] +pub struct AggregateRangeProof { + pub(crate) A: EdwardsPoint, + pub(crate) wip: WipProof, +} + +impl AggregateRangeStatement { + pub(crate) fn new(V: Vec) -> Option { + if V.is_empty() || (V.len() > MAX_M) { + return None; + } + + Some(Self { generators: Generators::new(), V }) + } + + fn transcript_A(transcript: &mut Scalar, A: EdwardsPoint) -> (Scalar, Scalar) { + let y = hash_to_scalar(&[transcript.to_repr().as_ref(), A.to_bytes().as_ref()].concat()); + let z = hash_to_scalar(y.to_bytes().as_ref()); + *transcript = z; + (y, z) + } + + fn d_j(j: usize, m: usize) -> ScalarVector { + let mut d_j = Vec::with_capacity(m * N); + for _ in 0 .. (j - 1) * N { + d_j.push(Scalar::ZERO); + } + d_j.append(&mut ScalarVector::powers(Scalar::from(2u8), N).0); + for _ in 0 .. (m - j) * N { + d_j.push(Scalar::ZERO); + } + ScalarVector(d_j) + } + + fn compute_A_hat( + mut V: PointVector, + generators: &Generators, + transcript: &mut Scalar, + mut A: EdwardsPoint, + ) -> (Scalar, ScalarVector, Scalar, Scalar, ScalarVector, EdwardsPoint) { + let (y, z) = Self::transcript_A(transcript, A); + A = A.mul_by_cofactor(); + + while V.len() < padded_pow_of_2(V.len()) { + V.0.push(EdwardsPoint::identity()); + } + let mn = V.len() * N; + + let mut z_pow = Vec::with_capacity(V.len()); + + let mut d = ScalarVector::new(mn); + for j in 1 ..= V.len() { + z_pow.push(z.pow(Scalar::from(2 * u64::try_from(j).unwrap()))); // TODO: Optimize this + d = d.add_vec(&Self::d_j(j, V.len()).mul(z_pow[j - 1])); + } + + let mut ascending_y = ScalarVector(vec![y]); + for i in 1 .. d.len() { + ascending_y.0.push(ascending_y[i - 1] * y); + } + let y_pows = ascending_y.clone().sum(); + + let mut descending_y = ascending_y.clone(); + descending_y.0.reverse(); + + let d_descending_y = d.mul_vec(&descending_y); + + let y_mn_plus_one = descending_y[0] * y; + + let mut commitment_accum = EdwardsPoint::identity(); + for (j, commitment) in V.0.iter().enumerate() { + commitment_accum += *commitment * z_pow[j]; + } + + let neg_z = -z; + let mut A_terms = Vec::with_capacity((generators.len() * 2) + 2); + for (i, d_y_z) in d_descending_y.add(z).0.drain(..).enumerate() { + A_terms.push((neg_z, generators.generator(GeneratorsList::GBold1, i))); + A_terms.push((d_y_z, generators.generator(GeneratorsList::HBold1, i))); + } + A_terms.push((y_mn_plus_one, commitment_accum)); + A_terms.push(( + ((y_pows * z) - (d.sum() * y_mn_plus_one * z) - (y_pows * z.square())), + generators.g(), + )); + + (y, d_descending_y, y_mn_plus_one, z, ScalarVector(z_pow), A + multiexp_vartime(&A_terms)) + } + + pub(crate) fn prove( + self, + rng: &mut R, + witness: AggregateRangeWitness, + ) -> Option { + // Check for consistency with the witness + if self.V.len() != witness.values.len() { + return None; + } + for (commitment, (value, gamma)) in + self.V.iter().zip(witness.values.iter().zip(witness.gammas.iter())) + { + if Commitment::new(**gamma, *value).calculate() != **commitment { + return None; + } + } + + let Self { generators, V } = self; + // Monero expects all of these points to be torsion-free + // Generally, for Bulletproofs, it sends points * INV_EIGHT and then performs a torsion clear + // by multiplying by 8 + // This also restores the original value due to the preprocessing + // Commitments aren't transmitted INV_EIGHT though, so this multiplies by INV_EIGHT to enable + // clearing its cofactor without mutating the value + // For some reason, these values are transcripted * INV_EIGHT, not as transmitted + let mut V = V.into_iter().map(|V| EdwardsPoint(V.0 * crate::INV_EIGHT())).collect::>(); + let mut transcript = initial_transcript(V.iter()); + V.iter_mut().for_each(|V| *V = V.mul_by_cofactor()); + + // Pad V + while V.len() < padded_pow_of_2(V.len()) { + V.push(EdwardsPoint::identity()); + } + + let generators = generators.reduce(V.len() * N); + + let mut d_js = Vec::with_capacity(V.len()); + let mut a_l = ScalarVector(Vec::with_capacity(V.len() * N)); + for j in 1 ..= V.len() { + d_js.push(Self::d_j(j, V.len())); + a_l.0.append(&mut u64_decompose(*witness.values.get(j - 1).unwrap_or(&0)).0); + } + + let a_r = a_l.sub(Scalar::ONE); + + let alpha = Scalar::random(&mut *rng); + + let mut A_terms = Vec::with_capacity((generators.len() * 2) + 1); + for (i, a_l) in a_l.0.iter().enumerate() { + A_terms.push((*a_l, generators.generator(GeneratorsList::GBold1, i))); + } + for (i, a_r) in a_r.0.iter().enumerate() { + A_terms.push((*a_r, generators.generator(GeneratorsList::HBold1, i))); + } + A_terms.push((alpha, generators.h())); + let mut A = multiexp(&A_terms); + A_terms.zeroize(); + + // Multiply by INV_EIGHT per earlier commentary + A.0 *= crate::INV_EIGHT(); + + let (y, d_descending_y, y_mn_plus_one, z, z_pow, A_hat) = + Self::compute_A_hat(PointVector(V), &generators, &mut transcript, A); + + let a_l = a_l.sub(z); + let a_r = a_r.add_vec(&d_descending_y).add(z); + let mut alpha = alpha; + for j in 1 ..= witness.gammas.len() { + alpha += z_pow[j - 1] * witness.gammas[j - 1] * y_mn_plus_one; + } + + Some(AggregateRangeProof { + A, + wip: WipStatement::new(generators, A_hat, y) + .prove(rng, transcript, WipWitness::new(a_l, a_r, alpha).unwrap()) + .unwrap(), + }) + } + + pub(crate) fn verify( + self, + rng: &mut R, + verifier: &mut BatchVerifier, + id: Id, + proof: AggregateRangeProof, + ) -> bool { + let Self { generators, V } = self; + + let mut V = V.into_iter().map(|V| EdwardsPoint(V.0 * crate::INV_EIGHT())).collect::>(); + let mut transcript = initial_transcript(V.iter()); + V.iter_mut().for_each(|V| *V = V.mul_by_cofactor()); + + let generators = generators.reduce(V.len() * N); + + let (y, _, _, _, _, A_hat) = + Self::compute_A_hat(PointVector(V), &generators, &mut transcript, proof.A); + WipStatement::new(generators, A_hat, y).verify(rng, verifier, id, transcript, proof.wip) + } +} diff --git a/coins/monero/src/ringct/bulletproofs/plus/mod.rs b/coins/monero/src/ringct/bulletproofs/plus/mod.rs new file mode 100644 index 00000000..f52677ee --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/mod.rs @@ -0,0 +1,92 @@ +#![allow(non_snake_case)] + +use group::Group; +use dalek_ff_group::{Scalar, EdwardsPoint}; + +mod scalar_vector; +pub(crate) use scalar_vector::{ScalarVector, weighted_inner_product}; +mod point_vector; +pub(crate) use point_vector::PointVector; + +pub(crate) mod transcript; +pub(crate) mod weighted_inner_product; +pub(crate) use weighted_inner_product::*; +pub(crate) mod aggregate_range_proof; +pub(crate) use aggregate_range_proof::*; + +pub(crate) fn padded_pow_of_2(i: usize) -> usize { + let mut next_pow_of_2 = 1; + while next_pow_of_2 < i { + next_pow_of_2 <<= 1; + } + next_pow_of_2 +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub(crate) enum GeneratorsList { + GBold1, + HBold1, +} + +// TODO: Table these +#[derive(Clone, Debug)] +pub(crate) struct Generators { + g: EdwardsPoint, + + g_bold1: &'static [EdwardsPoint], + h_bold1: &'static [EdwardsPoint], +} + +mod generators { + use std_shims::sync::OnceLock; + use monero_generators::Generators; + include!(concat!(env!("OUT_DIR"), "/generators_plus.rs")); +} + +impl Generators { + #[allow(clippy::new_without_default)] + pub(crate) fn new() -> Self { + let gens = generators::GENERATORS(); + Generators { g: dalek_ff_group::EdwardsPoint(crate::H()), g_bold1: &gens.G, h_bold1: &gens.H } + } + + pub(crate) fn len(&self) -> usize { + self.g_bold1.len() + } + + pub(crate) fn g(&self) -> EdwardsPoint { + self.g + } + + pub(crate) fn h(&self) -> EdwardsPoint { + EdwardsPoint::generator() + } + + pub(crate) fn generator(&self, list: GeneratorsList, i: usize) -> EdwardsPoint { + match list { + GeneratorsList::GBold1 => self.g_bold1[i], + GeneratorsList::HBold1 => self.h_bold1[i], + } + } + + pub(crate) fn reduce(&self, generators: usize) -> Self { + // Round to the nearest power of 2 + let generators = padded_pow_of_2(generators); + assert!(generators <= self.g_bold1.len()); + + Generators { + g: self.g, + g_bold1: &self.g_bold1[.. generators], + h_bold1: &self.h_bold1[.. generators], + } + } +} + +// Returns the little-endian decomposition. +fn u64_decompose(value: u64) -> ScalarVector { + let mut bits = ScalarVector::new(64); + for bit in 0 .. 64 { + bits[bit] = Scalar::from((value >> bit) & 1); + } + bits +} diff --git a/coins/monero/src/ringct/bulletproofs/plus/point_vector.rs b/coins/monero/src/ringct/bulletproofs/plus/point_vector.rs new file mode 100644 index 00000000..ac753a01 --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/point_vector.rs @@ -0,0 +1,50 @@ +use core::ops::{Index, IndexMut}; +use std_shims::vec::Vec; + +use zeroize::{Zeroize, ZeroizeOnDrop}; + +use dalek_ff_group::EdwardsPoint; + +#[cfg(test)] +use multiexp::multiexp; +#[cfg(test)] +use crate::ringct::bulletproofs::plus::ScalarVector; + +#[derive(Clone, PartialEq, Eq, Debug, Zeroize, ZeroizeOnDrop)] +pub(crate) struct PointVector(pub(crate) Vec); + +impl Index for PointVector { + type Output = EdwardsPoint; + fn index(&self, index: usize) -> &EdwardsPoint { + &self.0[index] + } +} + +impl IndexMut for PointVector { + fn index_mut(&mut self, index: usize) -> &mut EdwardsPoint { + &mut self.0[index] + } +} + +impl PointVector { + #[cfg(test)] + pub(crate) fn multiexp(&self, vector: &ScalarVector) -> EdwardsPoint { + debug_assert_eq!(self.len(), vector.len()); + let mut res = Vec::with_capacity(self.len()); + for (point, scalar) in self.0.iter().copied().zip(vector.0.iter().copied()) { + res.push((scalar, point)); + } + multiexp(&res) + } + + pub(crate) fn len(&self) -> usize { + self.0.len() + } + + pub(crate) fn split(mut self) -> (Self, Self) { + debug_assert!(self.len() > 1); + let r = self.0.split_off(self.0.len() / 2); + debug_assert_eq!(self.len(), r.len()); + (self, PointVector(r)) + } +} diff --git a/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs b/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs new file mode 100644 index 00000000..a8b0866e --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs @@ -0,0 +1,114 @@ +use core::{ + borrow::Borrow, + ops::{Index, IndexMut}, +}; +use std_shims::vec::Vec; + +use zeroize::Zeroize; + +use group::ff::Field; +use dalek_ff_group::Scalar; + +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] +pub(crate) struct ScalarVector(pub(crate) Vec); + +impl Index for ScalarVector { + type Output = Scalar; + fn index(&self, index: usize) -> &Scalar { + &self.0[index] + } +} + +impl IndexMut for ScalarVector { + fn index_mut(&mut self, index: usize) -> &mut Scalar { + &mut self.0[index] + } +} + +impl ScalarVector { + pub(crate) fn new(len: usize) -> Self { + ScalarVector(vec![Scalar::ZERO; len]) + } + + pub(crate) fn add(&self, scalar: impl Borrow) -> Self { + let mut res = self.clone(); + for val in res.0.iter_mut() { + *val += scalar.borrow(); + } + res + } + + pub(crate) fn sub(&self, scalar: impl Borrow) -> Self { + let mut res = self.clone(); + for val in res.0.iter_mut() { + *val -= scalar.borrow(); + } + res + } + + pub(crate) fn mul(&self, scalar: impl Borrow) -> Self { + let mut res = self.clone(); + for val in res.0.iter_mut() { + *val *= scalar.borrow(); + } + res + } + + pub(crate) fn add_vec(&self, vector: &Self) -> Self { + debug_assert_eq!(self.len(), vector.len()); + let mut res = self.clone(); + for (i, val) in res.0.iter_mut().enumerate() { + *val += vector.0[i]; + } + res + } + + pub(crate) fn mul_vec(&self, vector: &Self) -> Self { + debug_assert_eq!(self.len(), vector.len()); + let mut res = self.clone(); + for (i, val) in res.0.iter_mut().enumerate() { + *val *= vector.0[i]; + } + res + } + + pub(crate) fn inner_product(&self, vector: &Self) -> Scalar { + self.mul_vec(vector).sum() + } + + pub(crate) fn powers(x: Scalar, len: usize) -> Self { + debug_assert!(len != 0); + + let mut res = Vec::with_capacity(len); + res.push(Scalar::ONE); + res.push(x); + for i in 2 .. len { + res.push(res[i - 1] * x); + } + res.truncate(len); + ScalarVector(res) + } + + pub(crate) fn sum(mut self) -> Scalar { + self.0.drain(..).sum() + } + + pub(crate) fn len(&self) -> usize { + self.0.len() + } + + pub(crate) fn split(mut self) -> (Self, Self) { + debug_assert!(self.len() > 1); + let r = self.0.split_off(self.0.len() / 2); + debug_assert_eq!(self.len(), r.len()); + (self, ScalarVector(r)) + } +} + +pub(crate) fn weighted_inner_product( + a: &ScalarVector, + b: &ScalarVector, + y: &ScalarVector, +) -> Scalar { + a.inner_product(&b.mul_vec(y)) +} diff --git a/coins/monero/src/ringct/bulletproofs/plus/transcript.rs b/coins/monero/src/ringct/bulletproofs/plus/transcript.rs new file mode 100644 index 00000000..2108013b --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/transcript.rs @@ -0,0 +1,24 @@ +use std_shims::{sync::OnceLock, vec::Vec}; + +use dalek_ff_group::{Scalar, EdwardsPoint}; + +use monero_generators::{hash_to_point as raw_hash_to_point}; +use crate::{hash, hash_to_scalar as dalek_hash}; + +// Monero starts BP+ transcripts with the following constant. +static TRANSCRIPT_CELL: OnceLock<[u8; 32]> = OnceLock::new(); +pub(crate) fn TRANSCRIPT() -> [u8; 32] { + // Why this uses a hash_to_point is completely unknown. + *TRANSCRIPT_CELL + .get_or_init(|| raw_hash_to_point(hash(b"bulletproof_plus_transcript")).compress().to_bytes()) +} + +pub(crate) fn hash_to_scalar(data: &[u8]) -> Scalar { + Scalar(dalek_hash(data)) +} + +pub(crate) fn initial_transcript(commitments: core::slice::Iter<'_, EdwardsPoint>) -> Scalar { + let commitments_hash = + hash_to_scalar(&commitments.flat_map(|V| V.compress().to_bytes()).collect::>()); + hash_to_scalar(&[TRANSCRIPT().as_ref(), &commitments_hash.to_bytes()].concat()) +} diff --git a/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs b/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs new file mode 100644 index 00000000..6840be37 --- /dev/null +++ b/coins/monero/src/ringct/bulletproofs/plus/weighted_inner_product.rs @@ -0,0 +1,445 @@ +use std_shims::vec::Vec; + +use rand_core::{RngCore, CryptoRng}; + +use zeroize::{Zeroize, ZeroizeOnDrop}; + +use multiexp::{multiexp, multiexp_vartime, BatchVerifier}; +use group::{ + ff::{Field, PrimeField}, + GroupEncoding, +}; +use dalek_ff_group::{Scalar, EdwardsPoint}; + +use crate::ringct::bulletproofs::plus::{ + ScalarVector, PointVector, GeneratorsList, Generators, padded_pow_of_2, weighted_inner_product, + transcript::*, +}; + +// Figure 1 +#[derive(Clone, Debug)] +pub(crate) struct WipStatement { + generators: Generators, + P: EdwardsPoint, + y: ScalarVector, +} + +impl Zeroize for WipStatement { + fn zeroize(&mut self) { + self.P.zeroize(); + self.y.zeroize(); + } +} + +#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)] +pub(crate) struct WipWitness { + a: ScalarVector, + b: ScalarVector, + alpha: Scalar, +} + +impl WipWitness { + pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option { + if a.0.is_empty() || (a.len() != b.len()) { + return None; + } + + // Pad to the nearest power of 2 + let missing = padded_pow_of_2(a.len()) - a.len(); + a.0.reserve(missing); + b.0.reserve(missing); + for _ in 0 .. missing { + a.0.push(Scalar::ZERO); + b.0.push(Scalar::ZERO); + } + + Some(Self { a, b, alpha }) + } +} + +#[derive(Clone, PartialEq, Eq, Debug, Zeroize)] +pub(crate) struct WipProof { + pub(crate) L: Vec, + pub(crate) R: Vec, + pub(crate) A: EdwardsPoint, + pub(crate) B: EdwardsPoint, + pub(crate) r_answer: Scalar, + pub(crate) s_answer: Scalar, + pub(crate) delta_answer: Scalar, +} + +impl WipStatement { + pub(crate) fn new(generators: Generators, P: EdwardsPoint, y: Scalar) -> Self { + debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len())); + + // y ** n + let mut y_vec = ScalarVector::new(generators.len()); + y_vec[0] = y; + for i in 1 .. y_vec.len() { + y_vec[i] = y_vec[i - 1] * y; + } + + Self { generators, P, y: y_vec } + } + + fn transcript_L_R(transcript: &mut Scalar, L: EdwardsPoint, R: EdwardsPoint) -> Scalar { + let e = hash_to_scalar( + &[transcript.to_repr().as_ref(), L.to_bytes().as_ref(), R.to_bytes().as_ref()].concat(), + ); + *transcript = e; + e + } + + fn transcript_A_B(transcript: &mut Scalar, A: EdwardsPoint, B: EdwardsPoint) -> Scalar { + let e = hash_to_scalar( + &[transcript.to_repr().as_ref(), A.to_bytes().as_ref(), B.to_bytes().as_ref()].concat(), + ); + *transcript = e; + e + } + + // Prover's variant of the shared code block to calculate G/H/P when n > 1 + // Returns each permutation of G/H since the prover needs to do operation on each permutation + // P is dropped as it's unused in the prover's path + // TODO: It'd still probably be faster to keep in terms of the original generators, both between + // the reduced amount of group operations and the potential tabling of the generators under + // multiexp + #[allow(clippy::too_many_arguments)] + fn next_G_H( + transcript: &mut Scalar, + mut g_bold1: PointVector, + mut g_bold2: PointVector, + mut h_bold1: PointVector, + mut h_bold2: PointVector, + L: EdwardsPoint, + R: EdwardsPoint, + y_inv_n_hat: Scalar, + ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) { + debug_assert_eq!(g_bold1.len(), g_bold2.len()); + debug_assert_eq!(h_bold1.len(), h_bold2.len()); + debug_assert_eq!(g_bold1.len(), h_bold1.len()); + + let e = Self::transcript_L_R(transcript, L, R); + let inv_e = e.invert().unwrap(); + + // This vartime is safe as all of these arguments are public + let mut new_g_bold = Vec::with_capacity(g_bold1.len()); + let e_y_inv = e * y_inv_n_hat; + for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) { + new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)])); + } + + let mut new_h_bold = Vec::with_capacity(h_bold1.len()); + for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) { + new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)])); + } + + let e_square = e.square(); + let inv_e_square = inv_e.square(); + + (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold)) + } + + /* + This has room for optimization worth investigating further. It currently takes + an iterative approach. It can be optimized further via divide and conquer. + + Assume there are 4 challenges. + + Iterative approach (current): + 1. Do the optimal multiplications across challenge column 0 and 1. + 2. Do the optimal multiplications across that result and column 2. + 3. Do the optimal multiplications across that result and column 3. + + Divide and conquer (worth investigating further): + 1. Do the optimal multiplications across challenge column 0 and 1. + 2. Do the optimal multiplications across challenge column 2 and 3. + 3. Multiply both results together. + + When there are 4 challenges (n=16), the iterative approach does 28 multiplications + versus divide and conquer's 24. + */ + fn challenge_products(challenges: &[(Scalar, Scalar)]) -> Vec { + let mut products = vec![Scalar::ONE; 1 << challenges.len()]; + + if !challenges.is_empty() { + products[0] = challenges[0].1; + products[1] = challenges[0].0; + + for (j, challenge) in challenges.iter().enumerate().skip(1) { + let mut slots = (1 << (j + 1)) - 1; + while slots > 0 { + products[slots] = products[slots / 2] * challenge.0; + products[slots - 1] = products[slots / 2] * challenge.1; + + slots = slots.saturating_sub(2); + } + } + + // Sanity check since if the above failed to populate, it'd be critical + for product in &products { + debug_assert!(!bool::from(product.is_zero())); + } + } + + products + } + + pub(crate) fn prove( + self, + rng: &mut R, + mut transcript: Scalar, + witness: WipWitness, + ) -> Option { + let WipStatement { generators, P, mut y } = self; + + if generators.len() != witness.a.len() { + return None; + } + let (g, h) = (generators.g(), generators.h()); + let mut g_bold = vec![]; + let mut h_bold = vec![]; + for i in 0 .. generators.len() { + g_bold.push(generators.generator(GeneratorsList::GBold1, i)); + h_bold.push(generators.generator(GeneratorsList::HBold1, i)); + } + let mut g_bold = PointVector(g_bold); + let mut h_bold = PointVector(h_bold); + + // Check P has the expected relationship + #[cfg(debug_assertions)] + { + let mut P_terms = witness + .a + .0 + .iter() + .copied() + .zip(g_bold.0.iter().copied()) + .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied())) + .collect::>(); + P_terms.push((weighted_inner_product(&witness.a, &witness.b, &y), g)); + P_terms.push((witness.alpha, h)); + debug_assert_eq!(multiexp(&P_terms), P); + P_terms.zeroize(); + } + + let mut a = witness.a.clone(); + let mut b = witness.b.clone(); + let mut alpha = witness.alpha; + + // From here on, g_bold.len() is used as n + debug_assert_eq!(g_bold.len(), a.len()); + + let mut L_vec = vec![]; + let mut R_vec = vec![]; + + // else n > 1 case from figure 1 + while g_bold.len() > 1 { + let (a1, a2) = a.clone().split(); + let (b1, b2) = b.clone().split(); + let (g_bold1, g_bold2) = g_bold.split(); + let (h_bold1, h_bold2) = h_bold.split(); + + let n_hat = g_bold1.len(); + debug_assert_eq!(a1.len(), n_hat); + debug_assert_eq!(a2.len(), n_hat); + debug_assert_eq!(b1.len(), n_hat); + debug_assert_eq!(b2.len(), n_hat); + debug_assert_eq!(g_bold1.len(), n_hat); + debug_assert_eq!(g_bold2.len(), n_hat); + debug_assert_eq!(h_bold1.len(), n_hat); + debug_assert_eq!(h_bold2.len(), n_hat); + + let y_n_hat = y[n_hat - 1]; + y.0.truncate(n_hat); + + let d_l = Scalar::random(&mut *rng); + let d_r = Scalar::random(&mut *rng); + + let c_l = weighted_inner_product(&a1, &b2, &y); + let c_r = weighted_inner_product(&(a2.mul(y_n_hat)), &b1, &y); + + // TODO: Calculate these with a batch inversion + let y_inv_n_hat = y_n_hat.invert().unwrap(); + + let mut L_terms = a1 + .mul(y_inv_n_hat) + .0 + .drain(..) + .zip(g_bold2.0.iter().copied()) + .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied())) + .collect::>(); + L_terms.push((c_l, g)); + L_terms.push((d_l, h)); + let L = multiexp(&L_terms) * Scalar(crate::INV_EIGHT()); + L_vec.push(L); + L_terms.zeroize(); + + let mut R_terms = a2 + .mul(y_n_hat) + .0 + .drain(..) + .zip(g_bold1.0.iter().copied()) + .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied())) + .collect::>(); + R_terms.push((c_r, g)); + R_terms.push((d_r, h)); + let R = multiexp(&R_terms) * Scalar(crate::INV_EIGHT()); + R_vec.push(R); + R_terms.zeroize(); + + let (e, inv_e, e_square, inv_e_square); + (e, inv_e, e_square, inv_e_square, g_bold, h_bold) = + Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat); + + a = a1.mul(e).add_vec(&a2.mul(y_n_hat * inv_e)); + b = b1.mul(inv_e).add_vec(&b2.mul(e)); + alpha += (d_l * e_square) + (d_r * inv_e_square); + + debug_assert_eq!(g_bold.len(), a.len()); + debug_assert_eq!(g_bold.len(), h_bold.len()); + debug_assert_eq!(g_bold.len(), b.len()); + } + + // n == 1 case from figure 1 + debug_assert_eq!(g_bold.len(), 1); + debug_assert_eq!(h_bold.len(), 1); + + debug_assert_eq!(a.len(), 1); + debug_assert_eq!(b.len(), 1); + + let r = Scalar::random(&mut *rng); + let s = Scalar::random(&mut *rng); + let delta = Scalar::random(&mut *rng); + let eta = Scalar::random(&mut *rng); + + let ry = r * y[0]; + + let mut A_terms = + vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)]; + let A = multiexp(&A_terms) * Scalar(crate::INV_EIGHT()); + A_terms.zeroize(); + + let mut B_terms = vec![(ry * s, g), (eta, h)]; + let B = multiexp(&B_terms) * Scalar(crate::INV_EIGHT()); + B_terms.zeroize(); + + let e = Self::transcript_A_B(&mut transcript, A, B); + + let r_answer = r + (a[0] * e); + let s_answer = s + (b[0] * e); + let delta_answer = eta + (delta * e) + (alpha * e.square()); + + Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer }) + } + + pub(crate) fn verify( + self, + rng: &mut R, + verifier: &mut BatchVerifier, + id: Id, + mut transcript: Scalar, + mut proof: WipProof, + ) -> bool { + let WipStatement { generators, P, y } = self; + + let (g, h) = (generators.g(), generators.h()); + + // Verify the L/R lengths + { + let mut lr_len = 0; + while (1 << lr_len) < generators.len() { + lr_len += 1; + } + if (proof.L.len() != lr_len) || + (proof.R.len() != lr_len) || + (generators.len() != (1 << lr_len)) + { + return false; + } + } + + let inv_y = { + let inv_y = y[0].invert().unwrap(); + let mut res = Vec::with_capacity(y.len()); + res.push(inv_y); + while res.len() < y.len() { + res.push(inv_y * res.last().unwrap()); + } + res + }; + + let mut P_terms = vec![(Scalar::ONE, P)]; + P_terms.reserve(6 + (2 * generators.len()) + proof.L.len()); + + let mut challenges = Vec::with_capacity(proof.L.len()); + let product_cache = { + let mut es = Vec::with_capacity(proof.L.len()); + for (L, R) in proof.L.iter_mut().zip(proof.R.iter_mut()) { + es.push(Self::transcript_L_R(&mut transcript, *L, *R)); + *L = L.mul_by_cofactor(); + *R = R.mul_by_cofactor(); + } + + let mut inv_es = es.clone(); + let mut scratch = vec![Scalar::ZERO; es.len()]; + group::ff::BatchInverter::invert_with_external_scratch(&mut inv_es, &mut scratch); + drop(scratch); + + debug_assert_eq!(es.len(), inv_es.len()); + debug_assert_eq!(es.len(), proof.L.len()); + debug_assert_eq!(es.len(), proof.R.len()); + for ((e, inv_e), (L, R)) in + es.drain(..).zip(inv_es.drain(..)).zip(proof.L.iter().zip(proof.R.iter())) + { + debug_assert_eq!(e.invert().unwrap(), inv_e); + + challenges.push((e, inv_e)); + + let e_square = e.square(); + let inv_e_square = inv_e.square(); + P_terms.push((e_square, *L)); + P_terms.push((inv_e_square, *R)); + } + + Self::challenge_products(&challenges) + }; + + let e = Self::transcript_A_B(&mut transcript, proof.A, proof.B); + proof.A = proof.A.mul_by_cofactor(); + proof.B = proof.B.mul_by_cofactor(); + let neg_e_square = -e.square(); + + let mut multiexp = P_terms; + multiexp.reserve(4 + (2 * generators.len())); + for (scalar, _) in multiexp.iter_mut() { + *scalar *= neg_e_square; + } + + let re = proof.r_answer * e; + for i in 0 .. generators.len() { + let mut scalar = product_cache[i] * re; + if i > 0 { + scalar *= inv_y[i - 1]; + } + multiexp.push((scalar, generators.generator(GeneratorsList::GBold1, i))); + } + + let se = proof.s_answer * e; + for i in 0 .. generators.len() { + multiexp.push(( + se * product_cache[product_cache.len() - 1 - i], + generators.generator(GeneratorsList::HBold1, i), + )); + } + + multiexp.push((-e, proof.A)); + multiexp.push((proof.r_answer * y[0] * proof.s_answer, g)); + multiexp.push((proof.delta_answer, h)); + multiexp.push((-Scalar::ONE, proof.B)); + + verifier.queue(rng, id, multiexp); + + true + } +} diff --git a/coins/monero/src/ringct/bulletproofs/scalar_vector.rs b/coins/monero/src/ringct/bulletproofs/scalar_vector.rs index 3596f838..6f94f228 100644 --- a/coins/monero/src/ringct/bulletproofs/scalar_vector.rs +++ b/coins/monero/src/ringct/bulletproofs/scalar_vector.rs @@ -67,24 +67,6 @@ impl ScalarVector { ScalarVector(res) } - pub(crate) fn even_powers(x: Scalar, pow: usize) -> ScalarVector { - debug_assert!(pow != 0); - // Verify pow is a power of two - debug_assert_eq!(((pow - 1) & pow), 0); - - let xsq = x * x; - let mut res = ScalarVector(Vec::with_capacity(pow / 2)); - res.0.push(xsq); - - let mut prev = 2; - while prev < pow { - res.0.push(res[res.len() - 1] * xsq); - prev += 2; - } - - res - } - pub(crate) fn sum(mut self) -> Scalar { self.0.drain(..).sum() } @@ -110,15 +92,6 @@ pub(crate) fn inner_product(a: &ScalarVector, b: &ScalarVector) -> Scalar { (a * b).sum() } -pub(crate) fn weighted_powers(x: Scalar, len: usize) -> ScalarVector { - ScalarVector(ScalarVector::powers(x, len + 1).0[1 ..].to_vec()) -} - -pub(crate) fn weighted_inner_product(a: &ScalarVector, b: &ScalarVector, y: Scalar) -> Scalar { - // y ** 0 is not used as a power - (a * b * weighted_powers(y, a.len())).sum() -} - impl Mul<&[EdwardsPoint]> for &ScalarVector { type Output = EdwardsPoint; fn mul(self, b: &[EdwardsPoint]) -> EdwardsPoint { diff --git a/coins/monero/src/tests/bulletproofs.rs b/coins/monero/src/tests/bulletproofs/mod.rs similarity index 99% rename from coins/monero/src/tests/bulletproofs.rs rename to coins/monero/src/tests/bulletproofs/mod.rs index 2b30caaa..4ad39aa0 100644 --- a/coins/monero/src/tests/bulletproofs.rs +++ b/coins/monero/src/tests/bulletproofs/mod.rs @@ -9,6 +9,8 @@ use crate::{ ringct::bulletproofs::{Bulletproofs, original::OriginalStruct}, }; +mod plus; + #[test] fn bulletproofs_vector() { let scalar = |scalar| Scalar::from_canonical_bytes(scalar).unwrap(); @@ -62,7 +64,7 @@ macro_rules! bulletproofs_tests { fn $name() { // Create Bulletproofs for all possible output quantities let mut verifier = BatchVerifier::new(16); - for i in 1 .. 17 { + for i in 1 ..= 16 { let commitments = (1 ..= i) .map(|i| Commitment::new(random_scalar(&mut OsRng), u64::try_from(i).unwrap())) .collect::>(); diff --git a/coins/monero/src/tests/bulletproofs/plus/aggregate_range_proof.rs b/coins/monero/src/tests/bulletproofs/plus/aggregate_range_proof.rs new file mode 100644 index 00000000..34aa8478 --- /dev/null +++ b/coins/monero/src/tests/bulletproofs/plus/aggregate_range_proof.rs @@ -0,0 +1,30 @@ +use rand_core::{RngCore, OsRng}; + +use multiexp::BatchVerifier; +use group::ff::Field; +use dalek_ff_group::{Scalar, EdwardsPoint}; + +use crate::{ + Commitment, + ringct::bulletproofs::plus::aggregate_range_proof::{ + AggregateRangeStatement, AggregateRangeWitness, + }, +}; + +#[test] +fn test_aggregate_range_proof() { + let mut verifier = BatchVerifier::new(16); + for m in 1 ..= 16 { + let mut commitments = vec![]; + for _ in 0 .. m { + commitments.push(Commitment::new(*Scalar::random(&mut OsRng), OsRng.next_u64())); + } + let commitment_points = commitments.iter().map(|com| EdwardsPoint(com.calculate())).collect(); + let statement = AggregateRangeStatement::new(commitment_points).unwrap(); + let witness = AggregateRangeWitness::new(&commitments).unwrap(); + + let proof = statement.clone().prove(&mut OsRng, witness).unwrap(); + statement.verify(&mut OsRng, &mut verifier, (), proof); + } + assert!(verifier.verify_vartime()); +} diff --git a/coins/monero/src/tests/bulletproofs/plus/mod.rs b/coins/monero/src/tests/bulletproofs/plus/mod.rs new file mode 100644 index 00000000..bd48add5 --- /dev/null +++ b/coins/monero/src/tests/bulletproofs/plus/mod.rs @@ -0,0 +1,4 @@ +#[cfg(test)] +mod weighted_inner_product; +#[cfg(test)] +mod aggregate_range_proof; diff --git a/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs b/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs new file mode 100644 index 00000000..3da9c6ad --- /dev/null +++ b/coins/monero/src/tests/bulletproofs/plus/weighted_inner_product.rs @@ -0,0 +1,82 @@ +// The inner product relation is P = sum(g_bold * a, h_bold * b, g * (a * y * b), h * alpha) + +use rand_core::OsRng; + +use multiexp::BatchVerifier; +use group::{ff::Field, Group}; +use dalek_ff_group::{Scalar, EdwardsPoint}; + +use crate::ringct::bulletproofs::plus::{ + ScalarVector, PointVector, GeneratorsList, Generators, + weighted_inner_product::{WipStatement, WipWitness}, + weighted_inner_product, +}; + +#[test] +fn test_zero_weighted_inner_product() { + #[allow(non_snake_case)] + let P = EdwardsPoint::identity(); + let y = Scalar::random(&mut OsRng); + + let generators = Generators::new().reduce(1); + let statement = WipStatement::new(generators, P, y); + let witness = WipWitness::new(ScalarVector::new(1), ScalarVector::new(1), Scalar::ZERO).unwrap(); + + let transcript = Scalar::random(&mut OsRng); + let proof = statement.clone().prove(&mut OsRng, transcript, witness).unwrap(); + + let mut verifier = BatchVerifier::new(1); + statement.verify(&mut OsRng, &mut verifier, (), transcript, proof); + assert!(verifier.verify_vartime()); +} + +#[test] +fn test_weighted_inner_product() { + // P = sum(g_bold * a, h_bold * b, g * (a * y * b), h * alpha) + let mut verifier = BatchVerifier::new(6); + let generators = Generators::new(); + for i in [1, 2, 4, 8, 16, 32] { + let generators = generators.reduce(i); + let g = generators.g(); + let h = generators.h(); + assert_eq!(generators.len(), i); + let mut g_bold = vec![]; + let mut h_bold = vec![]; + for i in 0 .. i { + g_bold.push(generators.generator(GeneratorsList::GBold1, i)); + h_bold.push(generators.generator(GeneratorsList::HBold1, i)); + } + let g_bold = PointVector(g_bold); + let h_bold = PointVector(h_bold); + + let mut a = ScalarVector::new(i); + let mut b = ScalarVector::new(i); + let alpha = Scalar::random(&mut OsRng); + + let y = Scalar::random(&mut OsRng); + let mut y_vec = ScalarVector::new(g_bold.len()); + y_vec[0] = y; + for i in 1 .. y_vec.len() { + y_vec[i] = y_vec[i - 1] * y; + } + + for i in 0 .. i { + a[i] = Scalar::random(&mut OsRng); + b[i] = Scalar::random(&mut OsRng); + } + + #[allow(non_snake_case)] + let P = g_bold.multiexp(&a) + + h_bold.multiexp(&b) + + (g * weighted_inner_product(&a, &b, &y_vec)) + + (h * alpha); + + let statement = WipStatement::new(generators, P, y); + let witness = WipWitness::new(a, b, alpha).unwrap(); + + let transcript = Scalar::random(&mut OsRng); + let proof = statement.clone().prove(&mut OsRng, transcript, witness).unwrap(); + statement.verify(&mut OsRng, &mut verifier, (), transcript, proof); + } + assert!(verifier.verify_vartime()); +}