Add a message queue

This is intended to be a reliable transport between the processors and
coordinator. Since it'll be intranet only, it's written as never fail.

Primarily needs testing and a proper ID.
This commit is contained in:
Luke Parker
2023-07-01 08:53:46 -04:00
parent a95ecc2512
commit 6267acf3df
11 changed files with 340 additions and 2 deletions

2
message-queue/src/lib.rs Normal file
View File

@@ -0,0 +1,2 @@
mod messages;
pub use messages::*;

152
message-queue/src/main.rs Normal file
View File

@@ -0,0 +1,152 @@
use std::{
sync::{Arc, RwLock},
collections::HashMap,
};
use ciphersuite::{group::GroupEncoding, Ciphersuite, Ristretto};
use schnorr_signatures::SchnorrSignature;
use serai_primitives::NetworkId;
use jsonrpsee::{RpcModule, server::ServerBuilder};
mod messages;
use messages::*;
mod queue;
use queue::Queue;
lazy_static::lazy_static! {
static ref KEYS: Arc<RwLock<HashMap<Service, <Ristretto as Ciphersuite>::G>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref QUEUES: Arc<RwLock<HashMap<Service, RwLock<Queue<serai_db::MemDb>>>>> =
Arc::new(RwLock::new(HashMap::new()));
}
// queue RPC method
fn queue_message(meta: Metadata, msg: Vec<u8>, sig: SchnorrSignature<Ristretto>) {
{
let from = (*KEYS).read().unwrap()[&meta.from];
assert!(sig.verify(from, message_challenge(from, &msg, sig.R)));
}
// Assert one, and only one of these, is the coordinator
assert!(matches!(meta.from, Service::Coordinator) ^ matches!(meta.to, Service::Coordinator));
// TODO: Verify the from_id hasn't been prior seen
// Queue it
(*QUEUES).read().unwrap()[&meta.to].write().unwrap().queue_message(QueuedMessage {
from: meta.from,
msg,
sig: sig.serialize(),
});
}
// get RPC method
fn get_next_message(
service: Service,
_expected: u64,
_signature: SchnorrSignature<Ristretto>,
) -> Option<QueuedMessage> {
// TODO: Verify the signature
// TODO: Verify the expected next message ID matches
let queue_outer = (*QUEUES).read().unwrap();
let queue = queue_outer[&service].read().unwrap();
let next = queue.last_acknowledged().map(|i| i + 1).unwrap_or(0);
queue.get_message(next)
}
// ack RPC method
fn ack_message(service: Service, id: u64, _signature: SchnorrSignature<Ristretto>) {
// TODO: Verify the signature
// Is it:
// The acknowledged message should be > last acknowledged OR
// The acknowledged message should be >=
// It's the first if we save messages as acknowledged before acknowledging them
// It's the second if we acknowledge messages before saving them as acknowledged
// TODO: Check only a proper message is being acked
(*QUEUES).read().unwrap()[&service].write().unwrap().ack_message(id)
}
#[tokio::main]
async fn main() {
// Open the DB
// TODO
let db = serai_db::MemDb::new();
let read_key = |str| {
let Ok(key) = std::env::var(str) else { None? };
let mut repr = <<Ristretto as Ciphersuite>::G as GroupEncoding>::Repr::default();
repr.as_mut().copy_from_slice(&hex::decode(key).unwrap());
Some(<Ristretto as Ciphersuite>::G::from_bytes(&repr).unwrap())
};
let register_service = |service, key| {
(*KEYS).write().unwrap().insert(service, key);
(*QUEUES).write().unwrap().insert(service, RwLock::new(Queue(db.clone(), service)));
};
// Make queues for each NetworkId, other than Serai
for network in [NetworkId::Bitcoin, NetworkId::Ethereum, NetworkId::Monero] {
// Use a match so we error if the list of NetworkIds changes
let Some(key) = read_key(match network {
NetworkId::Serai => unreachable!(),
NetworkId::Bitcoin => "BITCOIN_KEY",
NetworkId::Ethereum => "ETHEREUM_KEY",
NetworkId::Monero => "MONERO_KEY",
}) else { continue };
register_service(Service::Processor(network), key);
}
// And the coordinator's
register_service(Service::Coordinator, read_key("COORDINATOR_KEY").unwrap());
// Start server
let builder = ServerBuilder::new();
// TODO: Set max request/response size
let listen_on: &[std::net::SocketAddr] = &["0.0.0.0".parse().unwrap()];
let server = builder.build(listen_on).await.unwrap();
let mut module = RpcModule::new(());
module
.register_method("queue", |args, _| {
let args = args.parse::<(Metadata, Vec<u8>, Vec<u8>)>().unwrap();
queue_message(
args.0,
args.1,
SchnorrSignature::<Ristretto>::read(&mut args.2.as_slice()).unwrap(),
);
Ok(())
})
.unwrap();
module
.register_method("next", |args, _| {
let args = args.parse::<(Service, u64, Vec<u8>)>().unwrap();
get_next_message(
args.0,
args.1,
SchnorrSignature::<Ristretto>::read(&mut args.2.as_slice()).unwrap(),
);
Ok(())
})
.unwrap();
module
.register_method("ack", |args, _| {
let args = args.parse::<(Service, u64, Vec<u8>)>().unwrap();
ack_message(
args.0,
args.1,
SchnorrSignature::<Ristretto>::read(&mut args.2.as_slice()).unwrap(),
);
Ok(())
})
.unwrap();
server.start(module).unwrap();
}

View File

@@ -0,0 +1,40 @@
use transcript::{Transcript, RecommendedTranscript};
use ciphersuite::{group::GroupEncoding, Ciphersuite, Ristretto};
use serde::{Serialize, Deserialize};
use serai_primitives::NetworkId;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub enum Service {
Processor(NetworkId),
Coordinator,
}
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct QueuedMessage {
pub from: Service,
pub msg: Vec<u8>,
pub sig: Vec<u8>,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Metadata {
pub from: Service,
pub to: Service,
pub from_id: u64,
}
pub fn message_challenge(
from: <Ristretto as Ciphersuite>::G,
msg: &[u8],
nonce: <Ristretto as Ciphersuite>::G,
) -> <Ristretto as Ciphersuite>::F {
let mut transcript = RecommendedTranscript::new(b"Serai Message Queue v0.1");
transcript.domain_separate(b"message");
transcript.append_message(b"from", from.to_bytes());
transcript.append_message(b"msg", msg);
transcript.domain_separate(b"signature");
transcript.append_message(b"nonce", nonce.to_bytes());
<Ristretto as Ciphersuite>::hash_to_F(b"challenge", &transcript.challenge(b"challenge"))
}

View File

@@ -0,0 +1,57 @@
use serai_db::{DbTxn, Db};
use crate::messages::*;
#[derive(Clone, Debug)]
pub(crate) struct Queue<D: Db>(pub(crate) D, pub(crate) Service);
impl<D: Db> Queue<D> {
fn key(domain: &'static [u8], key: impl AsRef<[u8]>) -> Vec<u8> {
[&[u8::try_from(domain.len()).unwrap()], domain, key.as_ref()].concat()
}
fn message_count_key(&self) -> Vec<u8> {
Self::key(b"message_count", serde_json::to_vec(&self.1).unwrap())
}
pub(crate) fn message_count(&self) -> u64 {
self
.0
.get(self.message_count_key())
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
.unwrap_or(0)
}
fn last_acknowledged_key(&self) -> Vec<u8> {
Self::key(b"last_acknowledged", serde_json::to_vec(&self.1).unwrap())
}
pub(crate) fn last_acknowledged(&self) -> Option<u64> {
self
.0
.get(self.last_acknowledged_key())
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
}
fn message_key(&self, id: u64) -> Vec<u8> {
Self::key(b"message", serde_json::to_vec(&(self.1, id)).unwrap())
}
pub(crate) fn queue_message(&mut self, msg: QueuedMessage) {
let id = self.message_count();
let msg_key = self.message_key(id);
let msg_count_key = self.message_count_key();
let mut txn = self.0.txn();
txn.put(msg_key, serde_json::to_vec(&msg).unwrap());
txn.put(msg_count_key, (id + 1).to_le_bytes());
txn.commit();
}
pub(crate) fn get_message(&self, id: u64) -> Option<QueuedMessage> {
self.0.get(self.message_key(id)).map(|bytes| serde_json::from_slice(&bytes).unwrap())
}
pub(crate) fn ack_message(&mut self, id: u64) {
let ack_key = self.last_acknowledged_key();
let mut txn = self.0.txn();
txn.put(ack_key, id.to_le_bytes());
txn.commit();
}
}