Make the coordinator's P2P modules their own crates

This commit is contained in:
Luke Parker
2025-01-09 01:26:25 -05:00
parent adf20773ac
commit 465e8498c4
24 changed files with 234 additions and 63 deletions

View File

@@ -0,0 +1,176 @@
use core::{pin::Pin, future::Future};
use std::io;
use zeroize::Zeroizing;
use rand_core::{RngCore, OsRng};
use blake2::{Digest, Blake2s256};
use schnorrkel::{Keypair, PublicKey, Signature};
use serai_client::primitives::PublicKey as Public;
use futures_util::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use libp2p::{
core::UpgradeInfo,
InboundUpgrade, OutboundUpgrade,
identity::{self, PeerId},
noise,
};
use crate::peer_id_from_public;
const PROTOCOL: &str = "/serai/coordinator/validators";
#[derive(Clone)]
pub(crate) struct OnlyValidators {
pub(crate) serai_key: Zeroizing<Keypair>,
pub(crate) noise_keypair: identity::Keypair,
}
impl OnlyValidators {
/// The ephemeral challenge protocol for authentication.
///
/// We use ephemeral challenges to prevent replaying signatures from historic sessions.
///
/// We don't immediately send the challenge. We only send a commitment to it. This prevents our
/// remote peer from choosing their challenge in response to our challenge, in case there was any
/// benefit to doing so.
async fn challenges<S: 'static + Send + Unpin + AsyncRead + AsyncWrite>(
socket: &mut noise::Output<S>,
) -> io::Result<([u8; 32], [u8; 32])> {
let mut our_challenge = [0; 32];
OsRng.fill_bytes(&mut our_challenge);
// Write the hash of our challenge
socket.write_all(&Blake2s256::digest(our_challenge)).await?;
// Read the hash of their challenge
let mut their_challenge_commitment = [0; 32];
socket.read_exact(&mut their_challenge_commitment).await?;
// Reveal our challenge
socket.write_all(&our_challenge).await?;
// Read their challenge
let mut their_challenge = [0; 32];
socket.read_exact(&mut their_challenge).await?;
// Verify their challenge
if <[u8; 32]>::from(Blake2s256::digest(their_challenge)) != their_challenge_commitment {
Err(io::Error::other("challenge didn't match challenge commitment"))?;
}
Ok((our_challenge, their_challenge))
}
// We sign the two noise peer IDs and the ephemeral challenges.
//
// Signing the noise peer IDs ensures we're authenticating this noise connection. The only
// expectations placed on noise are for it to prevent a MITM from impersonating the other end or
// modifying any messages sent.
//
// Signing the ephemeral challenges prevents any replays. While that should be unnecessary, as
// noise MAY prevent replays across sessions (even when the same key is used), and noise IDs
// shouldn't be reused (so it should be fine to reuse an existing signature for these noise IDs),
// it doesn't hurt.
async fn authenticate<S: 'static + Send + Unpin + AsyncRead + AsyncWrite>(
&self,
socket: &mut noise::Output<S>,
dialer_peer_id: PeerId,
dialer_challenge: [u8; 32],
listener_peer_id: PeerId,
listener_challenge: [u8; 32],
) -> io::Result<PeerId> {
// Write our public key
socket.write_all(&self.serai_key.public.to_bytes()).await?;
let msg = borsh::to_vec(&(
dialer_peer_id.to_bytes(),
dialer_challenge,
listener_peer_id.to_bytes(),
listener_challenge,
))
.unwrap();
let signature = self.serai_key.sign_simple(PROTOCOL.as_bytes(), &msg);
socket.write_all(&signature.to_bytes()).await?;
let mut public_key_and_sig = [0; 96];
socket.read_exact(&mut public_key_and_sig).await?;
let public_key = PublicKey::from_bytes(&public_key_and_sig[.. 32])
.map_err(|_| io::Error::other("invalid public key"))?;
let sig = Signature::from_bytes(&public_key_and_sig[32 ..])
.map_err(|_| io::Error::other("invalid signature serialization"))?;
public_key
.verify_simple(PROTOCOL.as_bytes(), &msg, &sig)
.map_err(|_| io::Error::other("invalid signature"))?;
Ok(peer_id_from_public(Public::from_raw(public_key.to_bytes())))
}
}
impl UpgradeInfo for OnlyValidators {
type Info = <noise::Config as UpgradeInfo>::Info;
type InfoIter = <noise::Config as UpgradeInfo>::InfoIter;
fn protocol_info(&self) -> Self::InfoIter {
// A keypair only causes an error if its sign operation fails, which is only possible with RSA,
// which isn't used within this codebase
noise::Config::new(&self.noise_keypair).unwrap().protocol_info()
}
}
impl<S: 'static + Send + Unpin + AsyncRead + AsyncWrite> InboundUpgrade<S> for OnlyValidators {
type Output = (PeerId, noise::Output<S>);
type Error = io::Error;
type Future = Pin<Box<dyn Send + Future<Output = Result<Self::Output, Self::Error>>>>;
fn upgrade_inbound(self, socket: S, info: Self::Info) -> Self::Future {
Box::pin(async move {
let (dialer_noise_peer_id, mut socket) = noise::Config::new(&self.noise_keypair)
.unwrap()
.upgrade_inbound(socket, info)
.await
.map_err(io::Error::other)?;
let (our_challenge, dialer_challenge) = OnlyValidators::challenges(&mut socket).await?;
let dialer_serai_validator = self
.authenticate(
&mut socket,
dialer_noise_peer_id,
dialer_challenge,
PeerId::from_public_key(&self.noise_keypair.public()),
our_challenge,
)
.await?;
Ok((dialer_serai_validator, socket))
})
}
}
impl<S: 'static + Send + Unpin + AsyncRead + AsyncWrite> OutboundUpgrade<S> for OnlyValidators {
type Output = (PeerId, noise::Output<S>);
type Error = io::Error;
type Future = Pin<Box<dyn Send + Future<Output = Result<Self::Output, Self::Error>>>>;
fn upgrade_outbound(self, socket: S, info: Self::Info) -> Self::Future {
Box::pin(async move {
let (listener_noise_peer_id, mut socket) = noise::Config::new(&self.noise_keypair)
.unwrap()
.upgrade_outbound(socket, info)
.await
.map_err(io::Error::other)?;
let (our_challenge, listener_challenge) = OnlyValidators::challenges(&mut socket).await?;
let listener_serai_validator = self
.authenticate(
&mut socket,
PeerId::from_public_key(&self.noise_keypair.public()),
our_challenge,
listener_noise_peer_id,
listener_challenge,
)
.await?;
Ok((listener_serai_validator, socket))
})
}
}

View File

@@ -0,0 +1,122 @@
use core::future::Future;
use std::collections::HashSet;
use rand_core::{RngCore, OsRng};
use tokio::sync::mpsc;
use serai_client::Serai;
use libp2p::{
core::multiaddr::{Protocol, Multiaddr},
swarm::dial_opts::DialOpts,
};
use serai_task::ContinuallyRan;
use crate::{PORT, Peers, validators::Validators};
const TARGET_PEERS_PER_NETWORK: usize = 5;
/*
If we only tracked the target amount of peers per network, we'd risk being eclipsed by an
adversary who immediately connects to us with their array of validators upon our boot. Their
array would satisfy our target amount of peers, so we'd never seek more, enabling the adversary
to be the only entity we peered with.
We solve this by additionally requiring an explicit amount of peers we dialed. That means we
randomly chose to connect to these peers.
*/
// TODO const TARGET_DIALED_PEERS_PER_NETWORK: usize = 3;
pub(crate) struct DialTask {
serai: Serai,
validators: Validators,
peers: Peers,
to_dial: mpsc::UnboundedSender<DialOpts>,
}
impl DialTask {
pub(crate) fn new(serai: Serai, peers: Peers, to_dial: mpsc::UnboundedSender<DialOpts>) -> Self {
DialTask { serai: serai.clone(), validators: Validators::new(serai).0, peers, to_dial }
}
}
impl ContinuallyRan for DialTask {
// Only run every five minutes, not the default of every five seconds
const DELAY_BETWEEN_ITERATIONS: u64 = 5 * 60;
const MAX_DELAY_BETWEEN_ITERATIONS: u64 = 10 * 60;
fn run_iteration(&mut self) -> impl Send + Future<Output = Result<bool, String>> {
async move {
self.validators.update().await?;
// If any of our peers is lacking, try to connect to more
let mut dialed = false;
let peer_counts = self
.peers
.peers
.read()
.await
.iter()
.map(|(network, peers)| (*network, peers.len()))
.collect::<Vec<_>>();
for (network, peer_count) in peer_counts {
/*
If we don't have the target amount of peers, and we don't have all the validators in the
set but one, attempt to connect to more validators within this set.
The latter clause is so if there's a set with only 3 validators, we don't infinitely try
to connect to the target amount of peers for this network as we never will. Instead, we
only try to connect to most of the validators actually present.
*/
if (peer_count < TARGET_PEERS_PER_NETWORK) &&
(peer_count <
self
.validators
.by_network()
.get(&network)
.map(HashSet::len)
.unwrap_or(0)
.saturating_sub(1))
{
let mut potential_peers =
self.serai.p2p_validators(network).await.map_err(|e| format!("{e:?}"))?;
for _ in 0 .. (TARGET_PEERS_PER_NETWORK - peer_count) {
if potential_peers.is_empty() {
break;
}
let index_to_dial =
usize::try_from(OsRng.next_u64() % u64::try_from(potential_peers.len()).unwrap())
.unwrap();
let randomly_selected_peer = potential_peers.swap_remove(index_to_dial);
log::info!("found peer from substrate: {randomly_selected_peer}");
// Map the peer from a Substrate P2P network peer to a Coordinator P2P network peer
let mapped_peer = randomly_selected_peer
.into_iter()
.filter_map(|protocol| match protocol {
// Drop PeerIds from the Substrate P2p network
Protocol::P2p(_) => None,
// Use our own TCP port
Protocol::Tcp(_) => Some(Protocol::Tcp(PORT)),
// Pass-through any other specifications (IPv4, IPv6, etc)
other => Some(other),
})
.collect::<Multiaddr>();
log::debug!("mapped found peer: {mapped_peer}");
self
.to_dial
.send(DialOpts::unknown_peer_id().address(mapped_peer).build())
.expect("dial receiver closed?");
dialed = true;
}
}
}
Ok(dialed)
}
}
}

View File

@@ -0,0 +1,74 @@
use core::time::Duration;
use blake2::{Digest, Blake2s256};
use borsh::{BorshSerialize, BorshDeserialize};
use libp2p::gossipsub::{
IdentTopic, MessageId, MessageAuthenticity, ValidationMode, ConfigBuilder, IdentityTransform,
AllowAllSubscriptionFilter, Behaviour,
};
pub use libp2p::gossipsub::Event;
use serai_cosign::SignedCosign;
// Block size limit + 16 KB of space for signatures/metadata
pub(crate) const MAX_LIBP2P_GOSSIP_MESSAGE_SIZE: usize = tributary::BLOCK_SIZE_LIMIT + 16384;
const LIBP2P_PROTOCOL: &str = "/serai/coordinator/gossip/1.0.0";
const BASE_TOPIC: &str = "/";
fn topic_for_tributary(tributary: [u8; 32]) -> IdentTopic {
IdentTopic::new(format!("/tributary/{}", hex::encode(tributary)))
}
#[derive(Clone, BorshSerialize, BorshDeserialize)]
pub(crate) enum Message {
Tributary { tributary: [u8; 32], message: Vec<u8> },
Cosign(SignedCosign),
}
impl Message {
pub(crate) fn topic(&self) -> IdentTopic {
match self {
Message::Tributary { tributary, .. } => topic_for_tributary(*tributary),
Message::Cosign(_) => IdentTopic::new(BASE_TOPIC),
}
}
}
pub(crate) type Behavior = Behaviour<IdentityTransform, AllowAllSubscriptionFilter>;
pub(crate) fn new_behavior() -> Behavior {
// The latency used by the Tendermint protocol, used here as the gossip epoch duration
// libp2p-rs defaults to 1 second, whereas ours will be ~2
let heartbeat_interval = tributary::tendermint::LATENCY_TIME;
// The amount of heartbeats which will occur within a single Tributary block
let heartbeats_per_block = tributary::tendermint::TARGET_BLOCK_TIME.div_ceil(heartbeat_interval);
// libp2p-rs defaults to 5, whereas ours will be ~8
let heartbeats_to_keep = 2 * heartbeats_per_block;
// libp2p-rs defaults to 3 whereas ours will be ~4
let heartbeats_to_gossip = heartbeats_per_block;
let config = ConfigBuilder::default()
.protocol_id_prefix(LIBP2P_PROTOCOL)
.history_length(usize::try_from(heartbeats_to_keep).unwrap())
.history_gossip(usize::try_from(heartbeats_to_gossip).unwrap())
.heartbeat_interval(Duration::from_millis(heartbeat_interval.into()))
.max_transmit_size(MAX_LIBP2P_GOSSIP_MESSAGE_SIZE)
.duplicate_cache_time(Duration::from_millis((heartbeats_to_keep * heartbeat_interval).into()))
.validation_mode(ValidationMode::Anonymous)
// Uses a content based message ID to avoid duplicates as much as possible
.message_id_fn(|msg| {
MessageId::new(&Blake2s256::digest([msg.topic.as_str().as_bytes(), &msg.data].concat()))
})
.build();
let mut gossip = Behavior::new(MessageAuthenticity::Anonymous, config.unwrap()).unwrap();
// Subscribe to the base topic
let topic = IdentTopic::new(BASE_TOPIC);
let _ = gossip.subscribe(&topic);
gossip
}

View File

@@ -0,0 +1,419 @@
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
use core::{future::Future, time::Duration};
use std::{
sync::Arc,
collections::{HashSet, HashMap},
};
use rand_core::{RngCore, OsRng};
use zeroize::Zeroizing;
use schnorrkel::Keypair;
use serai_client::{
primitives::{NetworkId, PublicKey},
validator_sets::primitives::ValidatorSet,
Serai,
};
use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
use serai_task::{Task, ContinuallyRan};
use serai_cosign::SignedCosign;
use libp2p::{
multihash::Multihash,
identity::{self, PeerId},
tcp::Config as TcpConfig,
yamux, allow_block_list,
connection_limits::{self, ConnectionLimits},
swarm::NetworkBehaviour,
SwarmBuilder,
};
use serai_coordinator_p2p::TributaryBlockWithCommit;
/// A struct to sync the validators from the Serai node in order to keep track of them.
mod validators;
use validators::UpdateValidatorsTask;
/// The authentication protocol upgrade to limit the P2P network to active validators.
mod authenticate;
use authenticate::OnlyValidators;
/// The ping behavior, used to ensure connection latency is below the limit
mod ping;
/// The request-response messages and behavior
mod reqres;
use reqres::{RequestId, Request, Response};
/// The gossip messages and behavior
mod gossip;
use gossip::Message;
/// The swarm task, running it and dispatching to/from it
mod swarm;
use swarm::SwarmTask;
/// The dial task, to find new peers to connect to
mod dial;
use dial::DialTask;
const PORT: u16 = 30563; // 5132 ^ (('c' << 8) | 'o')
// usize::max, manually implemented, as max isn't a const fn
const MAX_LIBP2P_MESSAGE_SIZE: usize =
if gossip::MAX_LIBP2P_GOSSIP_MESSAGE_SIZE > reqres::MAX_LIBP2P_REQRES_MESSAGE_SIZE {
gossip::MAX_LIBP2P_GOSSIP_MESSAGE_SIZE
} else {
reqres::MAX_LIBP2P_REQRES_MESSAGE_SIZE
};
fn peer_id_from_public(public: PublicKey) -> PeerId {
// 0 represents the identity Multihash, that no hash was performed
// It's an internal constant so we can't refer to the constant inside libp2p
PeerId::from_multihash(Multihash::wrap(0, &public.0).unwrap()).unwrap()
}
/// The representation of a peer.
pub struct Peer<'a> {
outbound_requests: &'a mpsc::UnboundedSender<(PeerId, Request, oneshot::Sender<Response>)>,
id: PeerId,
}
impl serai_coordinator_p2p::Peer<'_> for Peer<'_> {
fn send_heartbeat(
&self,
set: ValidatorSet,
latest_block_hash: [u8; 32],
) -> impl Send + Future<Output = Option<Vec<TributaryBlockWithCommit>>> {
async move {
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
let request = Request::Heartbeat { set, latest_block_hash };
let (sender, receiver) = oneshot::channel();
self
.outbound_requests
.send((self.id, request, sender))
.expect("outbound requests recv channel was dropped?");
if let Ok(Ok(Response::Blocks(blocks))) =
tokio::time::timeout(HEARTBEAT_TIMEOUT, receiver).await
{
Some(blocks)
} else {
None
}
}
}
}
#[derive(Clone)]
struct Peers {
peers: Arc<RwLock<HashMap<NetworkId, HashSet<PeerId>>>>,
}
// Consider adding identify/kad/autonat/rendevous/(relay + dcutr). While we currently use the Serai
// network for peers, we could use it solely for bootstrapping/as a fallback.
#[derive(NetworkBehaviour)]
struct Behavior {
// Used to only allow Serai validators as peers
allow_list: allow_block_list::Behaviour<allow_block_list::AllowedPeers>,
// Used to limit each peer to a single connection
connection_limits: connection_limits::Behaviour,
// Used to ensure connection latency is within tolerances
ping: ping::Behavior,
// Used to request data from specific peers
reqres: reqres::Behavior,
// Used to broadcast messages to all other peers subscribed to a topic
gossip: gossip::Behavior,
}
/// The libp2p-backed P2P implementation.
///
/// The P2p trait implementation does not support backpressure and is expected to be fully
/// utilized. Failure to poll the entire API will cause unbounded memory growth.
#[allow(clippy::type_complexity)]
#[derive(Clone)]
pub struct Libp2p {
peers: Peers,
gossip: mpsc::UnboundedSender<Message>,
outbound_requests: mpsc::UnboundedSender<(PeerId, Request, oneshot::Sender<Response>)>,
tributary_gossip: Arc<Mutex<mpsc::UnboundedReceiver<([u8; 32], Vec<u8>)>>>,
signed_cosigns: Arc<Mutex<mpsc::UnboundedReceiver<SignedCosign>>>,
signed_cosigns_send: mpsc::UnboundedSender<SignedCosign>,
heartbeat_requests: Arc<Mutex<mpsc::UnboundedReceiver<(RequestId, ValidatorSet, [u8; 32])>>>,
notable_cosign_requests: Arc<Mutex<mpsc::UnboundedReceiver<(RequestId, [u8; 32])>>>,
inbound_request_responses: mpsc::UnboundedSender<(RequestId, Response)>,
}
impl Libp2p {
/// Create a new libp2p-backed P2P instance.
///
/// This will spawn all of the internal tasks necessary for functioning.
pub fn new(serai_key: &Zeroizing<Keypair>, serai: Serai) -> Libp2p {
// Define the object we track peers with
let peers = Peers { peers: Arc::new(RwLock::new(HashMap::new())) };
// Define the dial task
let (dial_task_def, dial_task) = Task::new();
let (to_dial_send, to_dial_recv) = mpsc::unbounded_channel();
tokio::spawn(
DialTask::new(serai.clone(), peers.clone(), to_dial_send)
.continually_run(dial_task_def, vec![]),
);
let swarm = {
let new_only_validators = |noise_keypair: &identity::Keypair| -> Result<_, ()> {
Ok(OnlyValidators { serai_key: serai_key.clone(), noise_keypair: noise_keypair.clone() })
};
let new_yamux = || {
let mut config = yamux::Config::default();
// 1 MiB default + max message size
config.set_max_buffer_size((1024 * 1024) + MAX_LIBP2P_MESSAGE_SIZE);
// 256 KiB default + max message size
config
.set_receive_window_size(((256 * 1024) + MAX_LIBP2P_MESSAGE_SIZE).try_into().unwrap());
config
};
let mut swarm = SwarmBuilder::with_existing_identity(identity::Keypair::generate_ed25519())
.with_tokio()
.with_tcp(TcpConfig::default().nodelay(false), new_only_validators, new_yamux)
.unwrap()
.with_behaviour(|_| Behavior {
allow_list: allow_block_list::Behaviour::default(),
// Limit each per to a single connection
connection_limits: connection_limits::Behaviour::new(
ConnectionLimits::default().with_max_established_per_peer(Some(1)),
),
ping: ping::new_behavior(),
reqres: reqres::new_behavior(),
gossip: gossip::new_behavior(),
})
.unwrap()
.with_swarm_config(|config| {
config
.with_idle_connection_timeout(ping::INTERVAL + ping::TIMEOUT + Duration::from_secs(5))
})
.build();
swarm.listen_on(format!("/ip4/0.0.0.0/tcp/{PORT}").parse().unwrap()).unwrap();
swarm.listen_on(format!("/ip6/::/tcp/{PORT}").parse().unwrap()).unwrap();
swarm
};
let (swarm_validators, validator_changes) = UpdateValidatorsTask::spawn(serai);
let (gossip_send, gossip_recv) = mpsc::unbounded_channel();
let (signed_cosigns_send, signed_cosigns_recv) = mpsc::unbounded_channel();
let (tributary_gossip_send, tributary_gossip_recv) = mpsc::unbounded_channel();
let (outbound_requests_send, outbound_requests_recv) = mpsc::unbounded_channel();
let (heartbeat_requests_send, heartbeat_requests_recv) = mpsc::unbounded_channel();
let (notable_cosign_requests_send, notable_cosign_requests_recv) = mpsc::unbounded_channel();
let (inbound_request_responses_send, inbound_request_responses_recv) =
mpsc::unbounded_channel();
// Create the swarm task
SwarmTask::spawn(
dial_task,
to_dial_recv,
swarm_validators,
validator_changes,
peers.clone(),
swarm,
gossip_recv,
signed_cosigns_send.clone(),
tributary_gossip_send,
outbound_requests_recv,
heartbeat_requests_send,
notable_cosign_requests_send,
inbound_request_responses_recv,
);
Libp2p {
peers,
gossip: gossip_send,
outbound_requests: outbound_requests_send,
tributary_gossip: Arc::new(Mutex::new(tributary_gossip_recv)),
signed_cosigns: Arc::new(Mutex::new(signed_cosigns_recv)),
signed_cosigns_send,
heartbeat_requests: Arc::new(Mutex::new(heartbeat_requests_recv)),
notable_cosign_requests: Arc::new(Mutex::new(notable_cosign_requests_recv)),
inbound_request_responses: inbound_request_responses_send,
}
}
}
impl tributary::P2p for Libp2p {
fn broadcast(&self, tributary: [u8; 32], message: Vec<u8>) -> impl Send + Future<Output = ()> {
async move {
self
.gossip
.send(Message::Tributary { tributary, message })
.expect("gossip recv channel was dropped?");
}
}
}
impl serai_cosign::RequestNotableCosigns for Libp2p {
type Error = ();
fn request_notable_cosigns(
&self,
global_session: [u8; 32],
) -> impl Send + Future<Output = Result<(), Self::Error>> {
async move {
const AMOUNT_OF_PEERS_TO_REQUEST_FROM: usize = 3;
const NOTABLE_COSIGNS_TIMEOUT: Duration = Duration::from_secs(5);
let request = Request::NotableCosigns { global_session };
let peers = self.peers.peers.read().await.clone();
// HashSet of all peers
let peers = peers.into_values().flat_map(<_>::into_iter).collect::<HashSet<_>>();
// Vec of all peers
let mut peers = peers.into_iter().collect::<Vec<_>>();
let mut channels = Vec::with_capacity(AMOUNT_OF_PEERS_TO_REQUEST_FROM);
for _ in 0 .. AMOUNT_OF_PEERS_TO_REQUEST_FROM {
if peers.is_empty() {
break;
}
let i = usize::try_from(OsRng.next_u64() % u64::try_from(peers.len()).unwrap()).unwrap();
let peer = peers.swap_remove(i);
let (sender, receiver) = oneshot::channel();
self
.outbound_requests
.send((peer, request, sender))
.expect("outbound requests recv channel was dropped?");
channels.push(receiver);
}
// We could reduce our latency by using FuturesUnordered here but the latency isn't a concern
for channel in channels {
if let Ok(Ok(Response::NotableCosigns(cosigns))) =
tokio::time::timeout(NOTABLE_COSIGNS_TIMEOUT, channel).await
{
for cosign in cosigns {
self
.signed_cosigns_send
.send(cosign)
.expect("signed_cosigns recv in this object was dropped?");
}
}
}
Ok(())
}
}
}
impl serai_coordinator_p2p::P2p for Libp2p {
type Peer<'a> = Peer<'a>;
fn peers(&self, network: NetworkId) -> impl Send + Future<Output = Vec<Self::Peer<'_>>> {
async move {
let Some(peer_ids) = self.peers.peers.read().await.get(&network).cloned() else {
return vec![];
};
let mut res = vec![];
for id in peer_ids {
res.push(Peer { outbound_requests: &self.outbound_requests, id });
}
res
}
}
fn heartbeat(
&self,
) -> impl Send
+ Future<Output = (ValidatorSet, [u8; 32], oneshot::Sender<Vec<TributaryBlockWithCommit>>)>
{
async move {
let (request_id, set, latest_block_hash) = self
.heartbeat_requests
.lock()
.await
.recv()
.await
.expect("heartbeat_requests_send was dropped?");
let (sender, receiver) = oneshot::channel();
tokio::spawn({
let respond = self.inbound_request_responses.clone();
async move {
// The swarm task expects us to respond to every request. If the caller drops this
// channel, we'll receive `Err` and respond with `None`, safely satisfying that bound
// without requiring the caller send a value down this channel
let response =
if let Ok(blocks) = receiver.await { Response::Blocks(blocks) } else { Response::None };
respond
.send((request_id, response))
.expect("inbound_request_responses_recv was dropped?");
}
});
(set, latest_block_hash, sender)
}
}
fn notable_cosigns_request(
&self,
) -> impl Send + Future<Output = ([u8; 32], oneshot::Sender<Vec<SignedCosign>>)> {
async move {
let (request_id, global_session) = self
.notable_cosign_requests
.lock()
.await
.recv()
.await
.expect("notable_cosign_requests_send was dropped?");
let (sender, receiver) = oneshot::channel();
tokio::spawn({
let respond = self.inbound_request_responses.clone();
async move {
let response = if let Ok(notable_cosigns) = receiver.await {
Response::NotableCosigns(notable_cosigns)
} else {
Response::None
};
respond
.send((request_id, response))
.expect("inbound_request_responses_recv was dropped?");
}
});
(global_session, sender)
}
}
fn tributary_message(&self) -> impl Send + Future<Output = ([u8; 32], Vec<u8>)> {
async move {
self.tributary_gossip.lock().await.recv().await.expect("tributary_gossip send was dropped?")
}
}
fn cosign(&self) -> impl Send + Future<Output = SignedCosign> {
async move {
self
.signed_cosigns
.lock()
.await
.recv()
.await
.expect("signed_cosigns couldn't recv despite send in same object?")
}
}
}

View File

@@ -0,0 +1,17 @@
use core::time::Duration;
use tributary::tendermint::LATENCY_TIME;
use libp2p::ping::{self, Config, Behaviour};
pub use ping::Event;
pub(crate) const INTERVAL: Duration = Duration::from_secs(30);
// LATENCY_TIME represents the maximum latency for message delivery. Sending the ping, and
// receiving the pong, each have to occur within this time bound to validate the connection. We
// enforce that, as best we can, by requiring the round-trip be within twice the allowed latency.
pub(crate) const TIMEOUT: Duration = Duration::from_millis((2 * LATENCY_TIME) as u64);
pub(crate) type Behavior = Behaviour;
pub(crate) fn new_behavior() -> Behavior {
Behavior::new(Config::default().with_interval(INTERVAL).with_timeout(TIMEOUT))
}

View File

@@ -0,0 +1,136 @@
use core::{fmt, time::Duration};
use std::io;
use async_trait::async_trait;
use borsh::{BorshSerialize, BorshDeserialize};
use serai_client::validator_sets::primitives::ValidatorSet;
use futures_util::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use libp2p::request_response::{
self, Codec as CodecTrait, Event as GenericEvent, Config, Behaviour, ProtocolSupport,
};
pub use request_response::{RequestId, Message};
use serai_cosign::SignedCosign;
use serai_coordinator_p2p::TributaryBlockWithCommit;
/// The maximum message size for the request-response protocol
// This is derived from the heartbeat message size as it's our largest message
pub(crate) const MAX_LIBP2P_REQRES_MESSAGE_SIZE: usize =
(tributary::BLOCK_SIZE_LIMIT * serai_coordinator_p2p::heartbeat::BLOCKS_PER_BATCH) + 1024;
const PROTOCOL: &str = "/serai/coordinator/reqres/1.0.0";
/// Requests which can be made via the request-response protocol.
#[derive(Clone, Copy, Debug, BorshSerialize, BorshDeserialize)]
pub(crate) enum Request {
/// A heartbeat informing our peers of our latest block, for the specified blockchain, on regular
/// intervals.
///
/// If our peers have more blocks than us, they're expected to respond with those blocks.
Heartbeat { set: ValidatorSet, latest_block_hash: [u8; 32] },
/// A request for the notable cosigns for a global session.
NotableCosigns { global_session: [u8; 32] },
}
/// Responses which can be received via the request-response protocol.
#[derive(Clone, BorshSerialize, BorshDeserialize)]
pub(crate) enum Response {
None,
Blocks(Vec<TributaryBlockWithCommit>),
NotableCosigns(Vec<SignedCosign>),
}
impl fmt::Debug for Response {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Response::None => fmt.debug_struct("Response::None").finish(),
Response::Blocks(_) => fmt.debug_struct("Response::Block").finish_non_exhaustive(),
Response::NotableCosigns(_) => {
fmt.debug_struct("Response::NotableCosigns").finish_non_exhaustive()
}
}
}
}
/// The codec used for the request-response protocol.
///
/// We don't use CBOR or JSON, but use borsh to create `Vec<u8>`s we then length-prefix. While
/// ideally, we'd use borsh directly with the `io` traits defined here, they're async and there
/// isn't an amenable API within borsh for incremental deserialization.
#[derive(Default, Clone, Copy, Debug)]
pub(crate) struct Codec;
impl Codec {
async fn read<M: BorshDeserialize>(io: &mut (impl Unpin + AsyncRead)) -> io::Result<M> {
let mut len = [0; 4];
io.read_exact(&mut len).await?;
let len = usize::try_from(u32::from_le_bytes(len)).expect("not at least a 32-bit platform?");
if len > MAX_LIBP2P_REQRES_MESSAGE_SIZE {
Err(io::Error::other("request length exceeded MAX_LIBP2P_REQRES_MESSAGE_SIZE"))?;
}
// This may be a non-trivial allocation easily causable
// While we could chunk the read, meaning we only perform the allocation as bandwidth is used,
// the max message size should be sufficiently sane
let mut buf = vec![0; len];
io.read_exact(&mut buf).await?;
let mut buf = buf.as_slice();
let res = M::deserialize(&mut buf)?;
if !buf.is_empty() {
Err(io::Error::other("p2p message had extra data appended to it"))?;
}
Ok(res)
}
async fn write(io: &mut (impl Unpin + AsyncWrite), msg: &impl BorshSerialize) -> io::Result<()> {
let msg = borsh::to_vec(msg).unwrap();
io.write_all(&u32::try_from(msg.len()).unwrap().to_le_bytes()).await?;
io.write_all(&msg).await
}
}
#[async_trait]
impl CodecTrait for Codec {
type Protocol = &'static str;
type Request = Request;
type Response = Response;
async fn read_request<R: Send + Unpin + AsyncRead>(
&mut self,
_: &Self::Protocol,
io: &mut R,
) -> io::Result<Request> {
Self::read(io).await
}
async fn read_response<R: Send + Unpin + AsyncRead>(
&mut self,
_: &Self::Protocol,
io: &mut R,
) -> io::Result<Response> {
Self::read(io).await
}
async fn write_request<W: Send + Unpin + AsyncWrite>(
&mut self,
_: &Self::Protocol,
io: &mut W,
req: Request,
) -> io::Result<()> {
Self::write(io, &req).await
}
async fn write_response<W: Send + Unpin + AsyncWrite>(
&mut self,
_: &Self::Protocol,
io: &mut W,
res: Response,
) -> io::Result<()> {
Self::write(io, &res).await
}
}
pub(crate) type Event = GenericEvent<Request, Response>;
pub(crate) type Behavior = Behaviour<Codec>;
pub(crate) fn new_behavior() -> Behavior {
let mut config = Config::default();
config.set_request_timeout(Duration::from_secs(5));
Behavior::new([(PROTOCOL, ProtocolSupport::Full)], config)
}

View File

@@ -0,0 +1,359 @@
use std::{
sync::Arc,
collections::{HashSet, HashMap},
time::{Duration, Instant},
};
use borsh::BorshDeserialize;
use serai_client::validator_sets::primitives::ValidatorSet;
use tokio::sync::{mpsc, oneshot, RwLock};
use serai_task::TaskHandle;
use serai_cosign::SignedCosign;
use futures_util::StreamExt;
use libp2p::{
identity::PeerId,
request_response::{RequestId, ResponseChannel},
swarm::{dial_opts::DialOpts, SwarmEvent, Swarm},
};
use crate::{
Peers, BehaviorEvent, Behavior,
validators::{self, Validators},
ping,
reqres::{self, Request, Response},
gossip,
};
const TIME_BETWEEN_REBUILD_PEERS: Duration = Duration::from_secs(10 * 60);
/*
`SwarmTask` handles everything we need the `Swarm` object for. The goal is to minimize the
contention on this task. Unfortunately, the `Swarm` object itself is needed for a variety of
purposes making this a rather large task.
Responsibilities include:
- Actually dialing new peers (the selection process occurs in another task)
- Maintaining the peers structure (as we need the Swarm object to see who our peers are)
- Gossiping messages
- Dispatching gossiped messages
- Sending requests
- Dispatching responses to requests
- Dispatching received requests
- Sending responses
*/
pub(crate) struct SwarmTask {
dial_task: TaskHandle,
to_dial: mpsc::UnboundedReceiver<DialOpts>,
last_dial_task_run: Instant,
validators: Arc<RwLock<Validators>>,
validator_changes: mpsc::UnboundedReceiver<validators::Changes>,
peers: Peers,
rebuild_peers_at: Instant,
swarm: Swarm<Behavior>,
gossip: mpsc::UnboundedReceiver<gossip::Message>,
signed_cosigns: mpsc::UnboundedSender<SignedCosign>,
tributary_gossip: mpsc::UnboundedSender<([u8; 32], Vec<u8>)>,
outbound_requests: mpsc::UnboundedReceiver<(PeerId, Request, oneshot::Sender<Response>)>,
outbound_request_responses: HashMap<RequestId, oneshot::Sender<Response>>,
inbound_request_response_channels: HashMap<RequestId, ResponseChannel<Response>>,
heartbeat_requests: mpsc::UnboundedSender<(RequestId, ValidatorSet, [u8; 32])>,
/* TODO
let cosigns = Cosigning::<D>::notable_cosigns(&self.db, global_session);
let res = reqres::Response::NotableCosigns(cosigns);
let _: Result<_, _> = self.swarm.behaviour_mut().reqres.send_response(channel, res);
*/
notable_cosign_requests: mpsc::UnboundedSender<(RequestId, [u8; 32])>,
inbound_request_responses: mpsc::UnboundedReceiver<(RequestId, Response)>,
}
impl SwarmTask {
fn handle_gossip(&mut self, event: gossip::Event) {
match event {
gossip::Event::Message { message, .. } => {
let Ok(message) = gossip::Message::deserialize(&mut message.data.as_slice()) else {
// TODO: Penalize the PeerId which created this message, which requires authenticating
// each message OR moving to explicit acknowledgement before re-gossiping
return;
};
match message {
gossip::Message::Tributary { tributary, message } => {
let _: Result<_, _> = self.tributary_gossip.send((tributary, message));
}
gossip::Message::Cosign(signed_cosign) => {
let _: Result<_, _> = self.signed_cosigns.send(signed_cosign);
}
}
}
gossip::Event::Subscribed { .. } | gossip::Event::Unsubscribed { .. } => {}
gossip::Event::GossipsubNotSupported { peer_id } => {
let _: Result<_, _> = self.swarm.disconnect_peer_id(peer_id);
}
}
}
fn handle_reqres(&mut self, event: reqres::Event) {
match event {
reqres::Event::Message { message, .. } => match message {
reqres::Message::Request { request_id, request, channel } => match request {
reqres::Request::Heartbeat { set, latest_block_hash } => {
self.inbound_request_response_channels.insert(request_id, channel);
let _: Result<_, _> =
self.heartbeat_requests.send((request_id, set, latest_block_hash));
}
reqres::Request::NotableCosigns { global_session } => {
self.inbound_request_response_channels.insert(request_id, channel);
let _: Result<_, _> = self.notable_cosign_requests.send((request_id, global_session));
}
},
reqres::Message::Response { request_id, response } => {
if let Some(channel) = self.outbound_request_responses.remove(&request_id) {
let _: Result<_, _> = channel.send(response);
}
}
},
reqres::Event::OutboundFailure { request_id, .. } => {
// Send None as the response for the request
if let Some(channel) = self.outbound_request_responses.remove(&request_id) {
let _: Result<_, _> = channel.send(Response::None);
}
}
reqres::Event::InboundFailure { .. } | reqres::Event::ResponseSent { .. } => {}
}
}
async fn run(mut self) {
loop {
let time_till_rebuild_peers = self.rebuild_peers_at.saturating_duration_since(Instant::now());
tokio::select! {
// If the validators have changed, update the allow list
validator_changes = self.validator_changes.recv() => {
let validator_changes = validator_changes.expect("validators update task shut down?");
let behavior = &mut self.swarm.behaviour_mut().allow_list;
for removed in validator_changes.removed {
behavior.disallow_peer(removed);
}
for added in validator_changes.added {
behavior.allow_peer(added);
}
}
// Dial peers we're instructed to
dial_opts = self.to_dial.recv() => {
let dial_opts = dial_opts.expect("DialTask was closed?");
let _: Result<_, _> = self.swarm.dial(dial_opts);
}
/*
Rebuild the peers every 10 minutes.
This protects against any race conditions/edge cases we have in our logic to track peers,
along with unrepresented behavior such as when a peer changes the networks they're active
in. This lets the peer tracking logic simply be 'good enough' to not become horribly
corrupt over the span of `TIME_BETWEEN_REBUILD_PEERS`.
We also use this to disconnect all peers who are no longer active in any network.
*/
() = tokio::time::sleep(time_till_rebuild_peers) => {
let validators_by_network = self.validators.read().await.by_network().clone();
let connected_peers = self.swarm.connected_peers().copied().collect::<HashSet<_>>();
// Build the new peers object
let mut peers = HashMap::new();
for (network, validators) in validators_by_network {
peers.insert(network, validators.intersection(&connected_peers).copied().collect());
}
// Write the new peers object
*self.peers.peers.write().await = peers;
self.rebuild_peers_at = Instant::now() + TIME_BETWEEN_REBUILD_PEERS;
}
// Handle swarm events
event = self.swarm.next() => {
// `Swarm::next` will never return `Poll::Ready(None)`
// https://docs.rs/
// libp2p/0.54.1/libp2p/struct.Swarm.html#impl-Stream-for-Swarm%3CTBehaviour%3E
let event = event.unwrap();
match event {
// New connection, so update peers
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let Some(networks) =
self.validators.read().await.networks(&peer_id).cloned() else { continue };
let mut peers = self.peers.peers.write().await;
for network in networks {
peers.entry(network).or_insert_with(HashSet::new).insert(peer_id);
}
}
// Connection closed, so update peers
SwarmEvent::ConnectionClosed { peer_id, .. } => {
let Some(networks) =
self.validators.read().await.networks(&peer_id).cloned() else { continue };
let mut peers = self.peers.peers.write().await;
for network in networks {
peers.entry(network).or_insert_with(HashSet::new).remove(&peer_id);
}
/*
We want to re-run the dial task, since we lost a peer, in case we should find new
peers. This opens a DoS where a validator repeatedly opens/closes connections to
force iterations of the dial task. We prevent this by setting a minimum distance
since the last explicit iteration.
This is suboptimal. If we have several disconnects in immediate proximity, we'll
trigger the dial task upon the first (where we may still have enough peers we
shouldn't dial more) but not the last (where we may have so few peers left we
should dial more). This is accepted as the dial task will eventually run on its
natural timer.
*/
const MINIMUM_TIME_SINCE_LAST_EXPLICIT_DIAL: Duration = Duration::from_secs(60);
let now = Instant::now();
if (self.last_dial_task_run + MINIMUM_TIME_SINCE_LAST_EXPLICIT_DIAL) < now {
self.dial_task.run_now();
self.last_dial_task_run = now;
}
}
SwarmEvent::Behaviour(
BehaviorEvent::AllowList(event) | BehaviorEvent::ConnectionLimits(event)
) => {
// Ensure these are unreachable cases, not actual events
let _: void::Void = event;
}
SwarmEvent::Behaviour(
BehaviorEvent::Ping(ping::Event { peer: _, connection, result, })
) => {
if result.is_err() {
self.swarm.close_connection(connection);
}
}
SwarmEvent::Behaviour(BehaviorEvent::Reqres(event)) => {
self.handle_reqres(event)
}
SwarmEvent::Behaviour(BehaviorEvent::Gossip(event)) => {
self.handle_gossip(event)
}
// We don't handle any of these
SwarmEvent::IncomingConnection { .. } |
SwarmEvent::IncomingConnectionError { .. } |
SwarmEvent::OutgoingConnectionError { .. } |
SwarmEvent::NewListenAddr { .. } |
SwarmEvent::ExpiredListenAddr { .. } |
SwarmEvent::ListenerClosed { .. } |
SwarmEvent::ListenerError { .. } |
SwarmEvent::Dialing { .. } => {}
}
}
message = self.gossip.recv() => {
let message = message.expect("channel for messages to gossip was closed?");
let topic = message.topic();
let message = borsh::to_vec(&message).unwrap();
/*
If we're sending a message for this topic, it's because this topic is relevant to us.
Subscribe to it.
We create topics roughly weekly, one per validator set/session. Once present in a
topic, we're interested in all messages for it until the validator set/session retires.
Then there should no longer be any messages for the topic as we should drop the
Tributary which creates the messages.
We use this as an argument to not bother implement unsubscribing from topics. They're
incredibly infrequently created and old topics shouldn't still have messages published
to them. Having the coordinator reboot being our method of unsubscribing is fine.
Alternatively, we could route an API to determine when a topic is retired, or retire
any topics we haven't sent messages on in the past hour.
*/
let behavior = self.swarm.behaviour_mut();
let _: Result<_, _> = behavior.gossip.subscribe(&topic);
/*
This may be an error of `InsufficientPeers`. If so, we could ask DialTask to dial more
peers for this network. We don't as we assume DialTask will detect the lack of peers
for this network, and will already successfully handle this.
*/
let _: Result<_, _> = behavior.gossip.publish(topic.hash(), message);
}
request = self.outbound_requests.recv() => {
let (peer, request, response_channel) =
request.expect("channel for requests was closed?");
let request_id = self.swarm.behaviour_mut().reqres.send_request(&peer, request);
self.outbound_request_responses.insert(request_id, response_channel);
}
response = self.inbound_request_responses.recv() => {
let (request_id, response) =
response.expect("channel for inbound request responses was closed?");
if let Some(channel) = self.inbound_request_response_channels.remove(&request_id) {
let _: Result<_, _> =
self.swarm.behaviour_mut().reqres.send_response(channel, response);
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn spawn(
dial_task: TaskHandle,
to_dial: mpsc::UnboundedReceiver<DialOpts>,
validators: Arc<RwLock<Validators>>,
validator_changes: mpsc::UnboundedReceiver<validators::Changes>,
peers: Peers,
swarm: Swarm<Behavior>,
gossip: mpsc::UnboundedReceiver<gossip::Message>,
signed_cosigns: mpsc::UnboundedSender<SignedCosign>,
tributary_gossip: mpsc::UnboundedSender<([u8; 32], Vec<u8>)>,
outbound_requests: mpsc::UnboundedReceiver<(PeerId, Request, oneshot::Sender<Response>)>,
heartbeat_requests: mpsc::UnboundedSender<(RequestId, ValidatorSet, [u8; 32])>,
notable_cosign_requests: mpsc::UnboundedSender<(RequestId, [u8; 32])>,
inbound_request_responses: mpsc::UnboundedReceiver<(RequestId, Response)>,
) {
tokio::spawn(
SwarmTask {
dial_task,
to_dial,
last_dial_task_run: Instant::now(),
validators,
validator_changes,
peers,
rebuild_peers_at: Instant::now() + TIME_BETWEEN_REBUILD_PEERS,
swarm,
gossip,
signed_cosigns,
tributary_gossip,
outbound_requests,
outbound_request_responses: HashMap::new(),
inbound_request_response_channels: HashMap::new(),
heartbeat_requests,
notable_cosign_requests,
inbound_request_responses,
}
.run(),
);
}
}

View File

@@ -0,0 +1,213 @@
use core::{borrow::Borrow, future::Future};
use std::{
sync::Arc,
collections::{HashSet, HashMap},
};
use serai_client::{primitives::NetworkId, validator_sets::primitives::Session, Serai};
use serai_task::{Task, ContinuallyRan};
use libp2p::PeerId;
use futures_util::stream::{StreamExt, FuturesUnordered};
use tokio::sync::{mpsc, RwLock};
use crate::peer_id_from_public;
pub(crate) struct Changes {
pub(crate) removed: HashSet<PeerId>,
pub(crate) added: HashSet<PeerId>,
}
pub(crate) struct Validators {
serai: Serai,
// A cache for which session we're populated with the validators of
sessions: HashMap<NetworkId, Session>,
// The validators by network
by_network: HashMap<NetworkId, HashSet<PeerId>>,
// The validators and their networks
validators: HashMap<PeerId, HashSet<NetworkId>>,
// The channel to send the changes down
changes: mpsc::UnboundedSender<Changes>,
}
impl Validators {
pub(crate) fn new(serai: Serai) -> (Self, mpsc::UnboundedReceiver<Changes>) {
let (send, recv) = mpsc::unbounded_channel();
let validators = Validators {
serai,
sessions: HashMap::new(),
by_network: HashMap::new(),
validators: HashMap::new(),
changes: send,
};
(validators, recv)
}
async fn session_changes(
serai: impl Borrow<Serai>,
sessions: impl Borrow<HashMap<NetworkId, Session>>,
) -> Result<Vec<(NetworkId, Session, HashSet<PeerId>)>, String> {
let temporal_serai =
serai.borrow().as_of_latest_finalized_block().await.map_err(|e| format!("{e:?}"))?;
let temporal_serai = temporal_serai.validator_sets();
let mut session_changes = vec![];
{
// FuturesUnordered can be bad practice as it'll cause timeouts if infrequently polled, but
// we poll it till it yields all futures with the most minimal processing possible
let mut futures = FuturesUnordered::new();
for network in serai_client::primitives::NETWORKS {
if network == NetworkId::Serai {
continue;
}
let sessions = sessions.borrow();
futures.push(async move {
let session = match temporal_serai.session(network).await {
Ok(Some(session)) => session,
Ok(None) => return Ok(None),
Err(e) => return Err(format!("{e:?}")),
};
if sessions.get(&network) == Some(&session) {
Ok(None)
} else {
match temporal_serai.active_network_validators(network).await {
Ok(validators) => Ok(Some((
network,
session,
validators.into_iter().map(peer_id_from_public).collect(),
))),
Err(e) => Err(format!("{e:?}")),
}
}
});
}
while let Some(session_change) = futures.next().await {
if let Some(session_change) = session_change? {
session_changes.push(session_change);
}
}
}
Ok(session_changes)
}
fn incorporate_session_changes(
&mut self,
session_changes: Vec<(NetworkId, Session, HashSet<PeerId>)>,
) {
let mut removed = HashSet::new();
let mut added = HashSet::new();
for (network, session, validators) in session_changes {
// Remove the existing validators
for validator in self.by_network.remove(&network).unwrap_or_else(HashSet::new) {
// Get all networks this validator is in
let mut networks = self.validators.remove(&validator).unwrap();
// Remove this one
networks.remove(&network);
if !networks.is_empty() {
// Insert the networks back if the validator was present in other networks
self.validators.insert(validator, networks);
} else {
// Because this validator is no longer present in any network, mark them as removed
/*
This isn't accurate. The validator isn't present in the latest session for this
network. The validator was present in the prior session which has yet to retire. Our
lack of explicit inclusion for both the prior session and the current session causes
only the validators mutually present in both sessions to be responsible for all actions
still ongoing as the prior validator set retires.
TODO: Fix this
*/
removed.insert(validator);
}
}
// Add the new validators
for validator in validators.iter().copied() {
self.validators.entry(validator).or_insert_with(HashSet::new).insert(network);
added.insert(validator);
}
self.by_network.insert(network, validators);
// Update the session we have populated
self.sessions.insert(network, session);
}
// Only flag validators for removal if they weren't simultaneously added by these changes
removed.retain(|validator| !added.contains(validator));
// Send the changes, dropping the error
// This lets the caller opt-out of change notifications by dropping the receiver
let _: Result<_, _> = self.changes.send(Changes { removed, added });
}
/// Update the view of the validators.
pub(crate) async fn update(&mut self) -> Result<(), String> {
let session_changes = Self::session_changes(&self.serai, &self.sessions).await?;
self.incorporate_session_changes(session_changes);
Ok(())
}
pub(crate) fn by_network(&self) -> &HashMap<NetworkId, HashSet<PeerId>> {
&self.by_network
}
pub(crate) fn networks(&self, peer_id: &PeerId) -> Option<&HashSet<NetworkId>> {
self.validators.get(peer_id)
}
}
/// A task which updates a set of validators.
///
/// The validators managed by this tak will have their exclusive lock held for a minimal amount of
/// time while the update occurs to minimize the disruption to the services relying on it.
pub(crate) struct UpdateValidatorsTask {
validators: Arc<RwLock<Validators>>,
}
impl UpdateValidatorsTask {
/// Spawn a new instance of the UpdateValidatorsTask.
///
/// This returns a reference to the Validators it updates after spawning itself.
pub(crate) fn spawn(serai: Serai) -> (Arc<RwLock<Validators>>, mpsc::UnboundedReceiver<Changes>) {
// The validators which will be updated
let (validators, changes) = Validators::new(serai);
let validators = Arc::new(RwLock::new(validators));
// Define the task
let (update_validators_task, update_validators_task_handle) = Task::new();
// Forget the handle, as dropping the handle would stop the task
core::mem::forget(update_validators_task_handle);
// Spawn the task
tokio::spawn(
(Self { validators: validators.clone() }).continually_run(update_validators_task, vec![]),
);
// Return the validators
(validators, changes)
}
}
impl ContinuallyRan for UpdateValidatorsTask {
// Only run every minute, not the default of every five seconds
const DELAY_BETWEEN_ITERATIONS: u64 = 60;
const MAX_DELAY_BETWEEN_ITERATIONS: u64 = 5 * 60;
fn run_iteration(&mut self) -> impl Send + Future<Output = Result<bool, String>> {
async move {
let session_changes = {
let validators = self.validators.read().await;
Validators::session_changes(validators.serai.clone(), validators.sessions.clone())
.await
.map_err(|e| format!("{e:?}"))?
};
self.validators.write().await.incorporate_session_changes(session_changes);
Ok(true)
}
}
}