Tweak multiexp to compile on core

On `core`, it'll use a serial implementation of no benefit other than the fact
that when `alloc` _is_ enabled, it'll use the multi-scalar multiplication
algorithms.

`schnorr-signatures` was prior tweaked to include a shim for
`SchnorrSignature::verify` which didn't use `multiexp_vartime` yet this same
premise. Now, instead of callers writing these shims, it's within `multiexp`.
This commit is contained in:
Luke Parker
2025-09-15 22:37:59 -04:00
parent a82ccadbb0
commit 2be69b23b1
10 changed files with 161 additions and 189 deletions

2
Cargo.lock generated
View File

@@ -5714,8 +5714,6 @@ dependencies = [
"group",
"k256",
"rand_core 0.6.4",
"rustversion",
"std-shims",
"zeroize",
]

View File

@@ -17,11 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"]
workspace = true
[dependencies]
rustversion = "1"
std-shims = { path = "../../common/std-shims", version = "0.1.1", default-features = false, features = ["alloc"] }
zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive", "alloc"] }
zeroize = { version = "^1.5", default-features = false, features = ["zeroize_derive"] }
ff = { version = "0.13", default-features = false, features = ["bits"] }
group = { version = "0.13", default-features = false }
@@ -35,8 +31,9 @@ k256 = { version = "^0.13.1", default-features = false, features = ["arithmetic"
dalek-ff-group = { path = "../dalek-ff-group" }
[features]
std = ["std-shims/std", "zeroize/std", "ff/std", "rand_core?/std"]
alloc = ["zeroize/alloc"]
std = ["alloc", "zeroize/std", "ff/std", "rand_core?/std"]
batch = ["rand_core"]
batch = ["alloc", "rand_core"]
default = ["std"]

View File

@@ -12,5 +12,6 @@ culminating in commit
[669d2dbffc1dafb82a09d9419ea182667115df06](https://github.com/serai-dex/serai/tree/669d2dbffc1dafb82a09d9419ea182667115df06).
Any subsequent changes have not undergone auditing.
This library is usable under no_std, via alloc, when the default features are
disabled.
This library is usable under no-`std` and no-`alloc`. With the `alloc` feature,
the library is fully functional. Without the `alloc` feature, the `multiexp`
function is shimmed with a serial implementation.

View File

@@ -1,4 +1,4 @@
use std_shims::vec::Vec;
use alloc::vec::Vec;
use rand_core::{RngCore, CryptoRng};

View File

@@ -2,200 +2,177 @@
#![doc = include_str!("../README.md")]
#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(not(feature = "std"))]
#[macro_use]
#[cfg(feature = "alloc")]
extern crate alloc;
#[allow(unused_imports)]
use std_shims::prelude::*;
use std_shims::vec::Vec;
use zeroize::Zeroize;
use ff::PrimeFieldBits;
use group::Group;
#[cfg(feature = "alloc")]
mod straus;
use straus::*;
#[cfg(feature = "alloc")]
mod pippenger;
use pippenger::*;
#[cfg(feature = "batch")]
mod batch;
#[cfg(feature = "batch")]
pub use batch::BatchVerifier;
#[cfg(test)]
#[cfg(all(test, feature = "alloc"))]
mod tests;
// Use black_box when possible
#[rustversion::since(1.66)]
use core::hint::black_box;
#[rustversion::before(1.66)]
fn black_box<T>(val: T) -> T {
val
}
#[cfg(feature = "alloc")]
mod underlying {
use super::*;
fn u8_from_bool(bit_ref: &mut bool) -> u8 {
let bit_ref = black_box(bit_ref);
use core::hint::black_box;
use alloc::{vec, vec::Vec};
let mut bit = black_box(*bit_ref);
#[allow(clippy::cast_lossless)]
let res = black_box(bit as u8);
bit.zeroize();
debug_assert!((res | 1) == 1);
pub(crate) use straus::*;
bit_ref.zeroize();
res
}
pub(crate) use pippenger::*;
// Convert scalars to `window`-sized bit groups, as needed to index a table
// This algorithm works for `window <= 8`
pub(crate) fn prep_bits<G: Group<Scalar: PrimeFieldBits>>(
pairs: &[(G::Scalar, G)],
window: u8,
) -> Vec<Vec<u8>> {
let w_usize = usize::from(window);
#[cfg(feature = "batch")]
pub use batch::BatchVerifier;
let mut groupings = vec![];
for pair in pairs {
let p = groupings.len();
let mut bits = pair.0.to_le_bits();
groupings.push(vec![0; bits.len().div_ceil(w_usize)]);
fn u8_from_bool(bit_ref: &mut bool) -> u8 {
let bit_ref = black_box(bit_ref);
for (i, mut bit) in bits.iter_mut().enumerate() {
let mut bit = u8_from_bool(&mut bit);
groupings[p][i / w_usize] |= bit << (i % w_usize);
bit.zeroize();
let mut bit = black_box(*bit_ref);
#[allow(clippy::cast_lossless)]
let res = black_box(bit as u8);
bit.zeroize();
debug_assert!((res | 1) == 1);
bit_ref.zeroize();
res
}
// Convert scalars to `window`-sized bit groups, as needed to index a table
// This algorithm works for `window <= 8`
pub(crate) fn prep_bits<G: Group<Scalar: PrimeFieldBits>>(
pairs: &[(G::Scalar, G)],
window: u8,
) -> Vec<Vec<u8>> {
let w_usize = usize::from(window);
let mut groupings = vec![];
for pair in pairs {
let p = groupings.len();
let mut bits = pair.0.to_le_bits();
groupings.push(vec![0; bits.len().div_ceil(w_usize)]);
for (i, mut bit) in bits.iter_mut().enumerate() {
let mut bit = u8_from_bool(&mut bit);
groupings[p][i / w_usize] |= bit << (i % w_usize);
bit.zeroize();
}
}
groupings
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Algorithm {
Null,
Single,
Straus(u8),
Pippenger(u8),
}
// These are 'rule of thumb's obtained via benchmarking `k256` and `curve25519-dalek`
fn algorithm(len: usize) -> Algorithm {
#[cfg(not(debug_assertions))]
if len == 0 {
Algorithm::Null
} else if len == 1 {
Algorithm::Single
} else if len < 10 {
// Straus 2 never showed a performance benefit, even with just 2 elements
Algorithm::Straus(3)
} else if len < 20 {
Algorithm::Straus(4)
} else if len < 50 {
Algorithm::Straus(5)
} else if len < 100 {
Algorithm::Pippenger(4)
} else if len < 125 {
Algorithm::Pippenger(5)
} else if len < 275 {
Algorithm::Pippenger(6)
} else if len < 400 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger(8)
}
#[cfg(debug_assertions)]
if len == 0 {
Algorithm::Null
} else if len == 1 {
Algorithm::Single
} else if len < 10 {
Algorithm::Straus(3)
} else if len < 80 {
Algorithm::Straus(4)
} else if len < 100 {
Algorithm::Straus(5)
} else if len < 125 {
Algorithm::Pippenger(4)
} else if len < 275 {
Algorithm::Pippenger(5)
} else if len < 475 {
Algorithm::Pippenger(6)
} else if len < 750 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger(8)
}
}
groupings
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Algorithm {
Null,
Single,
Straus(u8),
Pippenger(u8),
}
/*
Release (with runs 20, so all of these are off by 20x):
k256
Straus 3 is more efficient at 5 with 678µs per
Straus 4 is more efficient at 10 with 530µs per
Straus 5 is more efficient at 35 with 467µs per
Pippenger 5 is more efficient at 125 with 431µs per
Pippenger 6 is more efficient at 275 with 349µs per
Pippenger 7 is more efficient at 375 with 360µs per
dalek
Straus 3 is more efficient at 5 with 519µs per
Straus 4 is more efficient at 10 with 376µs per
Straus 5 is more efficient at 170 with 330µs per
Pippenger 5 is more efficient at 125 with 305µs per
Pippenger 6 is more efficient at 275 with 250µs per
Pippenger 7 is more efficient at 450 with 205µs per
Pippenger 8 is more efficient at 800 with 213µs per
Debug (with runs 5, so...):
k256
Straus 3 is more efficient at 5 with 2532µs per
Straus 4 is more efficient at 10 with 1930µs per
Straus 5 is more efficient at 80 with 1632µs per
Pippenger 5 is more efficient at 150 with 1441µs per
Pippenger 6 is more efficient at 300 with 1235µs per
Pippenger 7 is more efficient at 475 with 1182µs per
Pippenger 8 is more efficient at 625 with 1170µs per
dalek:
Straus 3 is more efficient at 5 with 971µs per
Straus 4 is more efficient at 10 with 782µs per
Straus 5 is more efficient at 75 with 778µs per
Straus 6 is more efficient at 165 with 867µs per
Pippenger 5 is more efficient at 125 with 677µs per
Pippenger 6 is more efficient at 250 with 655µs per
Pippenger 7 is more efficient at 475 with 500µs per
Pippenger 8 is more efficient at 875 with 499µs per
*/
fn algorithm(len: usize) -> Algorithm {
#[cfg(not(debug_assertions))]
if len == 0 {
Algorithm::Null
} else if len == 1 {
Algorithm::Single
} else if len < 10 {
// Straus 2 never showed a performance benefit, even with just 2 elements
Algorithm::Straus(3)
} else if len < 20 {
Algorithm::Straus(4)
} else if len < 50 {
Algorithm::Straus(5)
} else if len < 100 {
Algorithm::Pippenger(4)
} else if len < 125 {
Algorithm::Pippenger(5)
} else if len < 275 {
Algorithm::Pippenger(6)
} else if len < 400 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger(8)
/// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the
/// amount of pairs.
pub fn multiexp<G: Zeroize + Group<Scalar: Zeroize + PrimeFieldBits>>(
pairs: &[(G::Scalar, G)],
) -> G {
match algorithm(pairs.len()) {
Algorithm::Null => Group::identity(),
Algorithm::Single => pairs[0].1 * pairs[0].0,
// These functions panic if called without any pairs
Algorithm::Straus(window) => straus(pairs, window),
Algorithm::Pippenger(window) => pippenger(pairs, window),
}
}
#[cfg(debug_assertions)]
if len == 0 {
Algorithm::Null
} else if len == 1 {
Algorithm::Single
} else if len < 10 {
Algorithm::Straus(3)
} else if len < 80 {
Algorithm::Straus(4)
} else if len < 100 {
Algorithm::Straus(5)
} else if len < 125 {
Algorithm::Pippenger(4)
} else if len < 275 {
Algorithm::Pippenger(5)
} else if len < 475 {
Algorithm::Pippenger(6)
} else if len < 750 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger(8)
/// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm
/// based on the amount of pairs.
pub fn multiexp_vartime<G: Group<Scalar: PrimeFieldBits>>(pairs: &[(G::Scalar, G)]) -> G {
match algorithm(pairs.len()) {
Algorithm::Null => Group::identity(),
Algorithm::Single => pairs[0].1 * pairs[0].0,
Algorithm::Straus(window) => straus_vartime(pairs, window),
Algorithm::Pippenger(window) => pippenger_vartime(pairs, window),
}
}
}
/// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the
/// amount of pairs.
pub fn multiexp<G: Zeroize + Group<Scalar: Zeroize + PrimeFieldBits>>(
pairs: &[(G::Scalar, G)],
) -> G {
match algorithm(pairs.len()) {
Algorithm::Null => Group::identity(),
Algorithm::Single => pairs[0].1 * pairs[0].0,
// These functions panic if called without any pairs
Algorithm::Straus(window) => straus(pairs, window),
Algorithm::Pippenger(window) => pippenger(pairs, window),
#[cfg(not(feature = "alloc"))]
mod underlying {
use super::*;
/// Performs a multiexponentiation, automatically selecting the optimal algorithm based on the
/// amount of pairs.
pub fn multiexp<G: Zeroize + Group<Scalar: Zeroize + PrimeFieldBits>>(
pairs: &[(G::Scalar, G)],
) -> G {
pairs.iter().map(|(scalar, point)| *point * scalar).sum()
}
/// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm
/// based on the amount of pairs.
pub fn multiexp_vartime<G: Group<Scalar: PrimeFieldBits>>(pairs: &[(G::Scalar, G)]) -> G {
pairs.iter().map(|(scalar, point)| *point * scalar).sum()
}
}
/// Performs a multiexponentiation in variable time, automatically selecting the optimal algorithm
/// based on the amount of pairs.
pub fn multiexp_vartime<G: Group<Scalar: PrimeFieldBits>>(pairs: &[(G::Scalar, G)]) -> G {
match algorithm(pairs.len()) {
Algorithm::Null => Group::identity(),
Algorithm::Single => pairs[0].1 * pairs[0].0,
Algorithm::Straus(window) => straus_vartime(pairs, window),
Algorithm::Pippenger(window) => pippenger_vartime(pairs, window),
}
}
pub use underlying::*;

View File

@@ -1,3 +1,5 @@
use alloc::vec;
use zeroize::Zeroize;
use ff::PrimeFieldBits;

View File

@@ -1,4 +1,4 @@
use std_shims::vec::Vec;
use alloc::{vec, vec::Vec};
use zeroize::Zeroize;

View File

@@ -27,7 +27,7 @@ digest = { version = "0.11.0-rc.1", default-features = false, features = ["block
transcript = { package = "flexible-transcript", path = "../transcript", version = "^0.3.2", default-features = false, optional = true }
ciphersuite = { path = "../ciphersuite", version = "^0.4.1", default-features = false }
multiexp = { path = "../multiexp", version = "0.4", default-features = false, features = ["batch"], optional = true }
multiexp = { path = "../multiexp", version = "0.4", default-features = false }
[dev-dependencies]
hex = "0.4"
@@ -40,7 +40,7 @@ dalek-ff-group = { path = "../dalek-ff-group" }
ciphersuite = { path = "../ciphersuite" }
[features]
alloc = ["zeroize/alloc", "digest/alloc", "ciphersuite/alloc", "multiexp"]
alloc = ["zeroize/alloc", "digest/alloc", "ciphersuite/alloc", "multiexp/alloc", "multiexp/batch"]
aggregate = ["alloc", "transcript"]
std = ["alloc", "std-shims/std", "rand_core/std", "zeroize/std", "transcript?/std", "ciphersuite/std", "multiexp/std"]
default = ["std"]

View File

@@ -23,8 +23,9 @@ use ciphersuite::{
},
GroupIo,
};
use multiexp::multiexp_vartime;
#[cfg(feature = "alloc")]
use multiexp::{multiexp_vartime, BatchVerifier};
use multiexp::BatchVerifier;
/// Half-aggregation from <https://eprint.iacr.org/2021/350>.
#[cfg(feature = "aggregate")]
@@ -109,12 +110,7 @@ impl<C: GroupIo> SchnorrSignature<C> {
/// different keys/messages.
#[must_use]
pub fn verify(&self, public_key: C::G, challenge: C::F) -> bool {
let statements = self.batch_statements(public_key, challenge);
#[cfg(feature = "alloc")]
let res = multiexp_vartime(&statements);
#[cfg(not(feature = "alloc"))]
let res = statements.into_iter().map(|(scalar, point)| point * scalar).sum::<C::G>();
res.is_identity().into()
multiexp_vartime(&self.batch_statements(public_key, challenge)).is_identity().into()
}
/// Queue a signature for batch verification.

View File

@@ -21,7 +21,7 @@ std-shims = { path = "../../common/std-shims", default-features = false }
flexible-transcript = { path = "../../crypto/transcript", default-features = false, features = ["recommended", "merlin"] }
multiexp = { path = "../../crypto/multiexp", default-features = false, features = ["batch"], optional = true }
multiexp = { path = "../../crypto/multiexp", default-features = false }
dalek-ff-group = { path = "../../crypto/dalek-ff-group", default-features = false }
minimal-ed448 = { path = "../../crypto/ed448", default-features = false }
@@ -46,7 +46,8 @@ bitcoin-serai = { path = "../../networks/bitcoin", default-features = false, fea
alloc = [
"std-shims/alloc",
"multiexp",
"multiexp/alloc",
"multiexp/batch",
"dalek-ff-group/alloc",
"minimal-ed448/alloc",