From de2d6568a410bd64714cc2b01742453e6a01785e Mon Sep 17 00:00:00 2001 From: Luke Parker Date: Wed, 8 Jan 2025 17:40:08 -0500 Subject: [PATCH] Actually implement the Peer abstraction for Libp2p --- coordinator/src/p2p/heartbeat.rs | 8 +++- coordinator/src/p2p/libp2p/mod.rs | 69 ++++++++++++++++++++++++---- coordinator/src/p2p/libp2p/reqres.rs | 9 +--- coordinator/src/p2p/libp2p/swarm.rs | 15 ++---- coordinator/src/p2p/mod.rs | 10 ++-- 5 files changed, 77 insertions(+), 34 deletions(-) diff --git a/coordinator/src/p2p/heartbeat.rs b/coordinator/src/p2p/heartbeat.rs index 0f000dcc..025bfd73 100644 --- a/coordinator/src/p2p/heartbeat.rs +++ b/coordinator/src/p2p/heartbeat.rs @@ -1,9 +1,10 @@ use core::future::Future; - use std::time::{Duration, SystemTime}; use serai_client::validator_sets::primitives::ValidatorSet; +use futures_util::FutureExt; + use tributary::{ReadWrite, Block, Tributary, TributaryReader}; use serai_db::*; @@ -71,7 +72,10 @@ impl ContinuallyRan for HeartbeatTask { tip = self.reader.tip(); tip_is_stale = false; } - let Ok(blocks) = peer.send_heartbeat(self.set, tip).await else { continue 'peer }; + // Necessary due to https://github.com/rust-lang/rust/issues/100013 + let Some(blocks) = peer.send_heartbeat(self.set, tip).boxed().await else { + continue 'peer; + }; // This is the final batch if it has less than the maximum amount of blocks // (signifying there weren't more blocks after this to fill the batch with) diff --git a/coordinator/src/p2p/libp2p/mod.rs b/coordinator/src/p2p/libp2p/mod.rs index b103a63f..79f06c19 100644 --- a/coordinator/src/p2p/libp2p/mod.rs +++ b/coordinator/src/p2p/libp2p/mod.rs @@ -1,4 +1,4 @@ -use core::future::Future; +use core::{future::Future, time::Duration}; use std::{ sync::Arc, collections::{HashSet, HashMap}, @@ -9,10 +9,11 @@ use schnorrkel::Keypair; use serai_client::{ primitives::{NetworkId, PublicKey}, + validator_sets::primitives::ValidatorSet, Serai, }; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{mpsc, oneshot, RwLock}; use serai_task::{Task, ContinuallyRan}; @@ -25,6 +26,8 @@ use libp2p::{ SwarmBuilder, }; +use crate::p2p::TributaryBlockWithCommit; + /// A struct to sync the validators from the Serai node in order to keep track of them. mod validators; use validators::UpdateValidatorsTask; @@ -64,10 +67,31 @@ fn peer_id_from_public(public: PublicKey) -> PeerId { PeerId::from_multihash(Multihash::wrap(0, &public.0).unwrap()).unwrap() } -struct Peer; -impl Peer { - async fn send(&self, request: Request) -> Result { - (async move { todo!("TODO") }).await +struct Peer<'a> { + outbound_requests: &'a mpsc::UnboundedSender<(PeerId, Request, oneshot::Sender)>, + id: PeerId, +} +impl crate::p2p::Peer<'_> for Peer<'_> { + fn send_heartbeat( + &self, + set: ValidatorSet, + latest_block_hash: [u8; 32], + ) -> impl Send + Future>> { + const HEARBEAT_TIMEOUT: Duration = Duration::from_secs(5); + async move { + let request = Request::Heartbeat { set, latest_block_hash }; + let (sender, receiver) = oneshot::channel(); + self + .outbound_requests + .send((self.id, request, sender)) + .expect("outbound requests recv channel was dropped?"); + match tokio::time::timeout(HEARBEAT_TIMEOUT, receiver).await.ok()?.ok()? { + Response::None => Some(vec![]), + Response::Blocks(blocks) => Some(blocks), + // TODO: Disconnect this peer + Response::NotableCosigns(_) => None, + } + } } } @@ -82,9 +106,14 @@ struct Behavior { gossip: gossip::Behavior, } -struct LibP2p; -impl LibP2p { - pub(crate) fn new(serai_key: &Zeroizing, serai: Serai) -> LibP2p { +#[derive(Clone)] +struct Libp2p { + peers: Peers, + outbound_requests: mpsc::UnboundedSender<(PeerId, Request, oneshot::Sender)>, +} + +impl Libp2p { + pub(crate) fn new(serai_key: &Zeroizing, serai: Serai) -> Libp2p { // Define the object we track peers with let peers = Peers { peers: Arc::new(RwLock::new(HashMap::new())) }; @@ -161,3 +190,25 @@ impl LibP2p { todo!("TODO"); } } + +impl tributary::P2p for Libp2p { + fn broadcast(&self, genesis: [u8; 32], msg: Vec) -> impl Send + Future { + async move { todo!("TODO") } + } +} + +impl crate::p2p::P2p for Libp2p { + type Peer<'a> = Peer<'a>; + fn peers(&self, network: NetworkId) -> impl Send + Future>> { + async move { + let Some(peer_ids) = self.peers.peers.read().await.get(&network).cloned() else { + return vec![]; + }; + let mut res = vec![]; + for id in peer_ids { + res.push(Peer { outbound_requests: &self.outbound_requests, id }); + } + res + } + } +} diff --git a/coordinator/src/p2p/libp2p/reqres.rs b/coordinator/src/p2p/libp2p/reqres.rs index cf7575e4..e3d761e5 100644 --- a/coordinator/src/p2p/libp2p/reqres.rs +++ b/coordinator/src/p2p/libp2p/reqres.rs @@ -15,6 +15,8 @@ pub use request_response::Message; use serai_cosign::SignedCosign; +use crate::p2p::TributaryBlockWithCommit; + /// The maximum message size for the request-response protocol // This is derived from the heartbeat message size as it's our largest message pub(crate) const MAX_LIBP2P_REQRES_MESSAGE_SIZE: usize = @@ -36,13 +38,6 @@ pub(crate) enum Request { NotableCosigns { global_session: [u8; 32] }, } -/// A tributary block and its commit. -#[derive(Clone, BorshSerialize, BorshDeserialize)] -pub(crate) struct TributaryBlockWithCommit { - pub(crate) block: Vec, - pub(crate) commit: Vec, -} - /// Responses which can be received via the request-response protocol. #[derive(Clone, BorshSerialize, BorshDeserialize)] pub(crate) enum Response { diff --git a/coordinator/src/p2p/libp2p/swarm.rs b/coordinator/src/p2p/libp2p/swarm.rs index 3962e81b..615295f4 100644 --- a/coordinator/src/p2p/libp2p/swarm.rs +++ b/coordinator/src/p2p/libp2p/swarm.rs @@ -63,8 +63,8 @@ pub(crate) struct SwarmTask { signed_cosigns: mpsc::UnboundedSender, tributary_gossip: mpsc::UnboundedSender<(ValidatorSet, Vec)>, - outbound_requests: mpsc::UnboundedReceiver<(PeerId, Request, oneshot::Sender>)>, - outbound_request_responses: HashMap>>, + outbound_requests: mpsc::UnboundedReceiver<(PeerId, Request, oneshot::Sender)>, + outbound_request_responses: HashMap>, inbound_request_response_channels: HashMap>, heartbeat_requests: mpsc::UnboundedSender<(RequestId, ValidatorSet, [u8; 32])>, @@ -120,16 +120,15 @@ impl SwarmTask { } }, reqres::Message::Response { request_id, response } => { - // Send Some(response) as the response for the request if let Some(channel) = self.outbound_request_responses.remove(&request_id) { - let _: Result<_, _> = channel.send(Some(response)); + let _: Result<_, _> = channel.send(response); } } }, reqres::Event::OutboundFailure { request_id, .. } => { // Send None as the response for the request if let Some(channel) = self.outbound_request_responses.remove(&request_id) { - let _: Result<_, _> = channel.send(None); + let _: Result<_, _> = channel.send(Response::None); } } reqres::Event::InboundFailure { .. } | reqres::Event::ResponseSent { .. } => {} @@ -299,11 +298,7 @@ impl SwarmTask { signed_cosigns: mpsc::UnboundedSender, tributary_gossip: mpsc::UnboundedSender<(ValidatorSet, Vec)>, - outbound_requests: mpsc::UnboundedReceiver<( - PeerId, - Request, - oneshot::Sender>, - )>, + outbound_requests: mpsc::UnboundedReceiver<(PeerId, Request, oneshot::Sender)>, heartbeat_requests: mpsc::UnboundedSender<(RequestId, ValidatorSet, [u8; 32])>, notable_cosign_requests: mpsc::UnboundedSender<(RequestId, [u8; 32])>, diff --git a/coordinator/src/p2p/mod.rs b/coordinator/src/p2p/mod.rs index 534e44dc..414e4ec3 100644 --- a/coordinator/src/p2p/mod.rs +++ b/coordinator/src/p2p/mod.rs @@ -1,7 +1,5 @@ use core::future::Future; -use tokio::time::error::Elapsed; - use borsh::{BorshSerialize, BorshDeserialize}; use serai_client::{primitives::NetworkId, validator_sets::primitives::ValidatorSet}; @@ -19,15 +17,15 @@ pub(crate) struct TributaryBlockWithCommit { pub(crate) commit: Vec, } -trait Peer: Send { +trait Peer<'a>: Send { fn send_heartbeat( &self, set: ValidatorSet, latest_block_hash: [u8; 32], - ) -> impl Send + Future, Elapsed>>; + ) -> impl Send + Future>>; } trait P2p: Send + Sync + tributary::P2p { - type Peer: Peer; - fn peers(&self, network: NetworkId) -> impl Send + Future>; + type Peer<'a>: Peer<'a>; + fn peers(&self, network: NetworkId) -> impl Send + Future>>; }