Correct misc TODOs in monero-serai

This commit is contained in:
Luke Parker
2024-07-05 23:30:02 -04:00
parent 90880cc9c8
commit 1f5e5fc7ac
27 changed files with 266 additions and 111 deletions

View File

@@ -3,7 +3,10 @@
#![deny(missing_docs)]
#![cfg_attr(not(feature = "std"), no_std)]
use core::fmt::Debug;
use core::{
fmt::Debug,
ops::{Bound, RangeBounds},
};
use std_shims::{
alloc::{boxed::Box, format},
vec,
@@ -26,6 +29,7 @@ use monero_serai::{
transaction::{Input, Timelock, Transaction},
block::Block,
};
use monero_address::Address;
// Number of blocks the fee estimate will be valid for
// https://github.com/monero-project/monero/blob/94e67bf96bbc010241f29ada6abc89f49a81759c/
@@ -70,9 +74,9 @@ pub enum RpcError {
#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroize)]
pub struct FeeRate {
/// The fee per-weight of the transaction.
pub per_weight: u64,
per_weight: u64,
/// The mask to round with.
pub mask: u64,
mask: u64,
}
impl FeeRate {
@@ -108,22 +112,22 @@ impl FeeRate {
/// This is not a Monero protocol defined struct, and this is accordingly not a Monero protocol
/// defined serialization.
pub fn read(r: &mut impl io::Read) -> io::Result<FeeRate> {
Ok(FeeRate { per_weight: read_u64(r)?, mask: read_u64(r)? })
let per_weight = read_u64(r)?;
let mask = read_u64(r)?;
FeeRate::new(per_weight, mask).map_err(io::Error::other)
}
/// Calculate the fee to use from the weight.
///
/// This function may panic if any of the `FeeRate`'s fields are zero.
/// This function may panic upon overflow.
pub fn calculate_fee_from_weight(&self, weight: usize) -> u64 {
let fee = self.per_weight * u64::try_from(weight).unwrap();
let fee = ((fee + self.mask - 1) / self.mask) * self.mask;
let fee = fee.div_ceil(self.mask) * self.mask;
debug_assert_eq!(weight, self.calculate_weight_from_fee(fee), "Miscalculated weight from fee");
fee
}
/// Calculate the weight from the fee.
///
/// This function may panic if any of the `FeeRate`'s fields are zero.
pub fn calculate_weight_from_fee(&self, fee: u64) -> usize {
usize::try_from(fee / self.per_weight).unwrap()
}
@@ -323,7 +327,11 @@ pub trait Rpc: Sync + Clone + Debug {
struct HeightResponse {
height: usize,
}
Ok(self.rpc_call::<Option<()>, HeightResponse>("get_height", None).await?.height)
let res = self.rpc_call::<Option<()>, HeightResponse>("get_height", None).await?.height;
if res == 0 {
Err(RpcError::InvalidNode("node responded with 0 for the height".to_string()))?;
}
Ok(res)
}
/// Get the specified transactions.
@@ -460,7 +468,7 @@ pub trait Rpc: Sync + Clone + Debug {
// Make sure this is actually the block for this number
match block.miner_transaction.prefix().inputs.first() {
Some(Input::Gen(actual)) => {
if usize::try_from(*actual) == Ok(number) {
if *actual == number {
Ok(block)
} else {
Err(RpcError::InvalidNode("different block than requested (number)".to_string()))
@@ -658,8 +666,11 @@ pub trait Rpc: Sync + Clone + Debug {
/// Get the output distribution.
///
/// `from` and `to` are heights, not block numbers, and inclusive.
async fn get_output_distribution(&self, from: usize, to: usize) -> Result<Vec<u64>, RpcError> {
/// `range` is in terms of block numbers.
async fn get_output_distribution(
&self,
range: impl Send + RangeBounds<usize>,
) -> Result<Vec<u64>, RpcError> {
#[derive(Deserialize, Debug)]
struct Distribution {
distribution: Vec<u64>,
@@ -667,10 +678,31 @@ pub trait Rpc: Sync + Clone + Debug {
#[derive(Deserialize, Debug)]
struct Distributions {
distributions: Vec<Distribution>,
distributions: [Distribution; 1],
}
let mut distributions: Distributions = self
let from = match range.start_bound() {
Bound::Included(from) => *from,
Bound::Excluded(from) => from
.checked_add(1)
.ok_or_else(|| RpcError::InternalError("range's from wasn't representable".to_string()))?,
Bound::Unbounded => 0,
};
let to = match range.end_bound() {
Bound::Included(to) => *to,
Bound::Excluded(to) => to
.checked_sub(1)
.ok_or_else(|| RpcError::InternalError("range's to wasn't representable".to_string()))?,
Bound::Unbounded => self.get_height().await? - 1,
};
if from > to {
Err(RpcError::InternalError(format!(
"malformed range: inclusive start {from}, inclusive end {to}"
)))?;
}
let zero_zero_case = (from == 0) && (to == 0);
let distributions: Distributions = self
.json_rpc_call(
"get_output_distribution",
Some(json!({
@@ -678,12 +710,27 @@ pub trait Rpc: Sync + Clone + Debug {
"amounts": [0],
"cumulative": true,
"from_height": from,
"to_height": to,
"to_height": if zero_zero_case { 1 } else { to },
})),
)
.await?;
let mut distributions = distributions.distributions;
let mut distribution = core::mem::take(&mut distributions[0].distribution);
Ok(distributions.distributions.swap_remove(0).distribution)
let expected_len = if zero_zero_case { 2 } else { (to - from) + 1 };
if expected_len != distribution.len() {
Err(RpcError::InvalidNode(format!(
"distribution length ({}) wasn't of the requested length ({})",
distribution.len(),
expected_len
)))?;
}
// Requesting 0, 0 returns the distribution for the entire chain
// We work-around this by requesting 0, 1 (yielding two blocks), then popping the second block
if zero_zero_case {
distribution.pop();
}
Ok(distribution)
}
/// Get the specified outputs from the RingCT (zero-amount) pool.
@@ -763,6 +810,7 @@ pub trait Rpc: Sync + Clone + Debug {
};
Ok(Some([key, rpc_point(&out.mask)?]).filter(|_| {
if fingerprintable_canonical {
// TODO: Are timelock blocks by height or number?
Timelock::Block(height) >= txs[i].prefix().additional_timelock
} else {
out.unlocked
@@ -864,10 +912,9 @@ pub trait Rpc: Sync + Clone + Debug {
/// Generate blocks, with the specified address receiving the block reward.
///
/// Returns the hashes of the generated blocks and the last block's number.
// TODO: Take &Address, not &str?
async fn generate_blocks(
async fn generate_blocks<const ADDR_BYTES: u128>(
&self,
address: &str,
address: &Address<ADDR_BYTES>,
block_count: usize,
) -> Result<(Vec<[u8; 32]>, usize), RpcError> {
#[derive(Debug, Deserialize)]
@@ -880,7 +927,7 @@ pub trait Rpc: Sync + Clone + Debug {
.json_rpc_call::<BlocksResponse>(
"generateblocks",
Some(json!({
"wallet_address": address,
"wallet_address": address.to_string(),
"amount_of_blocks": block_count
})),
)