Implement variable-sized windows into multiexp

Closes https://github.com/serai-dex/serai/issues/17 by using the 
PrimeFieldBits API to do so.

Should greatly speed up small batches, along with batches in the 
hundreds. Saves almost a full second on the cross-group DLEq proof.
This commit is contained in:
Luke Parker
2022-06-30 09:30:24 -04:00
parent 5d115f1e1c
commit 7890827a48
15 changed files with 342 additions and 148 deletions

View File

@@ -9,9 +9,16 @@ keywords = ["multiexp", "ff", "group"]
edition = "2021"
[dependencies]
ff = "0.12"
group = "0.12"
rand_core = { version = "0.6", optional = true }
[dev-dependencies]
rand_core = "0.6"
k256 = { version = "0.11", features = ["bits"] }
dalek-ff-group = { path = "../dalek-ff-group" }
[features]
batch = ["rand_core"]

View File

@@ -1,16 +1,17 @@
use rand_core::{RngCore, CryptoRng};
use group::{ff::Field, Group};
use ff::{Field, PrimeFieldBits};
use group::Group;
use crate::{multiexp, multiexp_vartime};
#[cfg(feature = "batch")]
pub struct BatchVerifier<Id: Copy, G: Group>(Vec<(Id, Vec<(G::Scalar, G)>)>, bool);
pub struct BatchVerifier<Id: Copy, G: Group>(Vec<(Id, Vec<(G::Scalar, G)>)>);
#[cfg(feature = "batch")]
impl<Id: Copy, G: Group> BatchVerifier<Id, G> {
pub fn new(capacity: usize, endian: bool) -> BatchVerifier<Id, G> {
BatchVerifier(Vec::with_capacity(capacity), endian)
impl<Id: Copy, G: Group> BatchVerifier<Id, G> where <G as Group>::Scalar: PrimeFieldBits {
pub fn new(capacity: usize) -> BatchVerifier<Id, G> {
BatchVerifier(Vec::with_capacity(capacity))
}
pub fn queue<
@@ -28,15 +29,13 @@ impl<Id: Copy, G: Group> BatchVerifier<Id, G> {
pub fn verify(&self) -> bool {
multiexp(
&self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>(),
self.1
&self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>()
).is_identity().into()
}
pub fn verify_vartime(&self) -> bool {
multiexp_vartime(
&self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>(),
self.1
&self.0.iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>()
).is_identity().into()
}
@@ -46,8 +45,7 @@ impl<Id: Copy, G: Group> BatchVerifier<Id, G> {
while slice.len() > 1 {
let split = slice.len() / 2;
if multiexp_vartime(
&slice[.. split].iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>(),
self.1
&slice[.. split].iter().flat_map(|pairs| pairs.1.iter()).cloned().collect::<Vec<_>>()
).is_identity().into() {
slice = &slice[split ..];
} else {
@@ -56,7 +54,7 @@ impl<Id: Copy, G: Group> BatchVerifier<Id, G> {
}
slice.get(0).filter(
|(_, value)| !bool::from(multiexp_vartime(value, self.1).is_identity())
|(_, value)| !bool::from(multiexp_vartime(value).is_identity())
).map(|(id, _)| *id)
}

View File

@@ -1,3 +1,4 @@
use ff::PrimeFieldBits;
use group::Group;
mod straus;
@@ -11,39 +12,151 @@ mod batch;
#[cfg(feature = "batch")]
pub use batch::BatchVerifier;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Algorithm {
Straus,
Pippenger
#[cfg(test)]
mod tests;
pub(crate) fn prep_bits<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> Vec<Vec<u8>> where G::Scalar: PrimeFieldBits {
let w_usize = usize::from(window);
let mut groupings = vec![];
for pair in pairs {
let p = groupings.len();
let bits = pair.0.to_le_bits();
groupings.push(vec![0; (bits.len() + (w_usize - 1)) / w_usize]);
for (i, bit) in bits.into_iter().enumerate() {
let bit = bit as u8;
debug_assert_eq!(bit | 1, 1);
groupings[p][i / w_usize] |= bit << (i % w_usize);
}
}
groupings
}
fn algorithm(pairs: usize) -> Algorithm {
// TODO: Replace this with an actual formula determining which will use less additions
// Right now, Straus is used until 600, instead of the far more accurate 300, as Pippenger
// operates per byte instead of per nibble, and therefore requires a much longer series to be
// performant
// Technically, 800 is dalek's number for when to use byte Pippenger, yet given Straus's own
// implementation limitations...
if pairs < 600 {
Algorithm::Straus
pub(crate) fn prep_tables<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> Vec<Vec<G>> {
let mut tables = Vec::with_capacity(pairs.len());
for pair in pairs {
let p = tables.len();
tables.push(vec![G::identity(); 2_usize.pow(window.into())]);
let mut accum = G::identity();
for i in 1 .. tables[p].len() {
accum += pair.1;
tables[p][i] = accum;
}
}
tables
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Algorithm {
Straus(u8),
Pippenger(u8)
}
/*
Release (with runs 20, so all of these are off by 20x):
k256
Straus 3 is more efficient at 5 with 678µs per
Straus 4 is more efficient at 10 with 530µs per
Straus 5 is more efficient at 35 with 467µs per
Pippenger 5 is more efficient at 125 with 431µs per
Pippenger 6 is more efficient at 275 with 349µs per
Pippenger 7 is more efficient at 375 with 360µs per
dalek
Straus 3 is more efficient at 5 with 519µs per
Straus 4 is more efficient at 10 with 376µs per
Straus 5 is more efficient at 170 with 330µs per
Pippenger 5 is more efficient at 125 with 305µs per
Pippenger 6 is more efficient at 275 with 250µs per
Pippenger 7 is more efficient at 450 with 205µs per
Pippenger 8 is more efficient at 800 with 213µs per
Debug (with runs 5, so...):
k256
Straus 3 is more efficient at 5 with 2532µs per
Straus 4 is more efficient at 10 with 1930µs per
Straus 5 is more efficient at 80 with 1632µs per
Pippenger 5 is more efficient at 150 with 1441µs per
Pippenger 6 is more efficient at 300 with 1235µs per
Pippenger 7 is more efficient at 475 with 1182µs per
Pippenger 8 is more efficient at 625 with 1170µs per
dalek:
Straus 3 is more efficient at 5 with 971µs per
Straus 4 is more efficient at 10 with 782µs per
Straus 5 is more efficient at 75 with 778µs per
Straus 6 is more efficient at 165 with 867µs per
Pippenger 5 is more efficient at 125 with 677µs per
Pippenger 6 is more efficient at 250 with 655µs per
Pippenger 7 is more efficient at 475 with 500µs per
Pippenger 8 is more efficient at 875 with 499µs per
*/
fn algorithm(len: usize) -> Algorithm {
#[cfg(not(debug_assertions))]
if len < 10 {
// Straus 2 never showed a performance benefit, even with just 2 elements
Algorithm::Straus(3)
} else if len < 20 {
Algorithm::Straus(4)
} else if len < 50 {
Algorithm::Straus(5)
} else if len < 100 {
Algorithm::Pippenger(4)
} else if len < 125 {
Algorithm::Pippenger(5)
} else if len < 275 {
Algorithm::Pippenger(6)
} else if len < 400 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger
Algorithm::Pippenger(8)
}
#[cfg(debug_assertions)]
if len < 10 {
Algorithm::Straus(3)
} else if len < 80 {
Algorithm::Straus(4)
} else if len < 100 {
Algorithm::Straus(5)
} else if len < 125 {
Algorithm::Pippenger(4)
} else if len < 275 {
Algorithm::Pippenger(5)
} else if len < 475 {
Algorithm::Pippenger(6)
} else if len < 750 {
Algorithm::Pippenger(7)
} else {
Algorithm::Pippenger(8)
}
}
// Performs a multiexp, automatically selecting the optimal algorithm based on amount of pairs
// Takes in an iterator of scalars and points, with a boolean for if the scalars are little endian
// encoded in their Reprs or not
pub fn multiexp<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
pub fn multiexp<G: Group>(pairs: &[(G::Scalar, G)]) -> G where G::Scalar: PrimeFieldBits {
match algorithm(pairs.len()) {
Algorithm::Straus => straus(pairs, little),
Algorithm::Pippenger => pippenger(pairs, little)
Algorithm::Straus(window) => straus(pairs, window),
Algorithm::Pippenger(window) => pippenger(pairs, window)
}
}
pub fn multiexp_vartime<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
pub fn multiexp_vartime<G: Group>(pairs: &[(G::Scalar, G)]) -> G where G::Scalar: PrimeFieldBits {
match algorithm(pairs.len()) {
Algorithm::Straus => straus_vartime(pairs, little),
Algorithm::Pippenger => pippenger_vartime(pairs, little)
Algorithm::Straus(window) => straus_vartime(pairs, window),
Algorithm::Pippenger(window) => pippenger_vartime(pairs, window)
}
}

View File

@@ -1,42 +1,23 @@
use group::{ff::PrimeField, Group};
use ff::PrimeFieldBits;
use group::Group;
fn prep<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> (Vec<Vec<u8>>, Vec<G>) {
let mut res = vec![];
let mut points = vec![];
for pair in pairs {
let p = res.len();
res.push(vec![]);
{
let mut repr = pair.0.to_repr();
let bytes = repr.as_mut();
if !little {
bytes.reverse();
}
use crate::prep_bits;
res[p].resize(bytes.len(), 0);
for i in 0 .. bytes.len() {
res[p][i] = bytes[i];
}
}
points.push(pair.1);
}
(res, points)
}
pub(crate) fn pippenger<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
let (bytes, points) = prep(pairs, little);
pub(crate) fn pippenger<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> G where G::Scalar: PrimeFieldBits {
let bits = prep_bits(pairs, window);
let mut res = G::identity();
for n in (0 .. bytes[0].len()).rev() {
for _ in 0 .. 8 {
for n in (0 .. bits[0].len()).rev() {
for _ in 0 .. window {
res = res.double();
}
let mut buckets = [G::identity(); 256];
for p in 0 .. bytes.len() {
buckets[usize::from(bytes[p][n])] += points[p];
let mut buckets = vec![G::identity(); 2_usize.pow(window.into())];
for p in 0 .. bits.len() {
buckets[usize::from(bits[p][n])] += pairs[p].1;
}
let mut intermediate_sum = G::identity();
@@ -49,22 +30,25 @@ pub(crate) fn pippenger<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
res
}
pub(crate) fn pippenger_vartime<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
let (bytes, points) = prep(pairs, little);
pub(crate) fn pippenger_vartime<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> G where G::Scalar: PrimeFieldBits {
let bits = prep_bits(pairs, window);
let mut res = G::identity();
for n in (0 .. bytes[0].len()).rev() {
if n != (bytes[0].len() - 1) {
for _ in 0 .. 8 {
for n in (0 .. bits[0].len()).rev() {
if n != (bits[0].len() - 1) {
for _ in 0 .. window {
res = res.double();
}
}
let mut buckets = [G::identity(); 256];
for p in 0 .. bytes.len() {
let nibble = usize::from(bytes[p][n]);
let mut buckets = vec![G::identity(); 2_usize.pow(window.into())];
for p in 0 .. bits.len() {
let nibble = usize::from(bits[p][n]);
if nibble != 0 {
buckets[nibble] += points[p];
buckets[nibble] += pairs[p].1;
}
}

View File

@@ -1,66 +1,46 @@
use group::{ff::PrimeField, Group};
use ff::PrimeFieldBits;
use group::Group;
fn prep<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> (Vec<Vec<u8>>, Vec<[G; 16]>) {
let mut nibbles = vec![];
let mut tables = vec![];
for pair in pairs {
let p = nibbles.len();
nibbles.push(vec![]);
{
let mut repr = pair.0.to_repr();
let bytes = repr.as_mut();
if !little {
bytes.reverse();
}
use crate::{prep_bits, prep_tables};
nibbles[p].resize(bytes.len() * 2, 0);
for i in 0 .. bytes.len() {
nibbles[p][i * 2] = bytes[i] & 0b1111;
nibbles[p][(i * 2) + 1] = (bytes[i] >> 4) & 0b1111;
}
}
tables.push([G::identity(); 16]);
let mut accum = G::identity();
for i in 1 .. 16 {
accum += pair.1;
tables[p][i] = accum;
}
}
(nibbles, tables)
}
pub(crate) fn straus<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
let (nibbles, tables) = prep(pairs, little);
pub(crate) fn straus<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> G where G::Scalar: PrimeFieldBits {
let groupings = prep_bits(pairs, window);
let tables = prep_tables(pairs, window);
let mut res = G::identity();
for b in (0 .. nibbles[0].len()).rev() {
for _ in 0 .. 4 {
for b in (0 .. groupings[0].len()).rev() {
for _ in 0 .. window {
res = res.double();
}
for s in 0 .. tables.len() {
res += tables[s][usize::from(nibbles[s][b])];
res += tables[s][usize::from(groupings[s][b])];
}
}
res
}
pub(crate) fn straus_vartime<G: Group>(pairs: &[(G::Scalar, G)], little: bool) -> G {
let (nibbles, tables) = prep(pairs, little);
pub(crate) fn straus_vartime<G: Group>(
pairs: &[(G::Scalar, G)],
window: u8
) -> G where G::Scalar: PrimeFieldBits {
let groupings = prep_bits(pairs, window);
let tables = prep_tables(pairs, window);
let mut res = G::identity();
for b in (0 .. nibbles[0].len()).rev() {
if b != (nibbles[0].len() - 1) {
for _ in 0 .. 4 {
for b in (0 .. groupings[0].len()).rev() {
if b != (groupings[0].len() - 1) {
for _ in 0 .. window {
res = res.double();
}
}
for s in 0 .. tables.len() {
if nibbles[s][b] != 0 {
res += tables[s][usize::from(nibbles[s][b])];
if groupings[s][b] != 0 {
res += tables[s][usize::from(groupings[s][b])];
}
}
}

View File

@@ -0,0 +1,112 @@
use std::time::Instant;
use rand_core::OsRng;
use ff::{Field, PrimeFieldBits};
use group::Group;
use k256::ProjectivePoint;
use dalek_ff_group::EdwardsPoint;
use crate::{straus, pippenger, multiexp, multiexp_vartime};
#[allow(dead_code)]
fn benchmark_internal<G: Group>(straus_bool: bool) where G::Scalar: PrimeFieldBits {
let runs: usize = 20;
let mut start = 0;
let mut increment: usize = 5;
let mut total: usize = 250;
let mut current = 2;
if !straus_bool {
start = 100;
increment = 25;
total = 1000;
current = 4;
};
let mut pairs = Vec::with_capacity(total);
let mut sum = G::identity();
for _ in 0 .. start {
pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng)));
sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0;
}
for _ in 0 .. (total / increment) {
for _ in 0 .. increment {
pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng)));
sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0;
}
let now = Instant::now();
for _ in 0 .. runs {
if straus_bool {
assert_eq!(straus(&pairs, current), sum);
} else {
assert_eq!(pippenger(&pairs, current), sum);
}
}
let current_per = now.elapsed().as_micros() / u128::try_from(pairs.len()).unwrap();
let now = Instant::now();
for _ in 0 .. runs {
if straus_bool {
assert_eq!(straus(&pairs, current + 1), sum);
} else {
assert_eq!(pippenger(&pairs, current + 1), sum);
}
}
let next_per = now.elapsed().as_micros() / u128::try_from(pairs.len()).unwrap();
if next_per < current_per {
current += 1;
println!(
"{} {} is more efficient at {} with {}µs per",
if straus_bool { "Straus" } else { "Pippenger" }, current, pairs.len(), next_per
);
if current >= 8 {
return;
}
}
}
}
fn test_multiexp<G: Group>() where G::Scalar: PrimeFieldBits {
let mut pairs = Vec::with_capacity(1000);
let mut sum = G::identity();
for _ in 0 .. 10 {
for _ in 0 .. 100 {
pairs.push((G::Scalar::random(&mut OsRng), G::generator() * G::Scalar::random(&mut OsRng)));
sum += pairs[pairs.len() - 1].1 * pairs[pairs.len() - 1].0;
}
assert_eq!(multiexp(&pairs), sum);
assert_eq!(multiexp_vartime(&pairs), sum);
}
}
#[test]
fn test_secp256k1() {
test_multiexp::<ProjectivePoint>();
}
#[test]
fn test_ed25519() {
test_multiexp::<EdwardsPoint>();
}
#[test]
#[ignore]
fn benchmark() {
// Activate the processor's boost clock
for _ in 0 .. 30 {
test_multiexp::<ProjectivePoint>();
}
benchmark_internal::<ProjectivePoint>(true);
benchmark_internal::<ProjectivePoint>(false);
benchmark_internal::<EdwardsPoint>(true);
benchmark_internal::<EdwardsPoint>(false);
}