diff --git a/crypto/dleq/src/cross_group/mod.rs b/crypto/dleq/src/cross_group/mod.rs index ce00aef2..4a0ce530 100644 --- a/crypto/dleq/src/cross_group/mod.rs +++ b/crypto/dleq/src/cross_group/mod.rs @@ -12,7 +12,7 @@ use group::{ff::{Field, PrimeField, PrimeFieldBits}, prime::PrimeGroup}; use crate::Generators; pub mod scalar; -use scalar::scalar_convert; +use scalar::{scalar_convert, mutual_scalar_from_bytes}; pub(crate) mod schnorr; use schnorr::SchnorrPoK; @@ -121,22 +121,11 @@ impl DLEqProof blinding_key } - fn mutual_scalar_from_bytes(bytes: &[u8]) -> (G0::Scalar, G1::Scalar) { - let capacity = usize::try_from(G0::Scalar::CAPACITY.min(G1::Scalar::CAPACITY)).unwrap(); - debug_assert!((bytes.len() * 8) >= capacity); - - let mut accum = G0::Scalar::zero(); - for b in 0 .. capacity { - accum += G0::Scalar::from((bytes[b / 8] & (1 << (b % 8))).into()); - } - (accum, scalar_convert(accum).unwrap()) - } - #[allow(non_snake_case)] fn nonces(mut transcript: T, nonces: (G0, G1)) -> (G0::Scalar, G1::Scalar) { transcript.append_message(b"nonce_0", nonces.0.to_bytes().as_ref()); transcript.append_message(b"nonce_1", nonces.1.to_bytes().as_ref()); - Self::mutual_scalar_from_bytes(transcript.challenge(b"challenge").as_ref()) + mutual_scalar_from_bytes(transcript.challenge(b"challenge").as_ref()) } #[allow(non_snake_case)] @@ -268,7 +257,7 @@ impl DLEqProof rng, transcript, generators, - Self::mutual_scalar_from_bytes(digest.finalize().as_ref()) + mutual_scalar_from_bytes(digest.finalize().as_ref()) ) } diff --git a/crypto/dleq/src/cross_group/scalar.rs b/crypto/dleq/src/cross_group/scalar.rs index 8d922719..6df5dee7 100644 --- a/crypto/dleq/src/cross_group/scalar.rs +++ b/crypto/dleq/src/cross_group/scalar.rs @@ -18,10 +18,12 @@ pub fn scalar_normalize(scalar: F0) -> ( for bit in bits.iter().skip(bits.len() - usize::try_from(mutual_capacity).unwrap()) { res1 = res1.double(); res2 = res2.double(); - if *bit { - res1 += F0::one(); - res2 += F1::one(); - } + + let bit = *bit as u8; + debug_assert_eq!(bit | 1, 1); + + res1 += F0::from(bit.into()); + res2 += F1::from(bit.into()); } (res1, res2) @@ -32,3 +34,16 @@ pub fn scalar_convert(scalar: F0) -> Opt let (valid, converted) = scalar_normalize(scalar); Some(converted).filter(|_| scalar == valid) } + +/// Create a mutually valid scalar from bytes via bit truncation to not introduce bias +pub fn mutual_scalar_from_bytes(bytes: &[u8]) -> (F0, F1) { + let capacity = usize::try_from(F0::CAPACITY.min(F1::CAPACITY)).unwrap(); + debug_assert!((bytes.len() * 8) >= capacity); + + let mut accum = F0::zero(); + for b in 0 .. capacity { + accum = accum.double(); + accum += F0::from(((bytes[b / 8] >> (b % 8)) & 1).into()); + } + (accum, scalar_convert(accum).unwrap()) +}