From be68e27551515adc7b5ace8cbda4df2aa81e5c01 Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Mon, 15 Sep 2025 22:37:59 -0400 Subject: [PATCH] Tweak `multiexp` to compile on `core` On `core`, it'll use a serial implementation of no benefit other than the fact that when `alloc` _is_ enabled, it'll use the multi-scalar multiplication algorithms. `schnorr-signatures` was prior tweaked to include a shim for `SchnorrSignature::verify` which didn't use `multiexp_vartime` yet this same premise. Now, instead of callers writing these shims, it's within `multiexp`. --- Cargo.lock | 2 - crypto/multiexp/Cargo.toml | 11 +- crypto/multiexp/README.md | 5 +- crypto/multiexp/src/batch.rs | 2 +- crypto/multiexp/src/lib.rs | 307 ++++++++++++++----------------- crypto/multiexp/src/pippenger.rs | 2 + crypto/multiexp/src/straus.rs | 2 +- crypto/schnorr/Cargo.toml | 4 +- crypto/schnorr/src/lib.rs | 10 +- tests/no-std/Cargo.toml | 5 +- 10 files changed, 161 insertions(+), 189 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ef25987e..a09e75d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6334,8 +6334,6 @@ dependencies = [ "group", "k256", "rand_core 0.6.4", - "rustversion", - "std-shims", "zeroize", ] diff --git a/crypto/multiexp/Cargo.toml b/crypto/multiexp/Cargo.toml index 18efef3f..8239c68b 100644 --- a/crypto/multiexp/Cargo.toml +++ b/crypto/multiexp/Cargo.toml @@ -17,11 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"] workspace = true [dependencies] -rustversion = "1" - -std-shims = { path = "../../common/std-shims", version = "0.1.1", default-features = false, features = ["alloc"] } - -zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive", "alloc"] } +zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] } ff = { version = "0.13", default-features = false, features = ["bits"] } group = { version = "0.13", default-features = false } @@ -35,8 +31,9 @@ k256 = { version = "^0.13.1", default-features = false, features = ["arithmetic" dalek-ff-group = { path = "../dalek-ff-group" } [features] -std = ["std-shims/std", "zeroize/std", "ff/std", "rand_core?/std"] +alloc = ["zeroize/alloc"] +std = ["alloc", "zeroize/std", "ff/std", "rand_core?/std"] -batch = ["rand_core"] +batch = ["alloc", "rand_core"] default = ["std"] diff --git a/crypto/multiexp/README.md b/crypto/multiexp/README.md index 1366f7a6..547dfa06 100644 --- a/crypto/multiexp/README.md +++ b/crypto/multiexp/README.md @@ -12,5 +12,6 @@ culminating in commit [669d2dbffc1dafb82a09d9419ea182667115df06](https://github.com/serai-dex/serai/tree/669d2dbffc1dafb82a09d9419ea182667115df06). Any subsequent changes have not undergone auditing. -This library is usable under no_std, via alloc, when the default features are -disabled. +This library is usable under no-`std` and no-`alloc`. With the `alloc` feature, +the library is fully functional. Without the `alloc` feature, the `multiexp` +function is shimmed with a serial implementation. diff --git a/crypto/multiexp/src/batch.rs b/crypto/multiexp/src/batch.rs index ea8044dd..50f07fc8 100644 --- a/crypto/multiexp/src/batch.rs +++ b/crypto/multiexp/src/batch.rs @@ -1,4 +1,4 @@ -use std_shims::vec::Vec; +use alloc::vec::Vec; use rand_core::{RngCore, CryptoRng}; diff --git a/crypto/multiexp/src/lib.rs b/crypto/multiexp/src/lib.rs index 8b16aa91..9bc35e4e 100644 --- a/crypto/multiexp/src/lib.rs +++ b/crypto/multiexp/src/lib.rs @@ -2,200 +2,177 @@ #![doc = include_str!("../README.md")] #![cfg_attr(not(feature = "std"), no_std)] -#[cfg(not(feature = "std"))] -#[macro_use] +#[cfg(feature = "alloc")] extern crate alloc; -#[allow(unused_imports)] -use std_shims::prelude::*; -use std_shims::vec::Vec; use zeroize::Zeroize; use ff::PrimeFieldBits; use group::Group; +#[cfg(feature = "alloc")] mod straus; -use straus::*; - +#[cfg(feature = "alloc")] mod pippenger; -use pippenger::*; #[cfg(feature = "batch")] mod batch; -#[cfg(feature = "batch")] -pub use batch::BatchVerifier; -#[cfg(test)] +#[cfg(all(test, feature = "alloc"))] mod tests; -// Use black_box when possible -#[rustversion::since(1.66)] -use core::hint::black_box; -#[rustversion::before(1.66)] -fn black_box(val: T) -> T { - val -} +#[cfg(feature = "alloc")] +mod underlying { + use super::*; -fn u8_from_bool(bit_ref: &mut bool) -> u8 { - let bit_ref = black_box(bit_ref); + use core::hint::black_box; + use alloc::{vec, vec::Vec}; - let mut bit = black_box(*bit_ref); - #[allow(clippy::cast_lossless)] - let res = black_box(bit as u8); - bit.zeroize(); - debug_assert!((res | 1) == 1); + pub(crate) use straus::*; - bit_ref.zeroize(); - res -} + pub(crate) use pippenger::*; -// Convert scalars to `window`-sized bit groups, as needed to index a table -// This algorithm works for `window <= 8` -pub(crate) fn prep_bits>( - pairs: &[(G::Scalar, G)], - window: u8, -) -> Vec> { - let w_usize = usize::from(window); + #[cfg(feature = "batch")] + pub use batch::BatchVerifier; - let mut groupings = vec![]; - for pair in pairs { - let p = groupings.len(); - let mut bits = pair.0.to_le_bits(); - groupings.push(vec![0; bits.len().div_ceil(w_usize)]); + fn u8_from_bool(bit_ref: &mut bool) -> u8 { + let bit_ref = black_box(bit_ref); - for (i, mut bit) in bits.iter_mut().enumerate() { - let mut bit = u8_from_bool(&mut bit); - groupings[p][i / w_usize] |= bit << (i % w_usize); - bit.zeroize(); + let mut bit = black_box(*bit_ref); + #[allow(clippy::cast_lossless)] + let res = black_box(bit as u8); + bit.zeroize(); + debug_assert!((res | 1) == 1); + + bit_ref.zeroize(); + res + } + + // Convert scalars to `window`-sized bit groups, as needed to index a table + // This algorithm works for `window <= 8` + pub(crate) fn prep_bits>( + pairs: &[(G::Scalar, G)], + window: u8, + ) -> Vec> { + let w_usize = usize::from(window); + + let mut groupings = vec![]; + for pair in pairs { + let p = groupings.len(); + let mut bits = pair.0.to_le_bits(); + groupings.push(vec![0; bits.len().div_ceil(w_usize)]); + + for (i, mut bit) in bits.iter_mut().enumerate() { + let mut bit = u8_from_bool(&mut bit); + groupings[p][i / w_usize] |= bit << (i % w_usize); + bit.zeroize(); + } + } + + groupings + } + + #[derive(Clone, Copy, PartialEq, Eq, Debug)] + enum Algorithm { + Null, + Single, + Straus(u8), + Pippenger(u8), + } + + // These are 'rule of thumb's obtained via benchmarking `k256` and `curve25519-dalek` + fn algorithm(len: usize) -> Algorithm { + #[cfg(not(debug_assertions))] + if len == 0 { + Algorithm::Null + } else if len == 1 { + Algorithm::Single + } else if len < 10 { + // Straus 2 never showed a performance benefit, even with just 2 elements + Algorithm::Straus(3) + } else if len < 20 { + Algorithm::Straus(4) + } else if len < 50 { + Algorithm::Straus(5) + } else if len < 100 { + Algorithm::Pippenger(4) + } else if len < 125 { + Algorithm::Pippenger(5) + } else if len < 275 { + Algorithm::Pippenger(6) + } else if len < 400 { + Algorithm::Pippenger(7) + } else { + Algorithm::Pippenger(8) + } + + #[cfg(debug_assertions)] + if len == 0 { + Algorithm::Null + } else if len == 1 { + Algorithm::Single + } else if len < 10 { + Algorithm::Straus(3) + } else if len < 80 { + Algorithm::Straus(4) + } else if len < 100 { + Algorithm::Straus(5) + } else if len < 125 { + Algorithm::Pippenger(4) + } else if len < 275 { + Algorithm::Pippenger(5) + } else if len < 475 { + Algorithm::Pippenger(6) + } else if len < 750 { + Algorithm::Pippenger(7) + } else { + Algorithm::Pippenger(8) } } - groupings -} - -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum Algorithm { - Null, - Single, - Straus(u8), - Pippenger(u8), -} - -/* -Release (with runs 20, so all of these are off by 20x): - -k256 -Straus 3 is more efficient at 5 with 678µs per -Straus 4 is more efficient at 10 with 530µs per -Straus 5 is more efficient at 35 with 467µs per - -Pippenger 5 is more efficient at 125 with 431µs per -Pippenger 6 is more efficient at 275 with 349µs per -Pippenger 7 is more efficient at 375 with 360µs per - -dalek -Straus 3 is more efficient at 5 with 519µs per -Straus 4 is more efficient at 10 with 376µs per -Straus 5 is more efficient at 170 with 330µs per - -Pippenger 5 is more efficient at 125 with 305µs per -Pippenger 6 is more efficient at 275 with 250µs per -Pippenger 7 is more efficient at 450 with 205µs per -Pippenger 8 is more efficient at 800 with 213µs per - -Debug (with runs 5, so...): - -k256 -Straus 3 is more efficient at 5 with 2532µs per -Straus 4 is more efficient at 10 with 1930µs per -Straus 5 is more efficient at 80 with 1632µs per - -Pippenger 5 is more efficient at 150 with 1441µs per -Pippenger 6 is more efficient at 300 with 1235µs per -Pippenger 7 is more efficient at 475 with 1182µs per -Pippenger 8 is more efficient at 625 with 1170µs per - -dalek: -Straus 3 is more efficient at 5 with 971µs per -Straus 4 is more efficient at 10 with 782µs per -Straus 5 is more efficient at 75 with 778µs per -Straus 6 is more efficient at 165 with 867µs per - -Pippenger 5 is more efficient at 125 with 677µs per -Pippenger 6 is more efficient at 250 with 655µs per -Pippenger 7 is more efficient at 475 with 500µs per -Pippenger 8 is more efficient at 875 with 499µs per -*/ -fn algorithm(len: usize) -> Algorithm { - #[cfg(not(debug_assertions))] - if len == 0 { - Algorithm::Null - } else if len == 1 { - Algorithm::Single - } else if len < 10 { - // Straus 2 never showed a performance benefit, even with just 2 elements - Algorithm::Straus(3) - } else if len < 20 { - Algorithm::Straus(4) - } else if len < 50 { - Algorithm::Straus(5) - } else if len < 100 { - Algorithm::Pippenger(4) - } else if len < 125 { - Algorithm::Pippenger(5) - } else if len < 275 { - Algorithm::Pippenger(6) - } else if len < 400 { - Algorithm::Pippenger(7) - } else { - Algorithm::Pippenger(8) + /// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the + /// amount of pairs. + pub fn multiexp>( + pairs: &[(G::Scalar, G)], + ) -> G { + match algorithm(pairs.len()) { + Algorithm::Null => Group::identity(), + Algorithm::Single => pairs[0].1 * pairs[0].0, + // These functions panic if called without any pairs + Algorithm::Straus(window) => straus(pairs, window), + Algorithm::Pippenger(window) => pippenger(pairs, window), + } } - #[cfg(debug_assertions)] - if len == 0 { - Algorithm::Null - } else if len == 1 { - Algorithm::Single - } else if len < 10 { - Algorithm::Straus(3) - } else if len < 80 { - Algorithm::Straus(4) - } else if len < 100 { - Algorithm::Straus(5) - } else if len < 125 { - Algorithm::Pippenger(4) - } else if len < 275 { - Algorithm::Pippenger(5) - } else if len < 475 { - Algorithm::Pippenger(6) - } else if len < 750 { - Algorithm::Pippenger(7) - } else { - Algorithm::Pippenger(8) + /// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm + /// based on the amount of pairs. + pub fn multiexp_vartime>(pairs: &[(G::Scalar, G)]) -> G { + match algorithm(pairs.len()) { + Algorithm::Null => Group::identity(), + Algorithm::Single => pairs[0].1 * pairs[0].0, + Algorithm::Straus(window) => straus_vartime(pairs, window), + Algorithm::Pippenger(window) => pippenger_vartime(pairs, window), + } } } -/// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the -/// amount of pairs. -pub fn multiexp>( - pairs: &[(G::Scalar, G)], -) -> G { - match algorithm(pairs.len()) { - Algorithm::Null => Group::identity(), - Algorithm::Single => pairs[0].1 * pairs[0].0, - // These functions panic if called without any pairs - Algorithm::Straus(window) => straus(pairs, window), - Algorithm::Pippenger(window) => pippenger(pairs, window), +#[cfg(not(feature = "alloc"))] +mod underlying { + use super::*; + + /// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the + /// amount of pairs. + pub fn multiexp>( + pairs: &[(G::Scalar, G)], + ) -> G { + pairs.iter().map(|(scalar, point)| *point * scalar).sum() + } + + /// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm + /// based on the amount of pairs. + pub fn multiexp_vartime>(pairs: &[(G::Scalar, G)]) -> G { + pairs.iter().map(|(scalar, point)| *point * scalar).sum() } } -/// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm -/// based on the amount of pairs. -pub fn multiexp_vartime>(pairs: &[(G::Scalar, G)]) -> G { - match algorithm(pairs.len()) { - Algorithm::Null => Group::identity(), - Algorithm::Single => pairs[0].1 * pairs[0].0, - Algorithm::Straus(window) => straus_vartime(pairs, window), - Algorithm::Pippenger(window) => pippenger_vartime(pairs, window), - } -} +pub use underlying::*; diff --git a/crypto/multiexp/src/pippenger.rs b/crypto/multiexp/src/pippenger.rs index faf9edc2..42f91ab2 100644 --- a/crypto/multiexp/src/pippenger.rs +++ b/crypto/multiexp/src/pippenger.rs @@ -1,3 +1,5 @@ +use alloc::vec; + use zeroize::Zeroize; use ff::PrimeFieldBits; diff --git a/crypto/multiexp/src/straus.rs b/crypto/multiexp/src/straus.rs index 638b2827..c8994f1b 100644 --- a/crypto/multiexp/src/straus.rs +++ b/crypto/multiexp/src/straus.rs @@ -1,4 +1,4 @@ -use std_shims::vec::Vec; +use alloc::{vec, vec::Vec}; use zeroize::Zeroize; diff --git a/crypto/schnorr/Cargo.toml b/crypto/schnorr/Cargo.toml index 006275a1..0314223a 100644 --- a/crypto/schnorr/Cargo.toml +++ b/crypto/schnorr/Cargo.toml @@ -27,7 +27,7 @@ digest = { version = "0.11.0-rc.1", default-features = false, features = ["block transcript = { package = "flexible-transcript", path = "../transcript", version = "^0.3.2", default-features = false, optional = true } ciphersuite = { path = "../ciphersuite", version = "^0.4.1", default-features = false } -multiexp = { path = "../multiexp", version = "0.4", default-features = false, features = ["batch"], optional = true } +multiexp = { path = "../multiexp", version = "0.4", default-features = false } [dev-dependencies] hex = "0.4" @@ -40,7 +40,7 @@ dalek-ff-group = { path = "../dalek-ff-group" } ciphersuite = { path = "../ciphersuite" } [features] -alloc = ["zeroize/alloc", "digest/alloc", "ciphersuite/alloc", "multiexp"] +alloc = ["zeroize/alloc", "digest/alloc", "ciphersuite/alloc", "multiexp/alloc", "multiexp/batch"] aggregate = ["alloc", "transcript"] std = ["alloc", "std-shims/std", "rand_core/std", "zeroize/std", "transcript?/std", "ciphersuite/std", "multiexp/std"] default = ["std"] diff --git a/crypto/schnorr/src/lib.rs b/crypto/schnorr/src/lib.rs index 071ef8e6..23f64c79 100644 --- a/crypto/schnorr/src/lib.rs +++ b/crypto/schnorr/src/lib.rs @@ -23,8 +23,9 @@ use ciphersuite::{ }, GroupIo, }; +use multiexp::multiexp_vartime; #[cfg(feature = "alloc")] -use multiexp::{multiexp_vartime, BatchVerifier}; +use multiexp::BatchVerifier; /// Half-aggregation from . #[cfg(feature = "aggregate")] @@ -109,12 +110,7 @@ impl SchnorrSignature { /// different keys/messages. #[must_use] pub fn verify(&self, public_key: C::G, challenge: C::F) -> bool { - let statements = self.batch_statements(public_key, challenge); - #[cfg(feature = "alloc")] - let res = multiexp_vartime(&statements); - #[cfg(not(feature = "alloc"))] - let res = statements.into_iter().map(|(scalar, point)| point * scalar).sum::(); - res.is_identity().into() + multiexp_vartime(&self.batch_statements(public_key, challenge)).is_identity().into() } /// Queue a signature for batch verification. diff --git a/tests/no-std/Cargo.toml b/tests/no-std/Cargo.toml index a828fe9c..0b75f8ab 100644 --- a/tests/no-std/Cargo.toml +++ b/tests/no-std/Cargo.toml @@ -21,7 +21,7 @@ std-shims = { path = "../../common/std-shims", default-features = false } flexible-transcript = { path = "../../crypto/transcript", default-features = false, features = ["recommended", "merlin"] } -multiexp = { path = "../../crypto/multiexp", default-features = false, features = ["batch"], optional = true } +multiexp = { path = "../../crypto/multiexp", default-features = false } dalek-ff-group = { path = "../../crypto/dalek-ff-group", default-features = false } minimal-ed448 = { path = "../../crypto/ed448", default-features = false } @@ -46,7 +46,8 @@ bitcoin-serai = { path = "../../networks/bitcoin", default-features = false, fea alloc = [ "std-shims/alloc", - "multiexp", + "multiexp/alloc", + "multiexp/batch", "dalek-ff-group/alloc", "minimal-ed448/alloc",