diff --git a/coordinator/src/p2p/libp2p/authenticate.rs b/coordinator/src/p2p/libp2p/authenticate.rs index d00d0dac..a8167d9e 100644 --- a/coordinator/src/p2p/libp2p/authenticate.rs +++ b/coordinator/src/p2p/libp2p/authenticate.rs @@ -19,13 +19,12 @@ use libp2p::{ noise, }; -use crate::p2p::libp2p::{validators::Validators, peer_id_from_public}; +use crate::p2p::libp2p::peer_id_from_public; const PROTOCOL: &str = "/serai/coordinator/validators"; #[derive(Clone)] pub(crate) struct OnlyValidators { - pub(crate) validators: Arc>, pub(crate) serai_key: Zeroizing, pub(crate) noise_keypair: identity::Keypair, } @@ -108,12 +107,7 @@ impl OnlyValidators { .verify_simple(PROTOCOL.as_bytes(), &msg, &sig) .map_err(|_| io::Error::other("invalid signature"))?; - let peer_id = peer_id_from_public(Public::from_raw(public_key.to_bytes())); - if !self.validators.read().await.contains(&peer_id) { - Err(io::Error::other("peer which tried to connect isn't a known active validator"))?; - } - - Ok(peer_id) + Ok(peer_id_from_public(Public::from_raw(public_key.to_bytes()))) } } diff --git a/coordinator/src/p2p/libp2p/dial.rs b/coordinator/src/p2p/libp2p/dial.rs index 03795a51..e8611797 100644 --- a/coordinator/src/p2p/libp2p/dial.rs +++ b/coordinator/src/p2p/libp2p/dial.rs @@ -37,7 +37,7 @@ pub(crate) struct DialTask { impl DialTask { pub(crate) fn new(serai: Serai, peers: Peers, to_dial: mpsc::UnboundedSender) -> Self { - DialTask { serai: serai.clone(), validators: Validators::new(serai), peers, to_dial } + DialTask { serai: serai.clone(), validators: Validators::new(serai).0, peers, to_dial } } } diff --git a/coordinator/src/p2p/libp2p/mod.rs b/coordinator/src/p2p/libp2p/mod.rs index ce60d285..fccf7ce1 100644 --- a/coordinator/src/p2p/libp2p/mod.rs +++ b/coordinator/src/p2p/libp2p/mod.rs @@ -25,7 +25,7 @@ use libp2p::{ multihash::Multihash, identity::{self, PeerId}, tcp::Config as TcpConfig, - yamux, + yamux, allow_block_list, swarm::NetworkBehaviour, SwarmBuilder, }; @@ -112,6 +112,7 @@ struct Peers { #[derive(NetworkBehaviour)] struct Behavior { + allow_list: allow_block_list::Behaviour, ping: ping::Behavior, reqres: reqres::Behavior, gossip: gossip::Behavior, @@ -147,43 +148,43 @@ impl Libp2p { .continually_run(dial_task_def, vec![]), ); - // Define the Validators object used for validating new connections - let connection_validators = UpdateValidatorsTask::spawn(serai.clone()); - let new_only_validators = |noise_keypair: &identity::Keypair| -> Result<_, ()> { - Ok(OnlyValidators { - serai_key: serai_key.clone(), - validators: connection_validators.clone(), - noise_keypair: noise_keypair.clone(), - }) + let swarm = { + let new_only_validators = |noise_keypair: &identity::Keypair| -> Result<_, ()> { + Ok(OnlyValidators { serai_key: serai_key.clone(), noise_keypair: noise_keypair.clone() }) + }; + + let new_yamux = || { + let mut config = yamux::Config::default(); + // 1 MiB default + max message size + config.set_max_buffer_size((1024 * 1024) + MAX_LIBP2P_MESSAGE_SIZE); + // 256 KiB default + max message size + config + .set_receive_window_size(((256 * 1024) + MAX_LIBP2P_MESSAGE_SIZE).try_into().unwrap()); + config + }; + + let mut swarm = SwarmBuilder::with_existing_identity(identity::Keypair::generate_ed25519()) + .with_tokio() + .with_tcp(TcpConfig::default().nodelay(false), new_only_validators, new_yamux) + .unwrap() + .with_behaviour(|_| Behavior { + allow_list: allow_block_list::Behaviour::default(), + ping: ping::new_behavior(), + reqres: reqres::new_behavior(), + gossip: gossip::new_behavior(), + }) + .unwrap() + .with_swarm_config(|config| { + config + .with_idle_connection_timeout(ping::INTERVAL + ping::TIMEOUT + Duration::from_secs(5)) + }) + .build(); + swarm.listen_on(format!("/ip4/0.0.0.0/tcp/{PORT}").parse().unwrap()).unwrap(); + swarm.listen_on(format!("/ip6/::/tcp/{PORT}").parse().unwrap()).unwrap(); + swarm }; - let new_yamux = || { - let mut config = yamux::Config::default(); - // 1 MiB default + max message size - config.set_max_buffer_size((1024 * 1024) + MAX_LIBP2P_MESSAGE_SIZE); - // 256 KiB default + max message size - config.set_receive_window_size(((256 * 1024) + MAX_LIBP2P_MESSAGE_SIZE).try_into().unwrap()); - config - }; - - let mut swarm = SwarmBuilder::with_existing_identity(identity::Keypair::generate_ed25519()) - .with_tokio() - .with_tcp(TcpConfig::default().nodelay(false), new_only_validators, new_yamux) - .unwrap() - .with_behaviour(|_| Behavior { - ping: ping::new_behavior(), - reqres: reqres::new_behavior(), - gossip: gossip::new_behavior(), - }) - .unwrap() - .with_swarm_config(|config| { - config.with_idle_connection_timeout(ping::INTERVAL + ping::TIMEOUT + Duration::from_secs(5)) - }) - .build(); - swarm.listen_on(format!("/ip4/0.0.0.0/tcp/{PORT}").parse().unwrap()).unwrap(); - swarm.listen_on(format!("/ip6/::/tcp/{PORT}").parse().unwrap()).unwrap(); - - let swarm_validators = UpdateValidatorsTask::spawn(serai); + let (swarm_validators, validator_changes) = UpdateValidatorsTask::spawn(serai); let (gossip_send, gossip_recv) = mpsc::unbounded_channel(); let (signed_cosigns_send, signed_cosigns_recv) = mpsc::unbounded_channel(); @@ -201,6 +202,7 @@ impl Libp2p { dial_task, to_dial_recv, swarm_validators, + validator_changes, peers.clone(), swarm, gossip_recv, diff --git a/coordinator/src/p2p/libp2p/swarm.rs b/coordinator/src/p2p/libp2p/swarm.rs index f4c5d7fe..10d91818 100644 --- a/coordinator/src/p2p/libp2p/swarm.rs +++ b/coordinator/src/p2p/libp2p/swarm.rs @@ -23,7 +23,7 @@ use libp2p::{ use crate::p2p::libp2p::{ Peers, BehaviorEvent, Behavior, - validators::Validators, + validators::{self, Validators}, ping, reqres::{self, Request, Response}, gossip, @@ -52,6 +52,7 @@ pub(crate) struct SwarmTask { last_dial_task_run: Instant, validators: Arc>, + validator_changes: mpsc::UnboundedReceiver, peers: Peers, rebuild_peers_at: Instant, @@ -135,6 +136,18 @@ impl SwarmTask { let time_till_rebuild_peers = self.rebuild_peers_at.saturating_duration_since(Instant::now()); tokio::select! { + // If the validators have changed, update the allow list + validator_changes = self.validator_changes.recv() => { + let validator_changes = validator_changes.expect("validators update task shut down?"); + let behavior = &mut self.swarm.behaviour_mut().allow_list; + for removed in validator_changes.removed { + behavior.disallow_peer(removed); + } + for added in validator_changes.added { + behavior.allow_peer(added); + } + } + // Dial peers we're instructed to dial_opts = self.to_dial.recv() => { let dial_opts = dial_opts.expect("DialTask was closed?"); @@ -155,26 +168,15 @@ impl SwarmTask { let validators_by_network = self.validators.read().await.by_network().clone(); let connected_peers = self.swarm.connected_peers().copied().collect::>(); - // We initially populate the list of peers to disconnect with all peers - let mut to_disconnect = connected_peers.clone(); - // Build the new peers object let mut peers = HashMap::new(); for (network, validators) in validators_by_network { peers.insert(network, validators.intersection(&connected_peers).copied().collect()); - - // If this peer is in this validator set, don't keep it flagged for disconnection - to_disconnect.retain(|peer| !validators.contains(peer)); } // Write the new peers object *self.peers.peers.write().await = peers; self.rebuild_peers_at = Instant::now() + TIME_BETWEEN_REBUILD_PEERS; - - // Disconnect all peers marked for disconnection - for peer in to_disconnect { - let _: Result<_, _> = self.swarm.disconnect_peer_id(peer); - } } // Handle swarm events @@ -223,6 +225,10 @@ impl SwarmTask { } } + SwarmEvent::Behaviour(BehaviorEvent::AllowList(event)) => { + // Ensure this is an unreachable case, not an actual event + let _: void::Void = event; + } SwarmEvent::Behaviour( BehaviorEvent::Ping(ping::Event { peer: _, connection, result, }) ) => { @@ -305,6 +311,7 @@ impl SwarmTask { to_dial: mpsc::UnboundedReceiver, validators: Arc>, + validator_changes: mpsc::UnboundedReceiver, peers: Peers, swarm: Swarm, @@ -326,6 +333,7 @@ impl SwarmTask { last_dial_task_run: Instant::now(), validators, + validator_changes, peers, rebuild_peers_at: Instant::now() + TIME_BETWEEN_REBUILD_PEERS, diff --git a/coordinator/src/p2p/libp2p/validators.rs b/coordinator/src/p2p/libp2p/validators.rs index b5be7c9e..7eb2e996 100644 --- a/coordinator/src/p2p/libp2p/validators.rs +++ b/coordinator/src/p2p/libp2p/validators.rs @@ -11,10 +11,15 @@ use serai_task::{Task, ContinuallyRan}; use libp2p::PeerId; use futures_util::stream::{StreamExt, FuturesUnordered}; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use crate::p2p::libp2p::peer_id_from_public; +pub(crate) struct Changes { + pub(crate) removed: HashSet, + pub(crate) added: HashSet, +} + pub(crate) struct Validators { serai: Serai, @@ -24,16 +29,22 @@ pub(crate) struct Validators { by_network: HashMap>, // The validators and their networks validators: HashMap>, + + // The channel to send the changes down + changes: mpsc::UnboundedSender, } impl Validators { - pub(crate) fn new(serai: Serai) -> Self { - Validators { + pub(crate) fn new(serai: Serai) -> (Self, mpsc::UnboundedReceiver) { + let (send, recv) = mpsc::unbounded_channel(); + let validators = Validators { serai, sessions: HashMap::new(), by_network: HashMap::new(), validators: HashMap::new(), - } + changes: send, + }; + (validators, recv) } async fn session_changes( @@ -89,6 +100,9 @@ impl Validators { &mut self, session_changes: Vec<(NetworkId, Session, HashSet)>, ) { + let mut removed = HashSet::new(); + let mut added = HashSet::new(); + for (network, session, validators) in session_changes { // Remove the existing validators for validator in self.by_network.remove(&network).unwrap_or_else(HashSet::new) { @@ -96,21 +110,31 @@ impl Validators { let mut networks = self.validators.remove(&validator).unwrap(); // Remove this one networks.remove(&network); - // Insert the networks back if the validator was present in other networks if !networks.is_empty() { + // Insert the networks back if the validator was present in other networks self.validators.insert(validator, networks); + } else { + // Because this validator is no longer present in any network, mark them as removed + removed.insert(validator); } } // Add the new validators for validator in validators.iter().copied() { self.validators.entry(validator).or_insert_with(HashSet::new).insert(network); + added.insert(validator); } self.by_network.insert(network, validators); // Update the session we have populated self.sessions.insert(network, session); } + + // Only flag validators for removal if they weren't simultaneously added by these changes + removed.retain(|validator| !added.contains(validator)); + // Send the changes, dropping the error + // This lets the caller opt-out of change notifications by dropping the receiver + let _: Result<_, _> = self.changes.send(Changes { removed, added }); } /// Update the view of the validators. @@ -145,9 +169,10 @@ impl UpdateValidatorsTask { /// Spawn a new instance of the UpdateValidatorsTask. /// /// This returns a reference to the Validators it updates after spawning itself. - pub(crate) fn spawn(serai: Serai) -> Arc> { + pub(crate) fn spawn(serai: Serai) -> (Arc>, mpsc::UnboundedReceiver) { // The validators which will be updated - let validators = Arc::new(RwLock::new(Validators::new(serai))); + let (validators, changes) = Validators::new(serai); + let validators = Arc::new(RwLock::new(validators)); // Define the task let (update_validators_task, update_validators_task_handle) = Task::new(); @@ -159,7 +184,7 @@ impl UpdateValidatorsTask { ); // Return the validators - validators + (validators, changes) } }