Enforce FROST StateMachine progression via the type system

A comment on the matter was made in 
https://github.com/serai-dex/serai/issues/12. While I do believe the API 
is slightly worse, I appreciate the explicitness.
This commit is contained in:
Luke Parker
2022-06-24 08:40:14 -04:00
parent 462d0e74ce
commit 1caa6a9606
9 changed files with 276 additions and 351 deletions

View File

@@ -1,5 +1,4 @@
use core::fmt;
use std::collections::HashMap;
use std::{marker::PhantomData, collections::HashMap};
use rand_core::{RngCore, CryptoRng};
@@ -271,100 +270,76 @@ fn complete_r2<R: RngCore + CryptoRng, C: Curve>(
)
}
/// State of a Key Generation machine
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum State {
Fresh,
GeneratedCoefficients,
GeneratedSecretShares,
Complete,
}
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
/// State machine which manages key generation
#[allow(non_snake_case)]
pub struct StateMachine<C: Curve> {
pub struct KeyGenMachine<C: Curve> {
params: MultisigParams,
context: String,
state: State,
coefficients: Option<Vec<C::F>>,
our_commitments: Option<Vec<u8>>,
secret: Option<C::F>,
commitments: Option<HashMap<u16, Vec<C::G>>>
_curve: PhantomData<C>,
}
impl<C: Curve> StateMachine<C> {
pub struct SecretShareMachine<C: Curve> {
params: MultisigParams,
context: String,
coefficients: Vec<C::F>,
our_commitments: Vec<u8>,
}
pub struct KeyMachine<C: Curve> {
params: MultisigParams,
secret: C::F,
commitments: HashMap<u16, Vec<C::G>>,
}
impl<C: Curve> KeyGenMachine<C> {
/// Creates a new machine to generate a key for the specified curve in the specified multisig
// The context string must be unique among multisigs
pub fn new(params: MultisigParams, context: String) -> StateMachine<C> {
StateMachine {
params,
context,
state: State::Fresh,
coefficients: None,
our_commitments: None,
secret: None,
commitments: None
}
pub fn new(params: MultisigParams, context: String) -> KeyGenMachine<C> {
KeyGenMachine { params, context, _curve: PhantomData }
}
/// Start generating a key according to the FROST DKG spec
/// Returns a serialized list of commitments to be sent to all parties over an authenticated
/// channel. If any party submits multiple sets of commitments, they MUST be treated as malicious
pub fn generate_coefficients<R: RngCore + CryptoRng>(
&mut self,
self,
rng: &mut R
) -> Result<Vec<u8>, FrostError> {
if self.state != State::Fresh {
Err(FrostError::InvalidKeyGenTransition(State::Fresh, self.state))?;
}
let (coefficients, serialized) = generate_key_r1::<R, C>(
rng,
&self.params,
&self.context,
);
self.coefficients = Some(coefficients);
self.our_commitments = Some(serialized.clone());
self.state = State::GeneratedCoefficients;
Ok(serialized)
) -> (SecretShareMachine<C>, Vec<u8>) {
let (coefficients, serialized) = generate_key_r1::<R, C>(rng, &self.params, &self.context);
(
SecretShareMachine {
params: self.params,
context: self.context,
coefficients,
our_commitments: serialized.clone()
},
serialized,
)
}
}
impl<C: Curve> SecretShareMachine<C> {
/// Continue generating a key
/// Takes in everyone else's commitments, which are expected to be in a Vec where participant
/// index = Vec index. An empty vector is expected at index 0 to allow for this. An empty vector
/// is also expected at index i which is locally handled. Returns a byte vector representing a
/// secret share for each other participant which should be encrypted before sending
pub fn generate_secret_shares<R: RngCore + CryptoRng>(
&mut self,
self,
rng: &mut R,
commitments: HashMap<u16, Vec<u8>>,
) -> Result<HashMap<u16, Vec<u8>>, FrostError> {
if self.state != State::GeneratedCoefficients {
Err(FrostError::InvalidKeyGenTransition(State::GeneratedCoefficients, self.state))?;
}
) -> Result<(KeyMachine<C>, HashMap<u16, Vec<u8>>), FrostError> {
let (secret, commitments, shares) = generate_key_r2::<R, C>(
rng,
&self.params,
&self.context,
self.coefficients.take().unwrap(),
self.our_commitments.take().unwrap(),
self.coefficients,
self.our_commitments,
commitments,
)?;
self.secret = Some(secret);
self.commitments = Some(commitments);
self.state = State::GeneratedSecretShares;
Ok(shares)
Ok((KeyMachine { params: self.params, secret, commitments }, shares))
}
}
impl<C: Curve> KeyMachine<C> {
/// Complete key generation
/// Takes in everyone elses' shares submitted to us as a Vec, expecting participant index =
/// Vec index with an empty vector at index 0 and index i. Returns a byte vector representing the
@@ -372,31 +347,10 @@ impl<C: Curve> StateMachine<C> {
/// must report completion without issue before this key can be considered usable, yet you should
/// wait for all participants to report as such
pub fn complete<R: RngCore + CryptoRng>(
&mut self,
self,
rng: &mut R,
shares: HashMap<u16, Vec<u8>>,
) -> Result<MultisigKeys<C>, FrostError> {
if self.state != State::GeneratedSecretShares {
Err(FrostError::InvalidKeyGenTransition(State::GeneratedSecretShares, self.state))?;
}
let keys = complete_r2(
rng,
self.params,
self.secret.take().unwrap(),
self.commitments.take().unwrap(),
shares,
)?;
self.state = State::Complete;
Ok(keys)
}
pub fn params(&self) -> MultisigParams {
self.params.clone()
}
pub fn state(&self) -> State {
self.state
complete_r2(rng, self.params, self.secret, self.commitments, shares)
}
}

View File

@@ -181,11 +181,6 @@ pub enum FrostError {
InvalidProofOfKnowledge(u16),
#[error("invalid share (participant {0})")]
InvalidShare(u16),
#[error("invalid key generation state machine transition (expected {0}, was {1})")]
InvalidKeyGenTransition(key_gen::State, key_gen::State),
#[error("invalid sign state machine transition (expected {0}, was {1})")]
InvalidSignTransition(sign::State, sign::State),
#[error("internal error ({0})")]
InternalError(String),

View File

@@ -236,31 +236,21 @@ fn complete<C: Curve, A: Algorithm<C>>(
)
}
/// State of a Sign machine
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum State {
Fresh,
Preprocessed,
Signed,
Complete,
}
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
pub trait StateMachine {
pub trait PreprocessMachine {
type Signature: Clone + PartialEq + fmt::Debug;
type SignMachine: SignMachine<Self::Signature>;
/// Perform the preprocessing round required in order to sign
/// Returns a byte vector which must be transmitted to all parties selected for this signing
/// process, over an authenticated channel
fn preprocess<R: RngCore + CryptoRng>(
&mut self,
self,
rng: &mut R
) -> Result<Vec<u8>, FrostError>;
) -> (Self::SignMachine, Vec<u8>);
}
pub trait SignMachine<S> {
type SignatureMachine: SignatureMachine<S>;
/// Sign a message
/// Takes in the participant's commitments, which are expected to be in a Vec where participant
@@ -268,29 +258,33 @@ pub trait StateMachine {
/// index i which is locally handled. Returns a byte vector representing a share of the signature
/// for every other participant to receive, over an authenticated channel
fn sign(
&mut self,
self,
commitments: HashMap<u16, Vec<u8>>,
msg: &[u8],
) -> Result<Vec<u8>, FrostError>;
) -> Result<(Self::SignatureMachine, Vec<u8>), FrostError>;
}
pub trait SignatureMachine<S> {
/// Complete signing
/// Takes in everyone elses' shares submitted to us as a Vec, expecting participant index =
/// Vec index with None at index 0 and index i. Returns a byte vector representing the serialized
/// signature
fn complete(&mut self, shares: HashMap<u16, Vec<u8>>) -> Result<Self::Signature, FrostError>;
fn multisig_params(&self) -> MultisigParams;
fn state(&self) -> State;
fn complete(self, shares: HashMap<u16, Vec<u8>>) -> Result<S, FrostError>;
}
/// State machine which manages signing for an arbitrary signature algorithm
#[allow(non_snake_case)]
pub struct AlgorithmMachine<C: Curve, A: Algorithm<C>> {
params: Params<C, A>
}
pub struct AlgorithmSignMachine<C: Curve, A: Algorithm<C>> {
params: Params<C, A>,
state: State,
preprocess: Option<PreprocessPackage<C>>,
sign: Option<Package<C>>,
preprocess: PreprocessPackage<C>,
}
pub struct AlgorithmSignatureMachine<C: Curve, A: Algorithm<C>> {
params: Params<C, A>,
sign: Package<C>,
}
impl<C: Curve, A: Algorithm<C>> AlgorithmMachine<C, A> {
@@ -300,85 +294,52 @@ impl<C: Curve, A: Algorithm<C>> AlgorithmMachine<C, A> {
keys: Arc<MultisigKeys<C>>,
included: &[u16],
) -> Result<AlgorithmMachine<C, A>, FrostError> {
Ok(
AlgorithmMachine {
params: Params::new(algorithm, keys, included)?,
state: State::Fresh,
preprocess: None,
sign: None,
}
)
Ok(AlgorithmMachine { params: Params::new(algorithm, keys, included)? })
}
pub(crate) fn unsafe_override_preprocess(&mut self, preprocess: PreprocessPackage<C>) {
if self.state != State::Fresh {
// This would be unacceptable, yet this is pub(crate) and explicitly labelled unsafe
// It's solely used in a testing environment, which is how it's justified
Err::<(), _>(FrostError::InvalidSignTransition(State::Fresh, self.state)).unwrap();
}
self.preprocess = Some(preprocess);
self.state = State::Preprocessed;
pub(crate) fn unsafe_override_preprocess(
self,
preprocess: PreprocessPackage<C>
) -> (AlgorithmSignMachine<C, A>, Vec<u8>) {
let serialized = preprocess.serialized.clone();
(AlgorithmSignMachine { params: self.params, preprocess }, serialized)
}
}
impl<C: Curve, A: Algorithm<C>> StateMachine for AlgorithmMachine<C, A> {
impl<C: Curve, A: Algorithm<C>> PreprocessMachine for AlgorithmMachine<C, A> {
type Signature = A::Signature;
type SignMachine = AlgorithmSignMachine<C, A>;
fn preprocess<R: RngCore + CryptoRng>(
&mut self,
self,
rng: &mut R
) -> Result<Vec<u8>, FrostError> {
if self.state != State::Fresh {
Err(FrostError::InvalidSignTransition(State::Fresh, self.state))?;
}
let preprocess = preprocess::<R, C, A>(rng, &mut self.params);
) -> (Self::SignMachine, Vec<u8>) {
let mut params = self.params;
let preprocess = preprocess::<R, C, A>(rng, &mut params);
let serialized = preprocess.serialized.clone();
self.preprocess = Some(preprocess);
self.state = State::Preprocessed;
Ok(serialized)
}
fn sign(
&mut self,
commitments: HashMap<u16, Vec<u8>>,
msg: &[u8],
) -> Result<Vec<u8>, FrostError> {
if self.state != State::Preprocessed {
Err(FrostError::InvalidSignTransition(State::Preprocessed, self.state))?;
}
let (sign, serialized) = sign_with_share(
&mut self.params,
self.preprocess.take().unwrap(),
commitments,
msg,
)?;
self.sign = Some(sign);
self.state = State::Signed;
Ok(serialized)
}
fn complete(&mut self, shares: HashMap<u16, Vec<u8>>) -> Result<A::Signature, FrostError> {
if self.state != State::Signed {
Err(FrostError::InvalidSignTransition(State::Signed, self.state))?;
}
let signature = complete(
&self.params,
self.sign.take().unwrap(),
shares,
)?;
self.state = State::Complete;
Ok(signature)
}
fn multisig_params(&self) -> MultisigParams {
self.params.multisig_params().clone()
}
fn state(&self) -> State {
self.state
(AlgorithmSignMachine { params, preprocess }, serialized)
}
}
impl<C: Curve, A: Algorithm<C>> SignMachine<A::Signature> for AlgorithmSignMachine<C, A> {
type SignatureMachine = AlgorithmSignatureMachine<C, A>;
fn sign(
self,
commitments: HashMap<u16, Vec<u8>>,
msg: &[u8]
) -> Result<(Self::SignatureMachine, Vec<u8>), FrostError> {
let mut params = self.params;
let (sign, serialized) = sign_with_share(&mut params, self.preprocess, commitments, msg)?;
Ok((AlgorithmSignatureMachine { params, sign }, serialized))
}
}
impl<
C: Curve,
A: Algorithm<C>
> SignatureMachine<A::Signature> for AlgorithmSignatureMachine<C, A> {
fn complete(self, shares: HashMap<u16, Vec<u8>>) -> Result<A::Signature, FrostError> {
complete(&self.params, self.sign, shares)
}
}

View File

@@ -8,9 +8,9 @@ use crate::{
Curve,
MultisigParams, MultisigKeys,
lagrange,
key_gen,
key_gen::KeyGenMachine,
algorithm::Algorithm,
sign::{StateMachine, AlgorithmMachine}
sign::{PreprocessMachine, SignMachine, SignatureMachine, AlgorithmMachine}
};
// Test suites for public usage
@@ -37,49 +37,36 @@ pub fn clone_without<K: Clone + std::cmp::Eq + std::hash::Hash, V: Clone>(
pub fn key_gen<R: RngCore + CryptoRng, C: Curve>(
rng: &mut R
) -> HashMap<u16, Arc<MultisigKeys<C>>> {
let mut params = HashMap::new();
let mut machines = HashMap::new();
let mut commitments = HashMap::new();
for i in 1 ..= PARTICIPANTS {
params.insert(
i,
MultisigParams::new(
THRESHOLD,
PARTICIPANTS,
i
).unwrap()
);
machines.insert(
i,
key_gen::StateMachine::<C>::new(
params[&i],
"FROST Test key_gen".to_string()
)
);
commitments.insert(
i,
machines.get_mut(&i).unwrap().generate_coefficients(rng).unwrap()
let machine = KeyGenMachine::<C>::new(
MultisigParams::new(THRESHOLD, PARTICIPANTS, i).unwrap(),
"FROST Test key_gen".to_string()
);
let (machine, these_commitments) = machine.generate_coefficients(rng);
machines.insert(i, machine);
commitments.insert(i, these_commitments);
}
let mut secret_shares = HashMap::new();
for (l, machine) in machines.iter_mut() {
secret_shares.insert(
*l,
let mut machines = machines.drain().map(|(l, machine)| {
let (machine, shares) = machine.generate_secret_shares(
rng,
// clone_without isn't necessary, as this machine's own data will be inserted without
// conflict, yet using it ensures the machine's own data is actually inserted as expected
machine.generate_secret_shares(rng, clone_without(&commitments, l)).unwrap()
);
}
clone_without(&commitments, &l)
).unwrap();
secret_shares.insert(l, shares);
(l, machine)
}).collect::<HashMap<_, _>>();
let mut verification_shares = None;
let mut group_key = None;
let mut keys = HashMap::new();
for (i, machine) in machines.iter_mut() {
machines.drain().map(|(i, machine)| {
let mut our_secret_shares = HashMap::new();
for (l, shares) in &secret_shares {
if i == l {
if i == *l {
continue;
}
our_secret_shares.insert(*l, shares[&i].clone());
@@ -98,10 +85,8 @@ pub fn key_gen<R: RngCore + CryptoRng, C: Curve>(
}
assert_eq!(group_key.unwrap(), these_keys.group_key());
keys.insert(*i, Arc::new(these_keys));
}
keys
(i, Arc::new(these_keys))
}).collect::<HashMap<_, _>>()
}
pub fn recover<C: Curve>(keys: &HashMap<u16, MultisigKeys<C>>) -> C::F {
@@ -147,27 +132,28 @@ pub fn algorithm_machines<R: RngCore, C: Curve, A: Algorithm<C>>(
).collect()
}
pub fn sign<R: RngCore + CryptoRng, M: StateMachine>(
pub fn sign<R: RngCore + CryptoRng, M: PreprocessMachine>(
rng: &mut R,
mut machines: HashMap<u16, M>,
msg: &[u8]
) -> M::Signature {
let mut commitments = HashMap::new();
for (i, machine) in machines.iter_mut() {
commitments.insert(*i, machine.preprocess(rng).unwrap());
}
let mut machines = machines.drain().map(|(i, machine)| {
let (machine, preprocess) = machine.preprocess(rng);
commitments.insert(i, preprocess);
(i, machine)
}).collect::<HashMap<_, _>>();
let mut shares = HashMap::new();
for (i, machine) in machines.iter_mut() {
shares.insert(
*i,
machine.sign(clone_without(&commitments, i), msg).unwrap()
);
}
let mut machines = machines.drain().map(|(i, machine)| {
let (machine, share) = machine.sign(clone_without(&commitments, &i), msg).unwrap();
shares.insert(i, share);
(i, machine)
}).collect::<HashMap<_, _>>();
let mut signature = None;
for (i, machine) in machines.iter_mut() {
let sig = machine.complete(clone_without(&shares, i)).unwrap();
for (i, machine) in machines.drain() {
let sig = machine.complete(clone_without(&shares, &i)).unwrap();
if signature.is_none() {
signature = Some(sig.clone());
}

View File

@@ -5,7 +5,7 @@ use rand_core::{RngCore, CryptoRng};
use crate::{
Curve, MultisigKeys,
algorithm::{Schnorr, Hram},
sign::{PreprocessPackage, StateMachine, AlgorithmMachine},
sign::{PreprocessPackage, SignMachine, SignatureMachine, AlgorithmMachine},
tests::{curve::test_curve, schnorr::test_schnorr, recover}
};
@@ -92,33 +92,40 @@ pub fn test_with_vectors<
let mut commitments = HashMap::new();
let mut c = 0;
for (i, machine) in machines.iter_mut() {
let mut machines = machines.drain(..).map(|(i, machine)| {
let nonces = [
C::F_from_slice(&hex::decode(vectors.nonces[c][0]).unwrap()).unwrap(),
C::F_from_slice(&hex::decode(vectors.nonces[c][1]).unwrap()).unwrap()
];
c += 1;
let mut serialized = C::G_to_bytes(&(C::GENERATOR * nonces[0]));
serialized.extend(&C::G_to_bytes(&(C::GENERATOR * nonces[1])));
machine.unsafe_override_preprocess(
let (machine, serialized) = machine.unsafe_override_preprocess(
PreprocessPackage { nonces, serialized: serialized.clone() }
);
commitments.insert(*i, serialized);
c += 1;
}
commitments.insert(i, serialized);
(i, machine)
}).collect::<Vec<_>>();
let mut shares = HashMap::new();
c = 0;
for (i, machine) in machines.iter_mut() {
let share = machine.sign(commitments.clone(), &hex::decode(vectors.msg).unwrap()).unwrap();
assert_eq!(share, hex::decode(vectors.sig_shares[c]).unwrap());
shares.insert(*i, share);
c += 1;
}
let mut machines = machines.drain(..).map(|(i, machine)| {
let (machine, share) = machine.sign(
commitments.clone(),
&hex::decode(vectors.msg).unwrap()
).unwrap();
for (_, machine) in machines.iter_mut() {
assert_eq!(share, hex::decode(vectors.sig_shares[c]).unwrap());
c += 1;
shares.insert(i, share);
(i, machine)
}).collect::<HashMap<_, _>>();
for (_, machine) in machines.drain() {
let sig = machine.complete(shares.clone()).unwrap();
let mut serialized = C::G_to_bytes(&sig.R);
serialized.extend(C::F_to_bytes(&sig.s));