Enforce FROST StateMachine progression via the type system

A comment on the matter was made in 
https://github.com/serai-dex/serai/issues/12. While I do believe the API 
is slightly worse, I appreciate the explicitness.
This commit is contained in:
Luke Parker
2022-06-24 08:40:14 -04:00
parent 462d0e74ce
commit 1caa6a9606
9 changed files with 276 additions and 351 deletions

View File

@@ -6,7 +6,13 @@ use rand_chacha::ChaCha12Rng;
use curve25519_dalek::{traits::Identity, scalar::Scalar, edwards::{EdwardsPoint, CompressedEdwardsY}};
use transcript::Transcript as TranscriptTrait;
use frost::{FrostError, MultisigKeys, MultisigParams, sign::{State, StateMachine, AlgorithmMachine}};
use frost::{
FrostError, MultisigKeys,
sign::{
PreprocessMachine, SignMachine, SignatureMachine,
AlgorithmMachine, AlgorithmSignMachine, AlgorithmSignatureMachine
}
};
use crate::{
frost::{Transcript, Ed25519},
@@ -24,14 +30,27 @@ pub struct TransactionMachine {
decoys: Vec<Decoys>,
our_preprocess: Vec<u8>,
images: Vec<EdwardsPoint>,
output_masks: Option<Scalar>,
inputs: Vec<Arc<RwLock<Option<ClsagDetails>>>>,
clsags: Vec<AlgorithmMachine<Ed25519, ClsagMultisig>>,
clsags: Vec<AlgorithmMachine<Ed25519, ClsagMultisig>>
}
tx: Option<Transaction>
pub struct TransactionSignMachine {
signable: SignableTransaction,
i: u16,
included: Vec<u16>,
transcript: Transcript,
decoys: Vec<Decoys>,
inputs: Vec<Arc<RwLock<Option<ClsagDetails>>>>,
clsags: Vec<AlgorithmSignMachine<Ed25519, ClsagMultisig>>,
our_preprocess: Vec<u8>
}
pub struct TransactionSignatureMachine {
tx: Transaction,
clsags: Vec<AlgorithmSignatureMachine<Ed25519, ClsagMultisig>>
}
impl SignableTransaction {
@@ -43,8 +62,6 @@ impl SignableTransaction {
height: usize,
mut included: Vec<u16>
) -> Result<TransactionMachine, TransactionError> {
let mut images = vec![];
images.resize(self.inputs.len(), EdwardsPoint::identity());
let mut inputs = vec![];
for _ in 0 .. self.inputs.len() {
// Doesn't resize as that will use a single Rc for the entire Vec
@@ -118,43 +135,38 @@ impl SignableTransaction {
&self.inputs
).await.map_err(|e| TransactionError::RpcError(e))?;
Ok(TransactionMachine {
signable: self,
i: keys.params().i(),
included,
transcript,
Ok(
TransactionMachine {
signable: self,
i: keys.params().i(),
included,
transcript,
decoys,
decoys,
our_preprocess: vec![],
images,
output_masks: None,
inputs,
clsags,
tx: None
})
inputs,
clsags
}
)
}
}
impl StateMachine for TransactionMachine {
impl PreprocessMachine for TransactionMachine {
type Signature = Transaction;
type SignMachine = TransactionSignMachine;
fn preprocess<R: RngCore + CryptoRng>(
&mut self,
mut self,
rng: &mut R
) -> Result<Vec<u8>, FrostError> {
if self.state() != State::Fresh {
Err(FrostError::InvalidSignTransition(State::Fresh, self.state()))?;
}
) -> (TransactionSignMachine, Vec<u8>) {
// Iterate over each CLSAG calling preprocess
let mut serialized = Vec::with_capacity(self.clsags.len() * (64 + ClsagMultisig::serialized_len()));
for clsag in self.clsags.iter_mut() {
serialized.extend(&clsag.preprocess(rng)?);
}
self.our_preprocess = serialized.clone();
let clsags = self.clsags.drain(..).map(|clsag| {
let (clsag, preprocess) = clsag.preprocess(rng);
serialized.extend(&preprocess);
clsag
}).collect();
let our_preprocess = serialized.clone();
// We could add further entropy here, and previous versions of this library did so
// As of right now, the multisig's key, the inputs being spent, and the FROST data itself
@@ -165,18 +177,33 @@ impl StateMachine for TransactionMachine {
// increase privacy. If they're not sent in plain text, or are otherwise inaccessible, they
// already offer sufficient entropy. That's why further entropy is not included
Ok(serialized)
(
TransactionSignMachine {
signable: self.signable,
i: self.i,
included: self.included,
transcript: self.transcript,
decoys: self.decoys,
inputs: self.inputs,
clsags,
our_preprocess,
},
serialized
)
}
}
impl SignMachine<Transaction> for TransactionSignMachine {
type SignatureMachine = TransactionSignatureMachine;
fn sign(
&mut self,
mut self,
mut commitments: HashMap<u16, Vec<u8>>,
msg: &[u8]
) -> Result<Vec<u8>, FrostError> {
if self.state() != State::Preprocessed {
Err(FrostError::InvalidSignTransition(State::Preprocessed, self.state()))?;
}
) -> Result<(TransactionSignatureMachine, Vec<u8>), FrostError> {
if msg.len() != 0 {
Err(
FrostError::InternalError(
@@ -189,7 +216,7 @@ impl StateMachine for TransactionMachine {
// While each CLSAG will do this as they need to for security, they have their own transcripts
// cloned from this TX's initial premise's transcript. For our TX transcript to have the CLSAG
// data for entropy, it'll have to be added ourselves
commitments.insert(self.i, self.our_preprocess.clone());
commitments.insert(self.i, self.our_preprocess);
for l in &self.included {
self.transcript.append_message(b"participant", &(*l).to_be_bytes());
// FROST itself will error if this is None, so let it
@@ -201,30 +228,33 @@ impl StateMachine for TransactionMachine {
// FROST commitments, image, H commitments, and their proofs
let clsag_len = 64 + ClsagMultisig::serialized_len();
let mut commitments = (0 .. self.clsags.len()).map(|c| commitments.iter().map(
|(l, commitments)| (*l, commitments[(c * clsag_len) .. ((c + 1) * clsag_len)].to_vec())
// Convert the unified commitments to a Vec of the individual commitments
let mut commitments = (0 .. self.clsags.len()).map(|_| commitments.iter_mut().map(
|(l, commitments)| (*l, commitments.drain(.. clsag_len).collect::<Vec<_>>())
).collect::<HashMap<_, _>>()).collect::<Vec<_>>();
// Calculate the key images
// Clsag will parse/calculate/validate this as needed, yet doing so here as well provides
// the easiest API overall, as this is where the TX is (which needs the key images in its
// message), along with where the outputs are determined (where our change output needs these
// to be unique)
let mut images = vec![EdwardsPoint::identity(); self.clsags.len()];
for c in 0 .. self.clsags.len() {
// Calculate the key images
// Multisig will parse/calculate/validate this as needed, yet doing so here as well provides
// the easiest API overall, as this is where the TX is (which needs the key images in its
// message), along with where the outputs are determined (where our change output needs these
// to be unique)
for (l, preprocess) in &commitments[c] {
self.images[c] += CompressedEdwardsY(
images[c] += CompressedEdwardsY(
preprocess[64 .. 96].try_into().map_err(|_| FrostError::InvalidCommitment(*l))?
).decompress().ok_or(FrostError::InvalidCommitment(*l))?;
}
}
// Create the actual transaction
let output_masks;
let mut tx = {
// Calculate uniqueness
let mut images = self.images.clone();
images.sort_by(key_image_sort);
let mut sorted_images = images.clone();
sorted_images.sort_by(key_image_sort);
let (commitments, output_masks) = self.signable.prepare_outputs(
let commitments;
(commitments, output_masks) = self.signable.prepare_outputs(
&mut ChaCha12Rng::from_seed(self.transcript.rng_seed(b"tx_keys")),
uniqueness(
&images.iter().map(|image| Input::ToKey {
@@ -234,7 +264,6 @@ impl StateMachine for TransactionMachine {
}).collect::<Vec<_>>()
)
);
self.output_masks = Some(output_masks);
self.signable.prepare_transaction(
&commitments,
@@ -245,18 +274,19 @@ impl StateMachine for TransactionMachine {
)
};
let mut sorted = Vec::with_capacity(self.decoys.len());
while self.decoys.len() != 0 {
// Sort the inputs, as expected
let mut sorted = Vec::with_capacity(self.clsags.len());
while self.clsags.len() != 0 {
sorted.push((
images.swap_remove(0),
self.signable.inputs.swap_remove(0),
self.decoys.swap_remove(0),
self.images.swap_remove(0),
self.inputs.swap_remove(0),
self.clsags.swap_remove(0),
commitments.swap_remove(0)
));
}
sorted.sort_by(|x, y| x.2.compress().to_bytes().cmp(&y.2.compress().to_bytes()).reverse());
sorted.sort_by(|x, y| key_image_sort(&x.0, &y.0));
let mut rng = ChaCha12Rng::from_seed(self.transcript.rng_seed(b"pseudo_out_masks"));
let mut sum_pseudo_outs = Scalar::zero();
@@ -265,7 +295,7 @@ impl StateMachine for TransactionMachine {
let mut mask = random_scalar(&mut rng);
if sorted.len() == 0 {
mask = self.output_masks.unwrap() - sum_pseudo_outs;
mask = output_masks - sum_pseudo_outs;
} else {
sum_pseudo_outs += mask;
}
@@ -273,16 +303,16 @@ impl StateMachine for TransactionMachine {
tx.prefix.inputs.push(
Input::ToKey {
amount: 0,
key_offsets: value.1.offsets.clone(),
key_image: value.2
key_offsets: value.2.offsets.clone(),
key_image: value.0
}
);
*value.3.write().unwrap() = Some(
ClsagDetails::new(
ClsagInput::new(
value.0.commitment,
value.1
value.1.commitment,
value.2
).map_err(|_| panic!("Signing an input which isn't present in the ring we created for it"))?,
mask
)
@@ -293,30 +323,31 @@ impl StateMachine for TransactionMachine {
}
let msg = tx.signature_hash();
self.tx = Some(tx);
// Iterate over each CLSAG calling sign
let mut serialized = Vec::with_capacity(self.clsags.len() * 32);
for clsag in self.clsags.iter_mut() {
serialized.extend(&clsag.sign(commitments.remove(0), &msg)?);
}
let clsags = self.clsags.drain(..).map(|clsag| {
let (clsag, share) = clsag.sign(commitments.remove(0), &msg)?;
serialized.extend(&share);
Ok(clsag)
}).collect::<Result<_, _>>()?;
Ok(serialized)
Ok((TransactionSignatureMachine { tx, clsags }, serialized))
}
}
fn complete(&mut self, shares: HashMap<u16, Vec<u8>>) -> Result<Transaction, FrostError> {
if self.state() != State::Signed {
Err(FrostError::InvalidSignTransition(State::Signed, self.state()))?;
}
let mut tx = self.tx.take().unwrap();
impl SignatureMachine<Transaction> for TransactionSignatureMachine {
fn complete(self, mut shares: HashMap<u16, Vec<u8>>) -> Result<Transaction, FrostError> {
let mut tx = self.tx;
match tx.rct_signatures.prunable {
RctPrunable::Null => panic!("Signing for RctPrunable::Null"),
RctPrunable::Clsag { ref mut clsags, ref mut pseudo_outs, .. } => {
for (c, clsag) in self.clsags.iter_mut().enumerate() {
let (clsag, pseudo_out) = clsag.complete(shares.iter().map(
|(l, shares)| (*l, shares[(c * 32) .. ((c + 1) * 32)].to_vec())
).collect::<HashMap<_, _>>())?;
for clsag in self.clsags {
let (clsag, pseudo_out) = clsag.complete(
shares.iter_mut().map(
|(l, shares)| (*l, shares.drain(.. 32).collect())
).collect::<HashMap<_, _>>()
)?;
clsags.push(clsag);
pseudo_outs.push(pseudo_out);
}
@@ -324,12 +355,4 @@ impl StateMachine for TransactionMachine {
}
Ok(tx)
}
fn multisig_params(&self) -> MultisigParams {
self.clsags[0].multisig_params()
}
fn state(&self) -> State {
self.clsags[0].state()
}
}