diff --git a/coins/monero/src/ringct/bulletproofs/mod.rs b/coins/monero/src/ringct/bulletproofs/mod.rs index 6dc2297b..2d62553d 100644 --- a/coins/monero/src/ringct/bulletproofs/mod.rs +++ b/coins/monero/src/ringct/bulletproofs/mod.rs @@ -26,15 +26,15 @@ use self::plus::*; pub(crate) const MAX_OUTPUTS: usize = self::core::MAX_M; -/// Bulletproofs enum, supporting the original and plus formulations. +/// Bulletproof enum, encapsulating both Bulletproofs and Bulletproofs+. #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Eq, Debug)] -pub enum Bulletproofs { +pub enum Bulletproof { Original(OriginalStruct), Plus(AggregateRangeProof), } -impl Bulletproofs { +impl Bulletproof { fn bp_fields(plus: bool) -> usize { if plus { 6 @@ -57,7 +57,7 @@ impl Bulletproofs { let mut bp_clawback = 0; if n_padded_outputs > 2 { - let fields = Bulletproofs::bp_fields(plus); + let fields = Bulletproof::bp_fields(plus); let base = ((fields + (2 * (LOG_N + 1))) * 32) / 2; let size = (fields + (2 * LR_len)) * 32; bp_clawback = ((base * n_padded_outputs) - size) * 4 / 5; @@ -68,40 +68,49 @@ impl Bulletproofs { pub(crate) fn fee_weight(plus: bool, outputs: usize) -> usize { #[allow(non_snake_case)] - let (bp_clawback, LR_len) = Bulletproofs::calculate_bp_clawback(plus, outputs); - 32 * (Bulletproofs::bp_fields(plus) + (2 * LR_len)) + 2 + bp_clawback + let (bp_clawback, LR_len) = Bulletproof::calculate_bp_clawback(plus, outputs); + 32 * (Bulletproof::bp_fields(plus) + (2 * LR_len)) + 2 + bp_clawback } - /// Prove the list of commitments are within [0 .. 2^64). + /// Prove the list of commitments are within [0 .. 2^64) with an aggregate Bulletproof. pub fn prove( rng: &mut R, outputs: &[Commitment], - plus: bool, - ) -> Result { + ) -> Result { if outputs.is_empty() { Err(TransactionError::NoOutputs)?; } if outputs.len() > MAX_OUTPUTS { Err(TransactionError::TooManyOutputs)?; } - Ok(if !plus { - Bulletproofs::Original(OriginalStruct::prove(rng, outputs)) - } else { - Bulletproofs::Plus( - AggregateRangeStatement::new(outputs.iter().map(Commitment::calculate).collect()) - .unwrap() - .prove(rng, &Zeroizing::new(AggregateRangeWitness::new(outputs.to_vec()).unwrap())) - .unwrap(), - ) - }) + Ok(Bulletproof::Original(OriginalStruct::prove(rng, outputs))) } - /// Verify the given Bulletproofs. + /// Prove the list of commitments are within [0 .. 2^64) with an aggregate Bulletproof+. + pub fn prove_plus( + rng: &mut R, + outputs: Vec, + ) -> Result { + if outputs.is_empty() { + Err(TransactionError::NoOutputs)?; + } + if outputs.len() > MAX_OUTPUTS { + Err(TransactionError::TooManyOutputs)?; + } + Ok(Bulletproof::Plus( + AggregateRangeStatement::new(outputs.iter().map(Commitment::calculate).collect()) + .unwrap() + .prove(rng, &Zeroizing::new(AggregateRangeWitness::new(outputs).unwrap())) + .unwrap(), + )) + } + + /// Verify the given Bulletproof(+). #[must_use] pub fn verify(&self, rng: &mut R, commitments: &[EdwardsPoint]) -> bool { match self { - Bulletproofs::Original(bp) => bp.verify(rng, commitments), - Bulletproofs::Plus(bp) => { + Bulletproof::Original(bp) => bp.verify(rng, commitments), + Bulletproof::Plus(bp) => { let mut verifier = BatchVerifier::new(1); let Some(statement) = AggregateRangeStatement::new(commitments.to_vec()) else { return false; @@ -114,9 +123,11 @@ impl Bulletproofs { } } - /// Accumulate the verification for the given Bulletproofs into the specified BatchVerifier. - /// Returns false if the Bulletproofs aren't sane, without mutating the BatchVerifier. - /// Returns true if the Bulletproofs are sane, regardless of their validity. + /// Accumulate the verification for the given Bulletproof into the specified BatchVerifier. + /// + /// Returns false if the Bulletproof isn't sane, leaving the BatchVerifier in an undefined + /// state. + /// Returns true if the Bulletproof is sane, regardless of their validity. #[must_use] pub fn batch_verify( &self, @@ -126,8 +137,8 @@ impl Bulletproofs { commitments: &[EdwardsPoint], ) -> bool { match self { - Bulletproofs::Original(bp) => bp.batch_verify(rng, verifier, id, commitments), - Bulletproofs::Plus(bp) => { + Bulletproof::Original(bp) => bp.batch_verify(rng, verifier, id, commitments), + Bulletproof::Plus(bp) => { let Some(statement) = AggregateRangeStatement::new(commitments.to_vec()) else { return false; }; @@ -142,7 +153,7 @@ impl Bulletproofs { specific_write_vec: F, ) -> io::Result<()> { match self { - Bulletproofs::Original(bp) => { + Bulletproof::Original(bp) => { write_point(&bp.A, w)?; write_point(&bp.S, w)?; write_point(&bp.T1, w)?; @@ -156,7 +167,7 @@ impl Bulletproofs { write_scalar(&bp.t, w) } - Bulletproofs::Plus(bp) => { + Bulletproof::Plus(bp) => { write_point(&bp.A.0, w)?; write_point(&bp.wip.A.0, w)?; write_point(&bp.wip.B.0, w)?; @@ -173,19 +184,21 @@ impl Bulletproofs { self.write_core(w, |points, w| write_raw_vec(write_point, points, w)) } + /// Write the Bulletproof(+) to a writer. pub fn write(&self, w: &mut W) -> io::Result<()> { self.write_core(w, |points, w| write_vec(write_point, points, w)) } + /// Serialize the Bulletproof(+) to a Vec. pub fn serialize(&self) -> Vec { let mut serialized = vec![]; self.write(&mut serialized).unwrap(); serialized } - /// Read Bulletproofs. - pub fn read(r: &mut R) -> io::Result { - Ok(Bulletproofs::Original(OriginalStruct { + /// Read a Bulletproof. + pub fn read(r: &mut R) -> io::Result { + Ok(Bulletproof::Original(OriginalStruct { A: read_point(r)?, S: read_point(r)?, T1: read_point(r)?, @@ -200,11 +213,11 @@ impl Bulletproofs { })) } - /// Read Bulletproofs+. - pub fn read_plus(r: &mut R) -> io::Result { + /// Read a Bulletproof+. + pub fn read_plus(r: &mut R) -> io::Result { use dalek_ff_group::{Scalar as DfgScalar, EdwardsPoint as DfgPoint}; - Ok(Bulletproofs::Plus(AggregateRangeProof { + Ok(Bulletproof::Plus(AggregateRangeProof { A: DfgPoint(read_point(r)?), wip: WipProof { A: DfgPoint(read_point(r)?), diff --git a/coins/monero/src/ringct/mod.rs b/coins/monero/src/ringct/mod.rs index bcd7f0c8..3b00dda8 100644 --- a/coins/monero/src/ringct/mod.rs +++ b/coins/monero/src/ringct/mod.rs @@ -23,7 +23,7 @@ pub mod bulletproofs; use crate::{ Protocol, serialize::*, - ringct::{mlsag::Mlsag, clsag::Clsag, borromean::BorromeanRange, bulletproofs::Bulletproofs}, + ringct::{mlsag::Mlsag, clsag::Clsag, borromean::BorromeanRange, bulletproofs::Bulletproof}, }; /// Generate a key image for a given key. Defined as `x * hash_to_point(xG)`. @@ -199,12 +199,12 @@ pub enum RctPrunable { mlsags: Vec, }, MlsagBulletproofs { - bulletproofs: Bulletproofs, + bulletproofs: Bulletproof, mlsags: Vec, pseudo_outs: Vec, }, Clsag { - bulletproofs: Bulletproofs, + bulletproofs: Bulletproof, clsags: Vec, pseudo_outs: Vec, }, @@ -213,7 +213,7 @@ pub enum RctPrunable { impl RctPrunable { pub(crate) fn fee_weight(protocol: Protocol, inputs: usize, outputs: usize) -> usize { // 1 byte for number of BPs (technically a VarInt, yet there's always just zero or one) - 1 + Bulletproofs::fee_weight(protocol.bp_plus(), outputs) + + 1 + Bulletproof::fee_weight(protocol.bp_plus(), outputs) + (inputs * (Clsag::fee_weight(protocol.ring_len()) + 32)) } @@ -294,7 +294,7 @@ impl RctPrunable { { Err(io::Error::other("n bulletproofs instead of one"))?; } - Bulletproofs::read(r)? + Bulletproof::read(r)? }, mlsags: (0 .. inputs) .map(|_| Mlsag::read(ring_length, 2, r)) @@ -307,9 +307,7 @@ impl RctPrunable { if read_varint::<_, u64>(r)? != 1 { Err(io::Error::other("n bulletproofs instead of one"))?; } - (if rct_type == RctType::Clsag { Bulletproofs::read } else { Bulletproofs::read_plus })( - r, - )? + (if rct_type == RctType::Clsag { Bulletproof::read } else { Bulletproof::read_plus })(r)? }, clsags: (0 .. inputs).map(|_| Clsag::read(ring_length, r)).collect::>()?, pseudo_outs: read_raw_vec(read_point, inputs, r)?, @@ -360,7 +358,7 @@ impl RctSignatures { } } RctPrunable::Clsag { bulletproofs, .. } => { - if matches!(bulletproofs, Bulletproofs::Original { .. }) { + if matches!(bulletproofs, Bulletproof::Original { .. }) { RctType::Clsag } else { RctType::BulletproofsPlus diff --git a/coins/monero/src/tests/bulletproofs/mod.rs b/coins/monero/src/tests/bulletproofs/mod.rs index 6c276206..9ffb7bc3 100644 --- a/coins/monero/src/tests/bulletproofs/mod.rs +++ b/coins/monero/src/tests/bulletproofs/mod.rs @@ -7,7 +7,8 @@ use multiexp::BatchVerifier; use crate::{ Commitment, random_scalar, - ringct::bulletproofs::{Bulletproofs, original::OriginalStruct}, + ringct::bulletproofs::{Bulletproof, original::OriginalStruct}, + wallet::TransactionError, }; mod plus; @@ -18,7 +19,7 @@ fn bulletproofs_vector() { let point = |point| decompress_point(point).unwrap(); // Generated from Monero - assert!(Bulletproofs::Original(OriginalStruct { + assert!(Bulletproof::Original(OriginalStruct { A: point(hex!("ef32c0b9551b804decdcb107eb22aa715b7ce259bf3c5cac20e24dfa6b28ac71")), S: point(hex!("e1285960861783574ee2b689ae53622834eb0b035d6943103f960cd23e063fa0")), T1: point(hex!("4ea07735f184ba159d0e0eb662bac8cde3eb7d39f31e567b0fbda3aa23fe5620")), @@ -70,7 +71,11 @@ macro_rules! bulletproofs_tests { .map(|i| Commitment::new(random_scalar(&mut OsRng), u64::try_from(i).unwrap())) .collect::>(); - let bp = Bulletproofs::prove(&mut OsRng, &commitments, $plus).unwrap(); + let bp = if $plus { + Bulletproof::prove_plus(&mut OsRng, commitments.clone()).unwrap() + } else { + Bulletproof::prove(&mut OsRng, &commitments).unwrap() + }; let commitments = commitments.iter().map(Commitment::calculate).collect::>(); assert!(bp.verify(&mut OsRng, &commitments)); @@ -86,7 +91,15 @@ macro_rules! bulletproofs_tests { for _ in 0 .. 17 { commitments.push(Commitment::new(Scalar::ZERO, 0)); } - assert!(Bulletproofs::prove(&mut OsRng, &commitments, $plus).is_err()); + assert_eq!( + (if $plus { + Bulletproof::prove_plus(&mut OsRng, commitments) + } else { + Bulletproof::prove(&mut OsRng, &commitments) + }) + .unwrap_err(), + TransactionError::TooManyOutputs, + ); } }; } diff --git a/coins/monero/src/transaction.rs b/coins/monero/src/transaction.rs index 89d489fe..2d351ff1 100644 --- a/coins/monero/src/transaction.rs +++ b/coins/monero/src/transaction.rs @@ -12,7 +12,7 @@ use crate::{ Protocol, hash, serialize::*, ring_signatures::RingSignature, - ringct::{bulletproofs::Bulletproofs, RctType, RctBase, RctPrunable, RctSignatures}, + ringct::{bulletproofs::Bulletproof, RctType, RctBase, RctPrunable, RctSignatures}, }; #[derive(Clone, PartialEq, Eq, Debug)] @@ -426,7 +426,7 @@ impl Transaction { if !(bp || bp_plus) { blob_size } else { - blob_size + Bulletproofs::calculate_bp_clawback(bp_plus, self.prefix.outputs.len()).0 + blob_size + Bulletproof::calculate_bp_clawback(bp_plus, self.prefix.outputs.len()).0 } } } diff --git a/coins/monero/src/wallet/send/mod.rs b/coins/monero/src/wallet/send/mod.rs index f4ac208e..6a0ab889 100644 --- a/coins/monero/src/wallet/send/mod.rs +++ b/coins/monero/src/wallet/send/mod.rs @@ -31,7 +31,7 @@ use crate::{ ringct::{ generate_key_image, clsag::{ClsagError, ClsagInput, Clsag}, - bulletproofs::{MAX_OUTPUTS, Bulletproofs}, + bulletproofs::{MAX_OUTPUTS, Bulletproof}, RctBase, RctPrunable, RctSignatures, }, transaction::{Input, Output, Timelock, TransactionPrefix, Transaction}, @@ -783,7 +783,11 @@ impl SignableTransaction { let sum = commitments.iter().map(|commitment| commitment.mask).sum(); // Safe due to the constructor checking MAX_OUTPUTS - let bp = Bulletproofs::prove(rng, &commitments, self.protocol.bp_plus()).unwrap(); + let bp = if self.protocol.bp_plus() { + Bulletproof::prove_plus(rng, commitments.clone()).unwrap() + } else { + Bulletproof::prove(rng, &commitments).unwrap() + }; // Create the TX extra let extra = Self::extra(