diff --git a/Cargo.lock b/Cargo.lock index 084e382db24..df435a13557 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -435,14 +435,18 @@ dependencies = [ "agave-math-utils", "agave-votor", "agave-votor-messages", + "arc-swap", "bitvec", + "bytes", "crossbeam-channel", "itertools 0.14.0", "lazy-lru", "log", "parking_lot 0.12.3", "qualifier_attr", + "quinn", "rand 0.9.4", + "rustls", "serde", "serde_bytes", "solana-accounts-db", @@ -472,6 +476,7 @@ dependencies = [ "solana-signer-store", "solana-streamer", "solana-time-utils", + "solana-tls-utils", "solana-transaction", "solana-transaction-error", "solana-vote", @@ -479,6 +484,7 @@ dependencies = [ "tempfile", "test-case", "thiserror 2.0.18", + "tokio", "tokio-util 0.7.18", "wincode", ] @@ -7797,6 +7803,7 @@ dependencies = [ "num_cpus", "num_enum", "qualifier_attr", + "quinn", "rand 0.9.4", "rand_chacha 0.9.0", "rayon", @@ -7807,6 +7814,7 @@ dependencies = [ "serde_bytes", "serial_test", "shaq", + "shuttle", "signal-hook", "slab", "solana-account 4.3.0", diff --git a/core/Cargo.toml b/core/Cargo.toml index bf171f15265..da7e851f705 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -38,6 +38,7 @@ frozen-abi = [ "solana-vote/frozen-abi", "solana-vote-program/frozen-abi", ] +shuttle-test = ["dep:shuttle"] [dependencies] agave-banking-stage-ingress-types = { workspace = true } @@ -75,6 +76,7 @@ min-max-heap = { workspace = true } num_cpus = { workspace = true } num_enum = { workspace = true } qualifier_attr = { workspace = true } +quinn = { workspace = true } rand = { workspace = true } rand_chacha = { workspace = true } rayon = { workspace = true } @@ -84,6 +86,7 @@ rustls = { workspace = true } serde = { workspace = true } serde_bytes = { workspace = true } shaq = { workspace = true } +shuttle = { workspace = true, optional = true } signal-hook = { workspace = true } slab = { workspace = true } solana-account = { workspace = true } diff --git a/core/src/admin_rpc_post_init.rs b/core/src/admin_rpc_post_init.rs index 79259fc625c..b054e0eef4b 100644 --- a/core/src/admin_rpc_post_init.rs +++ b/core/src/admin_rpc_post_init.rs @@ -32,8 +32,8 @@ pub enum KeyUpdaterType { RpcService, /// BLS all-to-all streamer key updater Bls, - /// BLS all-to-all connection cache key updater - BlsConnectionCache, + /// BLS all-to-all QUIC datagram client key updater + BlsDatagramClient, } /// Responsible for managing the updaters for identity key change diff --git a/core/src/bls_quic_datagram.rs b/core/src/bls_quic_datagram.rs new file mode 100644 index 00000000000..bbfeaaf40bc --- /dev/null +++ b/core/src/bls_quic_datagram.rs @@ -0,0 +1,1725 @@ +use { + bytes::Bytes, + crossbeam_channel::{Sender, TrySendError}, + quinn::{ + Connection, Endpoint, EndpointConfig, IdleTimeout, Incoming, ServerConfig, TokioRuntime, + VarInt, + crypto::rustls::{NoInitialCipherSuite, QuicServerConfig}, + }, + rustls::KeyLogFile, + solana_keypair::Keypair, + solana_net_utils::token_bucket::TokenBucket, + solana_packet::Meta, + solana_perf::packet::{BytesPacket, PACKET_DATA_SIZE, PacketBatch}, + solana_pubkey::Pubkey, + solana_runtime::bank::MAX_ALPENGLOW_VOTE_ACCOUNTS, + solana_streamer::nonblocking::{quic::ALPN_TPU_PROTOCOL_ID, simple_qos::SimpleQosBanlist}, + solana_tls_utils::{ + NotifyKeyUpdate, get_remote_pubkey, new_dummy_x509_certificate, tls_server_config_builder, + }, + std::{ + collections::HashMap, + net::{SocketAddr, UdpSocket}, + sync::Arc, + thread::{self, JoinHandle}, + time::Duration, + }, + tokio::{sync::mpsc, time::timeout}, + tokio_util::sync::CancellationToken, +}; + +const QUIC_MAX_TIMEOUT: Duration = Duration::from_secs(30); +const QUIC_CONNECTION_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(2); +const DATAGRAM_RECEIVE_BUFFER_SIZE: usize = PACKET_DATA_SIZE * 64; +const MAX_BLS_DATAGRAM_CONNECTIONS: usize = MAX_ALPENGLOW_VOTE_ACCOUNTS * 2; +const MAX_BLS_DATAGRAM_CONNECTIONS_PER_PEER: usize = 2; +const MAX_BLS_DATAGRAMS_PER_SECOND_PER_CONNECTION: f64 = 100.0; +const BLS_DATAGRAM_RATE_LIMIT_BURST: u64 = 1_000; +const BLS_DATAGRAM_DOS_BURST: u64 = 100_000; +const BLS_DATAGRAM_DOS_BAN_TIMEOUT: Duration = Duration::from_secs(48 * 60 * 60); +const BLS_DATAGRAM_BAN_CHECK_INTERVAL: Duration = Duration::from_secs(1); +const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2; +const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed"; +const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4; +const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many"; +const CONNECTION_CLOSE_CODE_PACKET_CHANNEL_CLOSED: u32 = 6; +const CONNECTION_CLOSE_REASON_PACKET_CHANNEL_CLOSED: &[u8] = b"packet_channel_closed"; + +pub(crate) type StakedPeerChecker = Arc bool + Send + Sync + 'static>; + +enum ServerEvent { + Accepted { + connection: Connection, + remote_addr: SocketAddr, + remote_pubkey: Pubkey, + }, + Closed { + remote_pubkey: Pubkey, + }, +} + +pub(crate) struct SpawnBlsQuicDatagramServerResult { + pub(crate) thread: JoinHandle<()>, + pub(crate) key_updater: Arc, + pub(crate) banlist: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum BlsQuicDatagramServerError { + #[error("endpoint creation failed: {0}")] + EndpointFailed(#[from] std::io::Error), + #[error("TLS error: {0}")] + TlsError(#[from] rustls::Error), + #[error("no initial cipher suite")] + NoInitialCipherSuite(#[from] NoInitialCipherSuite), +} + +pub(crate) struct BlsQuicDatagramServerKeyUpdater { + endpoint: Endpoint, +} + +impl NotifyKeyUpdate for BlsQuicDatagramServerKeyUpdater { + fn update_key(&self, key: &Keypair) -> Result<(), Box> { + self.endpoint + .set_server_config(Some(configure_server(key)?)); + Ok(()) + } +} + +pub(crate) fn spawn_bls_quic_datagram_server( + name: &'static str, + socket: UdpSocket, + identity_keypair: &Keypair, + packet_sender: Sender, + is_staked_peer: StakedPeerChecker, + cancel: CancellationToken, +) -> Result { + info!( + "Start {name} QUIC datagram server on {:?}", + socket.local_addr() + ); + let server_config = configure_server(identity_keypair)?; + let (banlist, banlist_eviction_receiver) = SimpleQosBanlist::new(); + let banlist = Arc::new(banlist); + let (init_sender, init_receiver) = std::sync::mpsc::sync_channel(1); + let thread = thread::Builder::new() + .name(name.to_string()) + .spawn({ + let banlist = banlist.clone(); + move || { + let runtime = tokio::runtime::Builder::new_current_thread() + .thread_name(name) + .enable_all() + .build() + .unwrap(); + let guard = runtime.enter(); + let endpoint = Endpoint::new( + EndpointConfig::default(), + Some(server_config), + socket, + Arc::new(TokioRuntime), + ); + let endpoint = match endpoint { + Ok(endpoint) => endpoint, + Err(err) => { + let _ = + init_sender.send(Err(BlsQuicDatagramServerError::EndpointFailed(err))); + return; + } + }; + let key_updater = Arc::new(BlsQuicDatagramServerKeyUpdater { + endpoint: endpoint.clone(), + }); + let _ = init_sender.send(Ok(key_updater)); + drop(guard); + + runtime.block_on(run_server( + endpoint, + packet_sender, + is_staked_peer, + banlist, + banlist_eviction_receiver, + cancel, + )); + } + }) + .unwrap(); + + let key_updater = match init_receiver.recv().unwrap() { + Ok(key_updater) => key_updater, + Err(err) => { + thread.join().unwrap(); + return Err(err); + } + }; + + Ok(SpawnBlsQuicDatagramServerResult { + thread, + key_updater, + banlist, + }) +} + +fn configure_server( + identity_keypair: &Keypair, +) -> Result { + let (cert, priv_key) = new_dummy_x509_certificate(identity_keypair); + let mut server_tls_config = + tls_server_config_builder().with_single_cert(vec![cert], priv_key)?; + server_tls_config.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + server_tls_config.key_log = Arc::new(KeyLogFile::new()); + + let mut server_config = + ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(server_tls_config)?)); + server_config.migration(false); + + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(0u32.into()); + transport_config.max_concurrent_bidi_streams(0u32.into()); + transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE)); + transport_config.max_idle_timeout(Some(IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap())); + transport_config.enable_segmentation_offload(false); + + Ok(server_config) +} + +async fn run_server( + endpoint: Endpoint, + packet_sender: Sender, + is_staked_peer: StakedPeerChecker, + banlist: Arc, + mut banlist_eviction_receiver: tokio::sync::mpsc::Receiver, + cancel: CancellationToken, +) { + let (server_event_sender, mut server_event_receiver) = + mpsc::channel(MAX_BLS_DATAGRAM_CONNECTIONS); + let mut active_connections = 0usize; + let mut peer_connection_counts = HashMap::::new(); + let mut banlist_eviction_receiver_closed = false; + + loop { + tokio::select! { + maybe_incoming = endpoint.accept() => { + let Some(incoming) = maybe_incoming else { + break; + }; + tokio::spawn(handle_incoming( + incoming, + is_staked_peer.clone(), + banlist.clone(), + server_event_sender.clone(), + )); + } + maybe_event = server_event_receiver.recv() => { + let Some(event) = maybe_event else { + break; + }; + handle_server_event( + event, + &packet_sender, + &banlist, + &mut active_connections, + &mut peer_connection_counts, + &server_event_sender, + cancel.clone(), + ); + } + maybe_banned_pubkey = banlist_eviction_receiver.recv(), if !banlist_eviction_receiver_closed => { + banlist_eviction_receiver_closed = maybe_banned_pubkey.is_none(); + } + _ = cancel.cancelled() => break, + } + } + endpoint.close(VarInt::from_u32(0), b"shutdown"); +} + +async fn handle_incoming( + incoming: Incoming, + is_staked_peer: StakedPeerChecker, + banlist: Arc, + server_event_sender: mpsc::Sender, +) { + let connection = match incoming.accept() { + Ok(connecting) => match timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await { + Ok(Ok(connection)) => connection, + Ok(Err(err)) => { + debug!("votor QUIC datagram handshake failed: {err}"); + return; + } + Err(_) => { + debug!("votor QUIC datagram handshake timed out"); + return; + } + }, + Err(err) => { + debug!("votor QUIC datagram accept failed: {err}"); + return; + } + }; + + let remote_addr = connection.remote_address(); + let Some(remote_pubkey) = get_remote_pubkey(&connection) else { + close_disallowed(&connection); + return; + }; + if !is_staked_peer(&remote_pubkey) || banlist.is_banned(&remote_pubkey) { + close_disallowed(&connection); + return; + } + + let event = ServerEvent::Accepted { + connection, + remote_addr, + remote_pubkey, + }; + let _ = server_event_sender.send(event).await; +} + +fn handle_server_event( + event: ServerEvent, + packet_sender: &Sender, + banlist: &Arc, + active_connections: &mut usize, + peer_connection_counts: &mut HashMap, + server_event_sender: &mpsc::Sender, + cancel: CancellationToken, +) { + match event { + ServerEvent::Accepted { + connection, + remote_addr, + remote_pubkey, + } => { + if *active_connections >= MAX_BLS_DATAGRAM_CONNECTIONS + || peer_connection_counts + .get(&remote_pubkey) + .copied() + .unwrap_or_default() + >= MAX_BLS_DATAGRAM_CONNECTIONS_PER_PEER + { + connection.close( + CONNECTION_CLOSE_CODE_TOO_MANY.into(), + CONNECTION_CLOSE_REASON_TOO_MANY, + ); + return; + } + + *active_connections += 1; + *peer_connection_counts.entry(remote_pubkey).or_default() += 1; + tokio::spawn(handle_connection_lifecycle( + connection, + remote_addr, + remote_pubkey, + packet_sender.clone(), + banlist.clone(), + server_event_sender.clone(), + cancel, + )); + } + ServerEvent::Closed { remote_pubkey } => { + *active_connections = active_connections.saturating_sub(1); + let Some(count) = peer_connection_counts.get_mut(&remote_pubkey) else { + return; + }; + *count = count.saturating_sub(1); + if *count == 0 { + peer_connection_counts.remove(&remote_pubkey); + } + } + } +} + +async fn handle_connection_lifecycle( + connection: Connection, + remote_addr: SocketAddr, + remote_pubkey: Pubkey, + packet_sender: Sender, + banlist: Arc, + server_event_sender: mpsc::Sender, + cancel: CancellationToken, +) { + handle_connection( + connection, + remote_addr, + remote_pubkey, + packet_sender, + banlist, + cancel, + ) + .await; + let _ = server_event_sender + .send(ServerEvent::Closed { remote_pubkey }) + .await; +} + +async fn handle_connection( + connection: Connection, + remote_addr: SocketAddr, + remote_pubkey: Pubkey, + packet_sender: Sender, + banlist: Arc, + cancel: CancellationToken, +) { + let receive_rate_limiter = TokenBucket::new( + BLS_DATAGRAM_RATE_LIMIT_BURST, + BLS_DATAGRAM_RATE_LIMIT_BURST, + MAX_BLS_DATAGRAMS_PER_SECOND_PER_CONNECTION, + ); + let dos_rate_limiter = TokenBucket::new( + BLS_DATAGRAM_DOS_BURST, + BLS_DATAGRAM_DOS_BURST, + MAX_BLS_DATAGRAMS_PER_SECOND_PER_CONNECTION, + ); + let mut ban_check = tokio::time::interval(BLS_DATAGRAM_BAN_CHECK_INTERVAL); + ban_check.tick().await; + + loop { + if banlist.is_banned(&remote_pubkey) { + close_disallowed(&connection); + return; + } + + let datagram = tokio::select! { + datagram = connection.read_datagram() => match datagram { + Ok(datagram) => datagram, + Err(err) => { + debug!("votor QUIC datagram read failed from {remote_addr}: {err}"); + return; + } + }, + _ = cancel.cancelled() => return, + _ = ban_check.tick() => continue, + }; + if banlist.is_banned(&remote_pubkey) { + close_disallowed(&connection); + return; + } + + if dos_rate_limiter.consume_tokens(1).is_err() { + if !banlist.ban(remote_pubkey, BLS_DATAGRAM_DOS_BAN_TIMEOUT) { + warn!( + "banned votor QUIC datagram sender {remote_pubkey} at {remote_addr} for \ + receive rate abuse" + ); + } + close_disallowed(&connection); + return; + } + + if receive_rate_limiter.consume_tokens(1).is_err() { + debug!("dropping rate-limited votor QUIC datagram from {remote_addr}"); + continue; + } + + if datagram.len() > PACKET_DATA_SIZE { + debug!( + "dropping oversized votor QUIC datagram from {remote_addr}: {} bytes", + datagram.len() + ); + continue; + } + + let packet = datagram_to_packet(datagram, remote_addr, remote_pubkey); + if let Err(err) = packet_sender.try_send(PacketBatch::Single(packet)) { + match err { + TrySendError::Full(_) => { + debug!("dropping votor QUIC datagram from {remote_addr}: packet channel full"); + } + TrySendError::Disconnected(_) => { + connection.close( + CONNECTION_CLOSE_CODE_PACKET_CHANNEL_CLOSED.into(), + CONNECTION_CLOSE_REASON_PACKET_CHANNEL_CLOSED, + ); + return; + } + } + } + } +} + +fn datagram_to_packet( + datagram: Bytes, + remote_addr: SocketAddr, + remote_pubkey: Pubkey, +) -> BytesPacket { + let mut meta = Meta::default(); + meta.size = datagram.len(); + meta.set_socket_addr(&remote_addr); + meta.set_from_staked_node(true); + meta.set_remote_pubkey(remote_pubkey); + BytesPacket::new(datagram, meta) +} + +fn close_disallowed(connection: &Connection) { + connection.close( + CONNECTION_CLOSE_CODE_DISALLOWED.into(), + CONNECTION_CLOSE_REASON_DISALLOWED, + ); +} + +#[cfg(test)] +mod tests { + use { + super::*, + quinn::{ClientConfig, TransportConfig, crypto::rustls::QuicClientConfig}, + solana_net_utils::sockets::bind_to_localhost_unique, + solana_signer::Signer, + solana_tls_utils::{ + QuicClientCertificate, socket_addr_to_quic_server_name, tls_client_config_builder, + }, + std::collections::HashMap, + }; + + #[test] + fn test_datagram_server_receives_staked_packet() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgram", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + + let payload = Bytes::from_static(b"votor datagram"); + send_test_datagram(server_addr, &client_keypair, payload.clone()); + + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("datagram packet"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.meta().remote_pubkey(), Some(client_keypair.pubkey())); + assert!(packet.meta().is_from_staked_node()); + assert_eq!(packet.data(..).unwrap(), payload.as_ref()); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_limits_connections_per_peer() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(10); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramPeerCap", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let mut held_connections = Vec::new(); + + for i in 0..MAX_BLS_DATAGRAM_CONNECTIONS_PER_PEER { + let payload = Bytes::from(format!("accepted-{i}").into_bytes()); + let client = runtime.block_on(async { + let client = connect_test_client(server_addr, &client_keypair).await; + client.1.send_datagram(payload.clone()).unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + client + }); + held_connections.push(client); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("accepted datagram packet"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.meta().remote_pubkey(), Some(client_keypair.pubkey())); + assert_eq!(packet.data(..).unwrap(), payload.as_ref()); + } + + let client = runtime.block_on(async { + let client = connect_test_client(server_addr, &client_keypair).await; + client + .1 + .send_datagram(Bytes::from_static(b"rejected")) + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + client + }); + held_connections.push(client); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + drop(held_connections); + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_closes_banned_connection() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramBanClose", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + connection + .send_datagram(Bytes::from_static(b"probe")) + .unwrap(); + packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("probe datagram packet"); + + server + .banlist + .ban(client_keypair.pubkey(), Duration::from_secs(30)); + let _ = connection.send_datagram(Bytes::from_static(b"after-ban")); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("banned connection should close"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_accepts_peer_after_ban_expires() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramBanExpires", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + + server + .banlist + .ban(client_keypair.pubkey(), Duration::from_millis(200)); + send_test_datagram(server_addr, &client_keypair, Bytes::from_static(b"banned")); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + std::thread::sleep(Duration::from_millis(250)); + send_test_datagram( + server_addr, + &client_keypair, + Bytes::from_static(b"unbanned"), + ); + + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("unbanned datagram packet"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.meta().remote_pubkey(), Some(client_keypair.pubkey())); + assert_eq!(packet.data(..).unwrap(), b"unbanned"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rate_limits_peer_burst() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramRateLimit", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + let burst = (BLS_DATAGRAM_RATE_LIMIT_BURST as usize) * 4; + + runtime.block_on(async { + for i in 0..burst { + connection + .send_datagram(Bytes::from(format!("burst-{i}").into_bytes())) + .unwrap(); + } + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + let mut delivered = 0usize; + while packet_receiver + .recv_timeout(Duration::from_millis(20)) + .is_ok() + { + delivered = delivered.saturating_add(1); + } + + assert!(delivered < burst); + assert!(delivered <= BLS_DATAGRAM_RATE_LIMIT_BURST as usize + 20); + assert!(!server.banlist.is_banned(&client_keypair.pubkey())); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rejects_unstaked_peer() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::new()); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramUnstaked", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + let _ = connection.send_datagram(Bytes::from_static(b"unstaked")); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("unstaked connection should close"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rejects_missing_client_identity() { + let server_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = Arc::new(|_: &Pubkey| true) as StakedPeerChecker; + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramNoCert", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + if let Some((_endpoint, connection)) = runtime.block_on(connect_test_client_with_config( + server_addr, + test_client_config_without_client_cert(), + )) { + let _ = connection.send_datagram(Bytes::from_static(b"no-cert")); + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("connection without client identity should close"); + } + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rejects_wrong_alpn_and_continues() { + let server_keypair = Keypair::new(); + let bad_client_keypair = Keypair::new(); + let good_client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([ + (bad_client_keypair.pubkey(), 100), + (good_client_keypair.pubkey(), 100), + ])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramWrongAlpn", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build() + .unwrap(); + + if let Some((_endpoint, connection)) = runtime.block_on(connect_test_client_with_config( + server_addr, + test_client_config_with_alpn(&bad_client_keypair, b"not-votor"), + )) { + let _ = connection.send_datagram(Bytes::from_static(b"wrong-alpn")); + } + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + send_test_datagram( + server_addr, + &good_client_keypair, + Bytes::from_static(b"valid-after-wrong-alpn"), + ); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("valid datagram after wrong ALPN attempt"); + let packet = packet_batch.first().expect("packet"); + assert_eq!( + packet.meta().remote_pubkey(), + Some(good_client_keypair.pubkey()) + ); + assert_eq!(packet.data(..).unwrap(), b"valid-after-wrong-alpn"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_sender_rejects_oversized_datagram_and_keeps_connection_open() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramOversized", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + assert!(matches!( + connection.send_datagram(Bytes::from(vec![7; PACKET_DATA_SIZE + 1])), + Err(quinn::SendDatagramError::TooLarge) + )); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + connection + .send_datagram(Bytes::from_static(b"after-oversized")) + .unwrap(); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("valid packet after oversized datagram"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.data(..).unwrap(), b"after-oversized"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_drops_when_packet_channel_full_and_keeps_connection_open() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramPktFull", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + runtime.block_on(async { + connection + .send_datagram(Bytes::from_static(b"first")) + .unwrap(); + connection + .send_datagram(Bytes::from_static(b"dropped-full")) + .unwrap(); + tokio::time::sleep(Duration::from_millis(200)).await; + }); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("first packet"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.data(..).unwrap(), b"first"); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(300)) + .is_err() + ); + + connection + .send_datagram(Bytes::from_static(b"after-full")) + .unwrap(); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("packet after channel-full drop"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.data(..).unwrap(), b"after-full"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_closes_when_packet_channel_disconnected_without_exiting_thread() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + drop(packet_receiver); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramPktClosed", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + let _ = connection.send_datagram(Bytes::from_static(b"disconnected")); + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("connection should close when packet channel is disconnected"); + std::thread::sleep(Duration::from_millis(100)); + assert!(!server.thread.is_finished()); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_closes_idle_banned_connection_on_tick() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramIdleBan", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + + connection + .send_datagram(Bytes::from_static(b"before-idle-ban")) + .unwrap(); + packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("probe datagram packet"); + server + .banlist + .ban(client_keypair.pubkey(), Duration::from_secs(30)); + + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("idle banned connection should close on ban-check tick"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rejects_reconnect_while_banned() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramBanReconnect", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + server + .banlist + .ban(client_keypair.pubkey(), Duration::from_secs(30)); + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + let _ = connection.send_datagram(Bytes::from_static(b"banned-reconnect")); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + runtime + .block_on(async { + tokio::time::timeout(Duration::from_secs(5), connection.closed()).await + }) + .expect("reconnected banned peer should be closed"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_same_pubkey_flood_does_not_evict_existing_connections() { + let server_keypair = Keypair::new(); + let client_keypair = Arc::new(Keypair::new()); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(16); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramSamePubkeyFlood", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + let mut held_connections = Vec::new(); + + for i in 0..MAX_BLS_DATAGRAM_CONNECTIONS_PER_PEER { + let (_endpoint, connection) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + connection + .send_datagram(Bytes::from(format!("held-{i}").into_bytes())) + .unwrap(); + held_connections.push(connection); + packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("held connection probe"); + } + + runtime.block_on(async { + let mut handles = Vec::new(); + for i in 0..16usize { + let payload = Bytes::from(format!("flood-{i}").into_bytes()); + handles.push(tokio::spawn(try_send_test_datagram( + server_addr, + Arc::clone(&client_keypair), + payload, + ))); + } + for handle in handles { + let _ = handle.await; + } + }); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + held_connections[0] + .send_datagram(Bytes::from_static(b"after-flood")) + .unwrap(); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("held connection should survive same-pubkey flood"); + let packet = packet_batch.first().expect("packet"); + assert_eq!(packet.data(..).unwrap(), b"after-flood"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_handles_many_concurrent_staked_connections() { + const PEERS: usize = 32; + let server_keypair = Keypair::new(); + let client_keypairs: Vec<_> = (0..PEERS).map(|_| Keypair::new()).collect(); + let stakes = client_keypairs + .iter() + .map(|keypair| (keypair.pubkey(), 100)) + .collect::>(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(PEERS); + let is_staked_peer = staked_peer_checker(stakes); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramManyStaked", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let mut handles = Vec::new(); + for (i, client_keypair) in client_keypairs.into_iter().enumerate() { + let payload = Bytes::from(format!("peer-{i}").into_bytes()); + handles.push(tokio::spawn(try_send_test_datagram( + server_addr, + Arc::new(client_keypair), + payload, + ))); + } + for handle in handles { + handle.await.unwrap(); + } + }); + + let mut delivered = 0usize; + while delivered < PEERS { + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("concurrent staked datagram"); + let packet = packet_batch.first().expect("packet"); + assert!(packet.meta().remote_pubkey().is_some()); + assert!(packet.meta().is_from_staked_node()); + delivered += 1; + } + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_unstaked_flood_does_not_starve_staked_peer() { + const UNSTAKED_PEERS: usize = 32; + let server_keypair = Keypair::new(); + let staked_client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::bounded(1); + let is_staked_peer = + staked_peer_checker(HashMap::from([(staked_client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramUnstakedFlood", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let mut handles = Vec::new(); + for i in 0..UNSTAKED_PEERS { + handles.push(tokio::spawn(try_send_test_datagram( + server_addr, + Arc::new(Keypair::new()), + Bytes::from(format!("unstaked-{i}").into_bytes()), + ))); + } + for handle in handles { + let _ = handle.await; + } + }); + assert!( + packet_receiver + .recv_timeout(Duration::from_millis(500)) + .is_err() + ); + + send_test_datagram( + server_addr, + &staked_client_keypair, + Bytes::from_static(b"staked-after-flood"), + ); + let packet_batch = packet_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("staked datagram after unstaked flood"); + let packet = packet_batch.first().expect("packet"); + assert_eq!( + packet.meta().remote_pubkey(), + Some(staked_client_keypair.pubkey()) + ); + assert_eq!(packet.data(..).unwrap(), b"staked-after-flood"); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + #[test] + fn test_datagram_server_rate_limits_each_allowed_connection() { + let server_keypair = Keypair::new(); + let client_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); + let is_staked_peer = staked_peer_checker(HashMap::from([(client_keypair.pubkey(), 100)])); + let cancel = CancellationToken::new(); + let server = spawn_bls_quic_datagram_server( + "testBlsDgramTwoConnRateLimit", + server_socket, + &server_keypair, + packet_sender, + is_staked_peer, + cancel.clone(), + ) + .unwrap(); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build() + .unwrap(); + let (_endpoint_a, connection_a) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + let (_endpoint_b, connection_b) = + runtime.block_on(connect_test_client(server_addr, &client_keypair)); + let burst = (BLS_DATAGRAM_RATE_LIMIT_BURST as usize) * 4; + + runtime.block_on(async { + for i in 0..burst { + connection_a + .send_datagram(Bytes::from(format!("a-{i}").into_bytes())) + .unwrap(); + connection_b + .send_datagram(Bytes::from(format!("b-{i}").into_bytes())) + .unwrap(); + } + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + let mut delivered = 0usize; + while packet_receiver + .recv_timeout(Duration::from_millis(20)) + .is_ok() + { + delivered = delivered.saturating_add(1); + } + + assert!(delivered < burst * 2); + assert!(delivered <= (BLS_DATAGRAM_RATE_LIMIT_BURST as usize * 2) + 40); + assert!(!server.banlist.is_banned(&client_keypair.pubkey())); + + cancel.cancel(); + server.thread.join().unwrap(); + } + + fn staked_peer_checker(stakes: HashMap) -> StakedPeerChecker { + Arc::new(move |pubkey| stakes.get(pubkey).copied().unwrap_or_default() > 0) + } + + async fn connect_test_client( + server_addr: SocketAddr, + client_keypair: &Keypair, + ) -> (Endpoint, Connection) { + connect_test_client_with_config(server_addr, test_client_config(client_keypair)) + .await + .expect("test client connection") + } + + async fn connect_test_client_with_config( + server_addr: SocketAddr, + client_config: ClientConfig, + ) -> Option<(Endpoint, Connection)> { + let client_socket = bind_to_localhost_unique().unwrap(); + let mut endpoint = Endpoint::new( + EndpointConfig::default(), + None, + client_socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + endpoint.set_default_client_config(client_config); + let server_name = socket_addr_to_quic_server_name(server_addr); + let connecting = endpoint.connect(server_addr, &server_name).ok()?; + let connection = connecting.await.ok()?; + Some((endpoint, connection)) + } + + async fn try_send_test_datagram( + server_addr: SocketAddr, + client_keypair: Arc, + payload: Bytes, + ) { + if let Some((_endpoint, connection)) = + connect_test_client_with_config(server_addr, test_client_config(&client_keypair)).await + { + let _ = connection.send_datagram(payload); + tokio::time::sleep(Duration::from_millis(50)).await; + } + } + + fn send_test_datagram(server_addr: SocketAddr, client_keypair: &Keypair, payload: Bytes) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + runtime.block_on(async move { + let (_endpoint, connection) = connect_test_client(server_addr, client_keypair).await; + connection.send_datagram(payload).unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + }); + } + + fn test_client_config(keypair: &Keypair) -> ClientConfig { + test_client_config_with_alpn(keypair, ALPN_TPU_PROTOCOL_ID) + } + + fn test_client_config_with_alpn(keypair: &Keypair, alpn: &[u8]) -> ClientConfig { + let client_certificate = QuicClientCertificate::new(Some(keypair)); + let mut crypto = tls_client_config_builder() + .with_client_auth_cert(vec![client_certificate.certificate], client_certificate.key) + .unwrap(); + crypto.alpn_protocols = vec![alpn.to_vec()]; + test_client_config_from_crypto(crypto) + } + + fn test_client_config_without_client_cert() -> ClientConfig { + let mut crypto = tls_client_config_builder().with_no_client_auth(); + crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + test_client_config_from_crypto(crypto) + } + + fn test_client_config_from_crypto(crypto: rustls::ClientConfig) -> ClientConfig { + let mut transport_config = TransportConfig::default(); + transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE)); + let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).unwrap())); + config.transport_config(Arc::new(transport_config)); + config + } +} + +#[cfg(all(test, feature = "shuttle-test"))] +mod shuttle_tests { + use std::collections::{HashMap, HashSet}; + + #[derive(Clone, Copy, Debug)] + enum ModelEvent { + Accepted { peer: u8, connection_id: u8 }, + Closed { peer: u8, connection_id: u8 }, + Banned(u8), + Datagram(u8), + } + + #[derive(Debug)] + struct ConnectionControllerModel { + active_connections: usize, + peer_connection_ids: HashMap>, + admitted_connection_ids: HashSet, + banned: HashSet, + forwarded: usize, + banned_datagrams_dropped: usize, + dropped: usize, + global_cap: usize, + per_peer_cap: usize, + } + + impl ConnectionControllerModel { + fn new(global_cap: usize, per_peer_cap: usize) -> Self { + Self { + active_connections: 0, + peer_connection_ids: HashMap::new(), + admitted_connection_ids: HashSet::new(), + banned: HashSet::new(), + forwarded: 0, + banned_datagrams_dropped: 0, + dropped: 0, + global_cap, + per_peer_cap, + } + } + + fn apply(&mut self, event: ModelEvent) { + match event { + ModelEvent::Accepted { + peer, + connection_id, + } => { + let peer_count = self + .peer_connection_ids + .get(&peer) + .map(HashSet::len) + .unwrap_or_default(); + if self.banned.contains(&peer) + || self.active_connections >= self.global_cap + || peer_count >= self.per_peer_cap + || self.admitted_connection_ids.contains(&connection_id) + { + self.dropped += 1; + } else { + self.active_connections += 1; + self.admitted_connection_ids.insert(connection_id); + self.peer_connection_ids + .entry(peer) + .or_default() + .insert(connection_id); + } + } + ModelEvent::Closed { + peer, + connection_id, + } => { + if self.admitted_connection_ids.remove(&connection_id) { + self.active_connections = self.active_connections.saturating_sub(1); + if let Some(peer_connections) = self.peer_connection_ids.get_mut(&peer) { + peer_connections.remove(&connection_id); + if peer_connections.is_empty() { + self.peer_connection_ids.remove(&peer); + } + } + } else { + self.dropped += 1; + } + } + ModelEvent::Banned(peer) => { + self.banned.insert(peer); + } + ModelEvent::Datagram(peer) => { + if self.banned.contains(&peer) { + self.banned_datagrams_dropped += 1; + self.dropped += 1; + } else if self + .peer_connection_ids + .get(&peer) + .map(HashSet::is_empty) + .unwrap_or(true) + { + self.dropped += 1; + } else { + self.forwarded += 1; + } + } + } + self.assert_invariants(); + } + + fn assert_invariants(&self) { + assert!(self.active_connections <= self.global_cap); + assert!( + self.peer_connection_ids + .values() + .all(|connections| connections.len() <= self.per_peer_cap) + ); + let per_peer_sum = self + .peer_connection_ids + .values() + .map(HashSet::len) + .sum::(); + assert_eq!(per_peer_sum, self.active_connections); + assert_eq!(self.admitted_connection_ids.len(), self.active_connections); + } + } + + fn drain_model_events( + receiver: shuttle::sync::mpsc::Receiver, + global_cap: usize, + per_peer_cap: usize, + ) -> ConnectionControllerModel { + let mut model = ConnectionControllerModel::new(global_cap, per_peer_cap); + while let Ok(event) = receiver.recv() { + model.apply(event); + shuttle::thread::yield_now(); + } + model + } + + #[test] + fn shuttle_test_connection_caps_hold_under_accept_close_races() { + shuttle::check_random( + || { + let (sender, receiver) = shuttle::sync::mpsc::sync_channel(2); + for (connection_id, peer) in [0, 0, 0, 1, 1, 2, 3].into_iter().enumerate() { + let sender = sender.clone(); + shuttle::thread::spawn(move || { + let connection_id = connection_id as u8; + sender + .send(ModelEvent::Accepted { + peer, + connection_id, + }) + .unwrap(); + shuttle::thread::yield_now(); + sender + .send(ModelEvent::Closed { + peer, + connection_id, + }) + .unwrap(); + }); + } + drop(sender); + + let model = drain_model_events(receiver, 3, 2); + assert_eq!(model.forwarded, 0); + }, + 500, + ); + } + + #[test] + fn shuttle_test_duplicate_close_events_do_not_underflow_controller_state() { + shuttle::check_random( + || { + let (sender, receiver) = shuttle::sync::mpsc::sync_channel(1); + for event in [ + ModelEvent::Accepted { + peer: 0, + connection_id: 1, + }, + ModelEvent::Closed { + peer: 0, + connection_id: 1, + }, + ModelEvent::Closed { + peer: 0, + connection_id: 1, + }, + ModelEvent::Accepted { + peer: 0, + connection_id: 2, + }, + ModelEvent::Datagram(0), + ] { + let sender = sender.clone(); + shuttle::thread::spawn(move || { + sender.send(event).unwrap(); + }); + } + drop(sender); + + let model = drain_model_events(receiver, 1, 1); + assert!(model.active_connections <= 1); + assert!( + model + .peer_connection_ids + .get(&0) + .map(HashSet::len) + .unwrap_or_default() + <= 1 + ); + }, + 500, + ); + } + + #[test] + fn shuttle_test_ban_datagram_interleavings_never_forward_after_ban_is_observed() { + shuttle::check_random( + || { + let (sender, receiver) = shuttle::sync::mpsc::sync_channel(2); + { + let sender = sender.clone(); + shuttle::thread::spawn(move || { + sender + .send(ModelEvent::Accepted { + peer: 0, + connection_id: 1, + }) + .unwrap(); + shuttle::thread::yield_now(); + sender.send(ModelEvent::Banned(0)).unwrap(); + shuttle::thread::yield_now(); + sender.send(ModelEvent::Datagram(0)).unwrap(); + sender + .send(ModelEvent::Closed { + peer: 0, + connection_id: 1, + }) + .unwrap(); + }); + } + { + let sender = sender.clone(); + shuttle::thread::spawn(move || { + sender.send(ModelEvent::Datagram(0)).unwrap(); + }); + } + drop(sender); + + let model = drain_model_events(receiver, 2, 2); + assert!(model.banned_datagrams_dropped >= 1); + }, + 1_000, + ); + } + + #[test] + fn shuttle_test_bounded_event_channel_backpressure_does_not_deadlock_model() { + shuttle::check_random( + || { + let (sender, receiver) = shuttle::sync::mpsc::sync_channel(1); + for peer in 0..4 { + let sender = sender.clone(); + shuttle::thread::spawn(move || { + sender + .send(ModelEvent::Accepted { + peer, + connection_id: peer, + }) + .unwrap(); + sender.send(ModelEvent::Datagram(peer)).unwrap(); + sender + .send(ModelEvent::Closed { + peer, + connection_id: peer, + }) + .unwrap(); + }); + } + drop(sender); + + let model = drain_model_events(receiver, 2, 1); + assert!(model.forwarded <= 4); + assert!(model.active_connections <= 2); + }, + 500, + ); + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index b3f6a506a2b..19fbeafd601 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -13,6 +13,7 @@ pub mod banking_simulation; pub mod banking_stage; pub mod banking_trace; pub(crate) mod block_creation_loop; +mod bls_quic_datagram; pub mod bls_sigverify; pub mod cluster_info_vote_listener; pub mod cluster_slots_service; diff --git a/core/src/staked_nodes_updater_service.rs b/core/src/staked_nodes_updater_service.rs index 471c07b8d7b..657e36bed68 100644 --- a/core/src/staked_nodes_updater_service.rs +++ b/core/src/staked_nodes_updater_service.rs @@ -1,9 +1,10 @@ use { + arc_swap::ArcSwap, solana_pubkey::Pubkey, solana_runtime::bank_forks::BankForks, solana_streamer::streamer::StakedNodes, std::{ - collections::HashMap, + collections::{HashMap, HashSet}, sync::{ Arc, RwLock, atomic::{AtomicBool, Ordering}, @@ -15,16 +16,37 @@ use { const STAKE_REFRESH_CYCLE: Duration = Duration::from_secs(5); +pub type StakedNodePubkeySet = Arc>>; + pub struct StakedNodesUpdaterService { thread_hdl: JoinHandle<()>, } +fn staked_pubkeys( + stakes: &HashMap, + overrides: &HashMap, +) -> HashSet { + let mut pubkeys: HashSet<_> = stakes + .iter() + .filter_map(|(pubkey, stake)| { + (*stake > 0 && !overrides.contains_key(pubkey)).then_some(*pubkey) + }) + .collect(); + pubkeys.extend( + overrides + .iter() + .filter_map(|(pubkey, stake)| (*stake > 0).then_some(*pubkey)), + ); + pubkeys +} + impl StakedNodesUpdaterService { pub fn new( exit: Arc, bank_forks: Arc>, staked_nodes: Arc>, staked_nodes_overrides: Arc>>, + staked_node_pubkeys: Option, ) -> Self { let thread_hdl = Builder::new() .name("solStakedNodeUd".to_string()) @@ -35,6 +57,9 @@ impl StakedNodesUpdaterService { root_bank.current_epoch_staked_nodes() }; let overrides = staked_nodes_overrides.read().unwrap().clone(); + if let Some(staked_node_pubkeys) = &staked_node_pubkeys { + staked_node_pubkeys.store(Arc::new(staked_pubkeys(&stakes, &overrides))); + } *staked_nodes.write().unwrap() = StakedNodes::new(stakes, overrides); std::thread::sleep(STAKE_REFRESH_CYCLE); } diff --git a/core/src/tpu.rs b/core/src/tpu.rs index 2fc899c9bc6..aae1c29247f 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -19,7 +19,7 @@ use { spawn_forwarding_stage, }, sigverify_stage::SigVerifyStage, - staked_nodes_updater_service::StakedNodesUpdaterService, + staked_nodes_updater_service::{StakedNodePubkeySet, StakedNodesUpdaterService}, tpu_entry_notifier::TpuEntryNotifier, validator::{BlockProductionMethod, GeneratorConfig}, }, @@ -141,6 +141,7 @@ impl Tpu { log_messages_bytes_limit: Option, staked_nodes: &Arc>, shared_staked_nodes_overrides: Arc>>, + staked_node_pubkeys: Option, banking_tracer_channels: Channels, tracer_thread_hdl: TracerThread, tpu_quic_server_config: SwQosQuicStreamerConfig, @@ -189,6 +190,7 @@ impl Tpu { bank_forks.clone(), staked_nodes.clone(), shared_staked_nodes_overrides, + staked_node_pubkeys, ); let Channels { diff --git a/core/src/tvu.rs b/core/src/tvu.rs index c119efc2aa4..6278a455dbc 100644 --- a/core/src/tvu.rs +++ b/core/src/tvu.rs @@ -6,6 +6,7 @@ use { admin_rpc_post_init::{KeyUpdaterType, KeyUpdaters}, banking_trace::BankingTracer, block_creation_loop::{ReplayHighestFrozen, rewards::msg_types::AddVoteMessage}, + bls_quic_datagram::{SpawnBlsQuicDatagramServerResult, spawn_bls_quic_datagram_server}, bls_sigverify::bls_sigverifier::{self, SigVerifierChannels, SigVerifierContext}, cluster_info_vote_listener::{ DuplicateConfirmedSlotsReceiver, GossipVerifiedVoteHashReceiver, @@ -24,6 +25,7 @@ use { }, replay_stage::{ReplayReceivers, ReplaySenders, ReplayStage, ReplayStageConfig}, shred_fetch_stage::{SHRED_FETCH_CHANNEL_SIZE, ShredFetchStage}, + staked_nodes_updater_service::StakedNodePubkeySet, voting_service::VotingService, warm_quic_cache_service::WarmQuicCacheService, window_service::{WindowService, WindowServiceChannels}, @@ -34,7 +36,9 @@ use { generated_cert_types::GeneratedCertTypes, vote_history::VoteHistory, vote_history_storage::VoteHistoryStorage, - voting_service::{VotingService as BLSVotingService, VotingServiceOverride}, + voting_service::{ + VotingService as BLSVotingService, VotingServiceOverride, VotingServiceTransport, + }, votor::{Votor, VotorConfig}, }, agave_votor_messages::consensus_message::Block, @@ -62,7 +66,6 @@ use { rpc_subscriptions::RpcSubscriptions, slot_status_notifier::SlotStatusNotifier, }, solana_runtime::{ - bank::MAX_ALPENGLOW_VOTE_ACCOUNTS, bank_forks::BankForks, bank_forks_controller::{BankForksCommandReceiver, BankForksController}, commitment::BlockCommitmentCache, @@ -71,12 +74,7 @@ use { validated_block_finalization::ValidatedBlockFinalizationCert, vote_sender_types::ReplayVoteSender, }, - solana_streamer::{ - evicting_sender::EvictingSender, - nonblocking::simple_qos::SimpleQosConfig, - quic::{QuicStreamerConfig, SpawnServerResult, spawn_simple_qos_server}, - streamer::StakedNodes, - }, + solana_streamer::evicting_sender::EvictingSender, solana_turbine::{XdpSender, retransmit_stage::RetransmitStage}, std::{ collections::HashSet, @@ -181,11 +179,11 @@ pub struct AlpenglowInitializationState { // For BLS streamer setup pub cancel: CancellationToken, - pub staked_nodes: Arc>, pub key_notifiers: Arc>, + pub bls_datagram_staked_nodes: StakedNodePubkeySet, // For BLS voting service - pub bls_connection_cache: Arc, + pub bls_transport: VotingServiceTransport, pub voting_service_test_override: Option, } @@ -265,9 +263,9 @@ impl Tvu { votor_event_sender, votor_event_receiver, cancel, - staked_nodes, key_notifiers, - bls_connection_cache, + bls_datagram_staked_nodes, + bls_transport, voting_service_test_override, highest_finalized, } = votor_init; @@ -284,38 +282,23 @@ impl Tvu { let bls_sigverify_threads = if let Some(bls_socket) = bls_socket { let (bls_packet_sender, bls_packet_receiver) = bounded(MAX_ALPENGLOW_PACKET_NUM); - let ( - SpawnServerResult { - endpoints: _, - thread: bls_streamer_t, - key_updater: bls_key_updater, - }, - banlist, - ) = { - let quic_server_params = QuicStreamerConfig { - num_threads: NonZeroUsize::new(4.min(num_cpus::get())).unwrap(), - ..Default::default() - }; - let qos_config = SimpleQosConfig { - max_streams_per_second: 30, - // Cap by # of active validators (some overhead for epoch boundaries) - max_staked_connections: MAX_ALPENGLOW_VOTE_ACCOUNTS * 2, - // Two staked connection per validator to account for hotspares - max_connections_per_peer: 2, - }; - spawn_simple_qos_server( - "solQuicBLS", - "quic_streamer_bls", - vec![bls_socket.into()], - &cluster_info.keypair(), - bls_packet_sender, - staked_nodes, - quic_server_params, - qos_config, - cancel, - ) - .unwrap() + let is_staked_peer = { + let staked_nodes = bls_datagram_staked_nodes.clone(); + Arc::new(move |pubkey: &Pubkey| staked_nodes.load().contains(pubkey)) }; + let SpawnBlsQuicDatagramServerResult { + thread: bls_streamer_t, + key_updater: bls_key_updater, + banlist, + } = spawn_bls_quic_datagram_server( + "solQuicBLS", + bls_socket, + &cluster_info.keypair(), + bls_packet_sender, + is_staked_peer, + cancel, + ) + .unwrap(); // sigverifier let sharable_banks = bank_forks.read().unwrap().sharable_banks(); @@ -595,7 +578,7 @@ impl Tvu { bls_receiver, cluster_info.clone(), vote_history_storage, - bls_connection_cache, + bls_transport, bank_forks.clone(), voting_service_test_override, ); @@ -714,7 +697,9 @@ pub mod tests { event::{VotorEventReceiver, VotorEventSender}, vote_history::VoteHistory, vote_history_storage::NullVoteHistoryStorage, + voting_service::{QuicDatagramClientKeyUpdater, QuicDatagramSenderConfig}, }, + arc_swap::ArcSwap, serial_test::serial, solana_gossip::{cluster_info::ClusterInfo, node::Node}, solana_hash::Hash, @@ -799,10 +784,12 @@ pub mod tests { DEFAULT_TPU_CONNECTION_POOL_SIZE, ) }; - let bls_connection_cache = ConnectionCache::new_quic_for_tests( - "connection_cache_bls_quic", - DEFAULT_TPU_CONNECTION_POOL_SIZE, - ); + let bls_datagram_key_updater = + Arc::new(QuicDatagramClientKeyUpdater::new(&cref1.keypair())); + let bls_transport = VotingServiceTransport::QuicDatagram(QuicDatagramSenderConfig { + client_socket: target1.sockets.quic_alpenglow_client, + key_updater: bls_datagram_key_updater, + }); let replay_highest_frozen = Arc::new(ReplayHighestFrozen::default()); let (leader_window_info_sender, _leader_window_info_receiver) = unbounded(); let (optimistic_parent_sender, optimistic_parent_receiver) = unbounded(); @@ -815,8 +802,8 @@ pub mod tests { ))); let (votor_event_sender, votor_event_receiver): (VotorEventSender, VotorEventReceiver) = unbounded(); - let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); let key_notifiers = Arc::new(RwLock::new(KeyUpdaters::default())); + let bls_datagram_staked_nodes = Arc::new(ArcSwap::from_pointee(HashSet::default())); let cancel = CancellationToken::new(); thread::spawn({ let cancel = cancel.clone(); @@ -899,9 +886,9 @@ pub mod tests { votor_event_sender, votor_event_receiver, cancel, - staked_nodes, key_notifiers, - bls_connection_cache: Arc::new(bls_connection_cache), + bls_datagram_staked_nodes, + bls_transport, voting_service_test_override: None, highest_finalized: Arc::new(RwLock::new(None)), bank_forks_controller, diff --git a/core/src/validator.rs b/core/src/validator.rs index 95d2466f09e..718824834b0 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -36,10 +36,14 @@ use { agave_votor::{ vote_history::{VoteHistory, VoteHistoryError}, vote_history_storage::{NullVoteHistoryStorage, VoteHistoryStorage}, - voting_service::VotingServiceOverride, + voting_service::{ + QuicDatagramClientKeyUpdater, QuicDatagramSenderConfig, VotingServiceOverride, + VotingServiceTransport, + }, }, agave_xdp::transmitter::{Transmitter, TransmitterBuilder}, anyhow::{Result, anyhow}, + arc_swap::ArcSwap, crossbeam_channel::{Receiver, bounded, unbounded}, serde::{Deserialize, Serialize}, solana_account::ReadableAccount, @@ -1159,6 +1163,7 @@ impl Validator { let max_slots = Arc::new(MaxSlots::default()); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let bls_datagram_staked_nodes = Arc::new(ArcSwap::from_pointee(HashSet::default())); let mut tpu_transactions_forwards_client_sockets = Some(node.sockets.tpu_transaction_forwarding_clients); @@ -1187,30 +1192,17 @@ impl Validator { )) }; - let bls_connection_cache = Arc::new(ConnectionCache::new_with_client_options( - "connection_cache_bls_quic", - // BLS consensus messaging is extremely low throughput (5 PPS). Even during standstill operations - // we wouldn't expect more than a 100 PPS. 1 connection is enough. - 1, /* connection_pool_size */ - Some(node.sockets.quic_alpenglow_client), - Some(( - &identity_keypair, - node.info - .alpenglow() - .ok_or_else(|| { - ValidatorError::Other(String::from( - "Invalid QUIC address for Alpenglow BLS", - )) - })? - .ip(), - )), - Some((&staked_nodes, &identity_keypair.pubkey())), - )); + let bls_datagram_key_updater = + Arc::new(QuicDatagramClientKeyUpdater::new(&identity_keypair)); + let bls_transport = VotingServiceTransport::QuicDatagram(QuicDatagramSenderConfig { + client_socket: node.sockets.quic_alpenglow_client, + key_updater: bls_datagram_key_updater.clone(), + }); let key_notifiers = Arc::new(RwLock::new(KeyUpdaters::default())); - key_notifiers.write().unwrap().add( - KeyUpdaterType::BlsConnectionCache, - bls_connection_cache.clone(), - ); + key_notifiers + .write() + .unwrap() + .add(KeyUpdaterType::BlsDatagramClient, bls_datagram_key_updater); // test-validator crate may start the validator in a tokio runtime // context which forces us to use the same runtime because a nested @@ -1648,9 +1640,9 @@ impl Validator { votor_event_sender: votor_event_sender.clone(), votor_event_receiver, cancel: cancel.clone(), - staked_nodes: staked_nodes.clone(), key_notifiers: key_notifiers.clone(), - bls_connection_cache, + bls_datagram_staked_nodes: bls_datagram_staked_nodes.clone(), + bls_transport, voting_service_test_override: config.voting_service_test_override.clone(), highest_finalized, }, @@ -1709,6 +1701,7 @@ impl Validator { config.runtime_config.log_messages_bytes_limit, &staked_nodes, config.staked_nodes_overrides.clone(), + Some(bls_datagram_staked_nodes), banking_tracer_channels, tracer_thread, tpu_quic_server_config, diff --git a/dev-bins/Cargo.lock b/dev-bins/Cargo.lock index 75b2eb92c8d..a12defa208a 100644 --- a/dev-bins/Cargo.lock +++ b/dev-bins/Cargo.lock @@ -349,13 +349,17 @@ dependencies = [ "agave-logger", "agave-math-utils", "agave-votor-messages", + "arc-swap", "bitvec", + "bytes", "crossbeam-channel", "itertools 0.14.0", "lazy-lru", "log", "parking_lot 0.12.5", "qualifier_attr", + "quinn", + "rustls", "serde", "serde_bytes", "solana-accounts-db", @@ -380,11 +384,13 @@ dependencies = [ "solana-signer-store", "solana-streamer", "solana-time-utils", + "solana-tls-utils", "solana-transaction", "solana-transaction-error", "solana-vote", "solana-vote-program", "thiserror 2.0.18", + "tokio", "wincode", ] @@ -6665,6 +6671,7 @@ dependencies = [ "num_cpus", "num_enum", "qualifier_attr", + "quinn", "rand 0.9.4", "rand_chacha 0.9.0", "rayon", diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index 0cfee5b8e99..bfe7555a1c4 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -337,13 +337,17 @@ dependencies = [ "agave-logger", "agave-math-utils", "agave-votor-messages", + "arc-swap", "bitvec", + "bytes", "crossbeam-channel", "itertools 0.14.0", "lazy-lru", "log", "parking_lot 0.12.2", "qualifier_attr", + "quinn", + "rustls", "serde", "serde_bytes", "solana-accounts-db", @@ -368,11 +372,13 @@ dependencies = [ "solana-signer-store", "solana-streamer", "solana-time-utils", + "solana-tls-utils", "solana-transaction", "solana-transaction-error", "solana-vote", "solana-vote-program", "thiserror 2.0.18", + "tokio", "wincode", ] @@ -6512,6 +6518,7 @@ dependencies = [ "num_cpus", "num_enum", "qualifier_attr", + "quinn", "rand 0.9.4", "rand_chacha 0.9.0", "rayon", diff --git a/validator/src/admin_rpc_service.rs b/validator/src/admin_rpc_service.rs index 13366d1598e..07a7d9106b0 100644 --- a/validator/src/admin_rpc_service.rs +++ b/validator/src/admin_rpc_service.rs @@ -1373,7 +1373,7 @@ mod tests { KeyUpdaterType::Forward, KeyUpdaterType::RpcService, KeyUpdaterType::Bls, - KeyUpdaterType::BlsConnectionCache, + KeyUpdaterType::BlsDatagramClient, ]) ); let mut io = MetaIoHandler::default(); diff --git a/votor/Cargo.toml b/votor/Cargo.toml index 50794e2696d..6b66fd2a8ce 100644 --- a/votor/Cargo.toml +++ b/votor/Cargo.toml @@ -32,13 +32,17 @@ frozen-abi = [ agave-logger = { workspace = true } agave-math-utils = { workspace = true } agave-votor-messages = { workspace = true } +arc-swap = { workspace = true } bitvec = { workspace = true } +bytes = { workspace = true } crossbeam-channel = { workspace = true } itertools = { workspace = true } lazy-lru = { workspace = true } log = { workspace = true } parking_lot = { workspace = true } qualifier_attr = { workspace = true } +quinn = { workspace = true } +rustls = { workspace = true } serde = { workspace = true } serde_bytes = { workspace = true } solana-accounts-db = { workspace = true } @@ -69,11 +73,13 @@ solana-signer = { workspace = true } solana-signer-store = { workspace = true } solana-streamer = { workspace = true } solana-time-utils = { workspace = true } +solana-tls-utils = { workspace = true } solana-transaction = { workspace = true } solana-transaction-error = { workspace = true } solana-vote = { workspace = true } solana-vote-program = { workspace = true } thiserror = { workspace = true } +tokio = { workspace = true, features = ["full"] } wincode = { workspace = true, features = ["alloc"] } [dev-dependencies] diff --git a/votor/src/lib.rs b/votor/src/lib.rs index e3cf0a52fa9..be58bab6c01 100644 --- a/votor/src/lib.rs +++ b/votor/src/lib.rs @@ -12,6 +12,7 @@ mod consensus_pool_service; pub mod event; mod event_handler; pub mod generated_cert_types; +mod quic_datagram_sender; pub mod root_utils; mod staked_validators_cache; mod timer_manager; diff --git a/votor/src/quic_datagram_sender.rs b/votor/src/quic_datagram_sender.rs new file mode 100644 index 00000000000..112bae43aeb --- /dev/null +++ b/votor/src/quic_datagram_sender.rs @@ -0,0 +1,417 @@ +use { + arc_swap::ArcSwap, + bytes::Bytes, + quinn::{ + ClientConfig, ConnectError, Connection, ConnectionError, Endpoint, EndpointConfig, + IdleTimeout, SendDatagramError, TokioRuntime, TransportConfig, + crypto::rustls::QuicClientConfig, + }, + rustls::KeyLogFile, + solana_keypair::Keypair, + solana_streamer::{nonblocking::quic::ALPN_TPU_PROTOCOL_ID, packet::PACKET_DATA_SIZE}, + solana_tls_utils::{ + NotifyKeyUpdate, QuicClientCertificate, socket_addr_to_quic_server_name, + tls_client_config_builder, + }, + std::{ + collections::HashMap, + net::{SocketAddr, UdpSocket}, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, + }, + tokio::{task::JoinHandle, time::timeout}, +}; + +const QUIC_MAX_TIMEOUT: Duration = Duration::from_secs(10); +const QUIC_KEEP_ALIVE: Duration = Duration::from_secs(1); +const QUIC_CONNECTION_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(2); +const DATAGRAM_RECEIVE_BUFFER_SIZE: usize = PACKET_DATA_SIZE * 64; +const CONNECTION_CLOSE_CODE: u32 = 0; +const CONNECTION_CLOSE_REASON_KEY_UPDATE: &[u8] = b"key_update"; + +pub struct QuicDatagramSenderConfig { + pub client_socket: UdpSocket, + pub key_updater: Arc, +} + +pub struct QuicDatagramClientKeyUpdater { + client_certificate: ArcSwap, + dirty: AtomicBool, +} + +impl QuicDatagramClientKeyUpdater { + pub fn new(keypair: &Keypair) -> Self { + Self { + client_certificate: ArcSwap::new(Arc::new(QuicClientCertificate::new(Some(keypair)))), + dirty: AtomicBool::new(false), + } + } + + fn client_config(&self) -> ClientConfig { + let client_certificate = self.client_certificate.load_full(); + client_config(client_certificate.as_ref()) + } + + fn take_dirty(&self) -> bool { + self.dirty.swap(false, Ordering::AcqRel) + } +} + +impl NotifyKeyUpdate for QuicDatagramClientKeyUpdater { + fn update_key(&self, key: &Keypair) -> Result<(), Box> { + self.client_certificate + .store(Arc::new(QuicClientCertificate::new(Some(key)))); + self.dirty.store(true, Ordering::Release); + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum QuicDatagramSenderError { + #[error("QUIC endpoint creation failed: {0}")] + EndpointFailed(#[from] std::io::Error), +} + +#[derive(Debug, thiserror::Error)] +enum ConnectAndSendError { + #[error("connect failed: {0}")] + Connect(#[from] ConnectError), + #[error("connection failed: {0}")] + Connection(#[from] ConnectionError), + #[error("datagram send failed: {0}")] + SendDatagram(#[from] SendDatagramError), + #[error("connection timed out")] + Timeout, +} + +enum PeerConnection { + Connected(Connection), + Connecting(JoinHandle>), +} + +pub(crate) struct QuicDatagramSender { + endpoint: Endpoint, + key_updater: Arc, + connections: HashMap, +} + +impl QuicDatagramSender { + pub(crate) fn new(config: QuicDatagramSenderConfig) -> Result { + let QuicDatagramSenderConfig { + client_socket, + key_updater, + } = config; + let mut endpoint = Endpoint::new( + EndpointConfig::default(), + None, + client_socket, + Arc::new(TokioRuntime), + )?; + endpoint.set_default_client_config(key_updater.client_config()); + Ok(Self { + endpoint, + key_updater, + connections: HashMap::new(), + }) + } + + pub(crate) async fn send(&mut self, addr: SocketAddr, data: &[u8]) { + self.apply_key_update_if_needed(); + + let data = Bytes::copy_from_slice(data); + match self.connections.remove(&addr) { + Some(PeerConnection::Connected(connection)) => { + self.send_on_connection(addr, connection, data); + } + Some(PeerConnection::Connecting(handle)) => { + if handle.is_finished() { + match handle.await { + Ok(Ok(connection)) => self.send_on_connection(addr, connection, data), + Ok(Err(err)) => { + warn!("Failed to connect votor QUIC datagram peer {addr}: {err}"); + self.start_connecting(addr, data); + } + Err(err) => { + warn!("Votor QUIC datagram connect task failed for {addr}: {err}"); + self.start_connecting(addr, data); + } + } + } else { + debug!("dropping votor QUIC datagram to {addr}: connection in progress"); + self.connections + .insert(addr, PeerConnection::Connecting(handle)); + } + } + None => self.start_connecting(addr, data), + } + } + + fn apply_key_update_if_needed(&mut self) { + if !self.key_updater.take_dirty() { + return; + } + self.endpoint + .set_default_client_config(self.key_updater.client_config()); + for (_, connection) in self.connections.drain() { + match connection { + PeerConnection::Connected(connection) => connection.close( + CONNECTION_CLOSE_CODE.into(), + CONNECTION_CLOSE_REASON_KEY_UPDATE, + ), + PeerConnection::Connecting(handle) => handle.abort(), + } + } + } + + fn send_on_connection(&mut self, addr: SocketAddr, connection: Connection, data: Bytes) { + match connection.send_datagram(data.clone()) { + Ok(()) => { + self.connections + .insert(addr, PeerConnection::Connected(connection)); + } + Err(SendDatagramError::ConnectionLost(err)) => { + warn!("Lost votor QUIC datagram connection to {addr}: {err}"); + self.start_connecting(addr, data); + } + Err(err) => { + warn!("Failed to send votor QUIC datagram to {addr}: {err}"); + self.connections + .insert(addr, PeerConnection::Connected(connection)); + } + } + } + + fn start_connecting(&mut self, addr: SocketAddr, data: Bytes) { + let endpoint = self.endpoint.clone(); + let handle = tokio::spawn(connect_and_send(endpoint, addr, data)); + self.connections + .insert(addr, PeerConnection::Connecting(handle)); + } +} + +fn client_config(client_certificate: &QuicClientCertificate) -> ClientConfig { + let mut crypto = tls_client_config_builder() + .with_client_auth_cert( + vec![client_certificate.certificate.clone()], + client_certificate.key.clone_key(), + ) + .expect("valid votor QUIC client certificate"); + crypto.enable_early_data = true; + crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + crypto.key_log = Arc::new(KeyLogFile::new()); + + let mut transport_config = TransportConfig::default(); + transport_config.max_idle_timeout(Some(IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap())); + transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE)); + transport_config.send_fairness(false); + transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE)); + + let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).unwrap())); + config.transport_config(Arc::new(transport_config)); + config +} + +async fn connect_and_send( + endpoint: Endpoint, + addr: SocketAddr, + data: Bytes, +) -> Result { + let server_name = socket_addr_to_quic_server_name(addr); + let connecting = endpoint.connect(addr, &server_name)?; + let connection = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting) + .await + .map_err(|_| ConnectAndSendError::Timeout)??; + connection.send_datagram(data)?; + Ok(connection) +} + +#[cfg(test)] +mod tests { + use { + super::*, + quinn::{ServerConfig, crypto::rustls::QuicServerConfig}, + solana_net_utils::sockets::bind_to_localhost_unique, + solana_signer::Signer, + solana_tls_utils::{ + get_remote_pubkey, new_dummy_x509_certificate, tls_server_config_builder, + }, + }; + + #[test] + fn test_send_drops_datagrams_while_connecting() { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let server_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let endpoint = Endpoint::new( + EndpointConfig::default(), + Some(test_server_config(&server_keypair)), + server_socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + let client_keypair = Keypair::new(); + let mut sender = QuicDatagramSender::new(QuicDatagramSenderConfig { + client_socket: bind_to_localhost_unique().unwrap(), + key_updater: Arc::new(QuicDatagramClientKeyUpdater::new(&client_keypair)), + }) + .unwrap(); + + sender.send(server_addr, b"first").await; + sender.send(server_addr, b"second").await; + + let incoming = endpoint.accept().await.unwrap(); + let connection = timeout(Duration::from_secs(5), incoming) + .await + .unwrap() + .unwrap(); + let first = timeout(Duration::from_secs(5), connection.read_datagram()) + .await + .unwrap() + .unwrap(); + let second = timeout(Duration::from_millis(200), connection.read_datagram()).await; + + assert_eq!(first.as_ref(), b"first"); + assert!(second.is_err()); + }); + } + + #[test] + fn test_send_drops_all_datagrams_while_connecting() { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let server_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let endpoint = Endpoint::new( + EndpointConfig::default(), + Some(test_server_config(&server_keypair)), + server_socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + let client_keypair = Keypair::new(); + let mut sender = QuicDatagramSender::new(QuicDatagramSenderConfig { + client_socket: bind_to_localhost_unique().unwrap(), + key_updater: Arc::new(QuicDatagramClientKeyUpdater::new(&client_keypair)), + }) + .unwrap(); + + for i in 0..32usize { + sender + .send(server_addr, format!("queued-while-dialing-{i}").as_bytes()) + .await; + } + + let incoming = endpoint.accept().await.unwrap(); + let connection = timeout(Duration::from_secs(5), incoming) + .await + .unwrap() + .unwrap(); + let first = timeout(Duration::from_secs(5), connection.read_datagram()) + .await + .unwrap() + .unwrap(); + let second = timeout(Duration::from_millis(200), connection.read_datagram()).await; + + assert_eq!(first.as_ref(), b"queued-while-dialing-0"); + assert!(second.is_err()); + }); + } + + #[test] + fn test_key_update_closes_cached_connection_and_reconnects_with_new_identity() { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let server_keypair = Keypair::new(); + let server_socket = bind_to_localhost_unique().unwrap(); + let server_addr = server_socket.local_addr().unwrap(); + let endpoint = Endpoint::new( + EndpointConfig::default(), + Some(test_server_config(&server_keypair)), + server_socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + let old_client_keypair = Keypair::new(); + let new_client_keypair = Keypair::new(); + let key_updater = Arc::new(QuicDatagramClientKeyUpdater::new(&old_client_keypair)); + let mut sender = QuicDatagramSender::new(QuicDatagramSenderConfig { + client_socket: bind_to_localhost_unique().unwrap(), + key_updater: key_updater.clone(), + }) + .unwrap(); + + sender.send(server_addr, b"before-rotate").await; + let incoming = endpoint.accept().await.unwrap(); + let old_connection = timeout(Duration::from_secs(5), incoming) + .await + .unwrap() + .unwrap(); + let first = timeout(Duration::from_secs(5), old_connection.read_datagram()) + .await + .unwrap() + .unwrap(); + assert_eq!(first.as_ref(), b"before-rotate"); + assert_eq!( + get_remote_pubkey(&old_connection), + Some(old_client_keypair.pubkey()) + ); + + key_updater.update_key(&new_client_keypair).unwrap(); + sender.send(server_addr, b"after-rotate").await; + + timeout(Duration::from_secs(5), old_connection.closed()) + .await + .expect("old connection should be closed on key update"); + let incoming = endpoint.accept().await.unwrap(); + let new_connection = timeout(Duration::from_secs(5), incoming) + .await + .unwrap() + .unwrap(); + let second = timeout(Duration::from_secs(5), new_connection.read_datagram()) + .await + .unwrap() + .unwrap(); + assert_eq!(second.as_ref(), b"after-rotate"); + assert_eq!( + get_remote_pubkey(&new_connection), + Some(new_client_keypair.pubkey()) + ); + }); + } + + fn test_server_config(keypair: &Keypair) -> ServerConfig { + let (cert, priv_key) = new_dummy_x509_certificate(keypair); + let mut server_tls_config = tls_server_config_builder() + .with_single_cert(vec![cert], priv_key) + .unwrap(); + server_tls_config.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + let mut server_config = ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(server_tls_config).unwrap(), + )); + server_config.migration(false); + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(0u32.into()); + transport_config.max_concurrent_bidi_streams(0u32.into()); + transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE)); + transport_config.max_idle_timeout(Some(IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap())); + server_config + } +} diff --git a/votor/src/voting_service.rs b/votor/src/voting_service.rs index fb0202b1cee..4eafe60caac 100644 --- a/votor/src/voting_service.rs +++ b/votor/src/voting_service.rs @@ -1,5 +1,6 @@ use { crate::{ + quic_datagram_sender::{QuicDatagramSender, QuicDatagramSenderError}, staked_validators_cache::StakedValidatorsCache, vote_history_storage::{SavedVoteHistoryVersions, VoteHistoryStorage}, }, @@ -7,15 +8,12 @@ use { certificate::Certificate, consensus_message::{ConsensusMessage, VoteMessage}, }, - crossbeam_channel::Receiver, - solana_client::connection_cache::ConnectionCache, + crossbeam_channel::{Receiver, TryRecvError}, solana_clock::Slot, - solana_connection_cache::client_connection::ClientConnection, solana_gossip::cluster_info::ClusterInfo, solana_measure::measure::Measure, solana_pubkey::Pubkey, solana_runtime::bank_forks::BankForks, - solana_transaction_error::TransportError, std::{ collections::HashMap, net::SocketAddr, @@ -42,14 +40,10 @@ pub enum BLSOp { }, } -fn send_message( - buf: Vec, - socket: &SocketAddr, - connection_cache: &Arc, -) -> Result<(), TransportError> { - let client = connection_cache.get_connection(socket); +pub use crate::quic_datagram_sender::{QuicDatagramClientKeyUpdater, QuicDatagramSenderConfig}; - client.send_data_async(Arc::new(buf)) +pub enum VotingServiceTransport { + QuicDatagram(QuicDatagramSenderConfig), } pub struct VotingService { @@ -122,7 +116,7 @@ impl VotingService { bls_receiver: Receiver, cluster_info: Arc, vote_history_storage: Arc, - connection_cache: Arc, + transport: VotingServiceTransport, bank_forks: Arc>, test_override: Option, ) -> Self { @@ -137,36 +131,96 @@ impl VotingService { let thread_hdl = Builder::new() .name("solVotorVoteSvc".to_string()) .spawn(move || { - let mut staked_validators_cache = StakedValidatorsCache::new( - bank_forks.clone(), + let staked_validators_cache = StakedValidatorsCache::new( + bank_forks, Duration::from_secs(STAKED_VALIDATORS_CACHE_TTL_S), STAKED_VALIDATORS_CACHE_NUM_EPOCH_TARGET, false, alpenglow_port_override, ); - info!("AlpenglowVotingService has started"); - while let Ok(bls_op) = bls_receiver.recv() { - Self::handle_bls_op( - &cluster_info, - vote_history_storage.as_ref(), - bls_op, - connection_cache.clone(), - &additional_listeners, - &mut staked_validators_cache, - ); + match transport { + VotingServiceTransport::QuicDatagram(config) => Self::run_quic_datagram_loop( + config, + bls_receiver, + cluster_info, + vote_history_storage, + additional_listeners, + staked_validators_cache, + ), } - info!("AlpenglowVotingService has stopped"); }) .unwrap(); Self { thread_hdl } } - fn broadcast_consensus_message( + fn run_quic_datagram_loop( + config: QuicDatagramSenderConfig, + bls_receiver: Receiver, + cluster_info: Arc, + vote_history_storage: Arc, + additional_listeners: Vec, + mut staked_validators_cache: StakedValidatorsCache, + ) { + let runtime = tokio::runtime::Builder::new_current_thread() + .thread_name("solVotorVoteRt") + .enable_all() + .build() + .unwrap(); + let guard = runtime.enter(); + let mut sender = match QuicDatagramSender::new(config) { + Ok(sender) => sender, + Err(err) => { + Self::log_datagram_sender_start_error(err); + return; + } + }; + drop(guard); + + runtime.block_on(async move { + info!("AlpenglowVotingService has started"); + loop { + let mut handled_message = false; + loop { + match bls_receiver.try_recv() { + Ok(bls_op) => { + handled_message = true; + Self::handle_bls_op( + &cluster_info, + vote_history_storage.as_ref(), + bls_op, + &mut sender, + &additional_listeners, + &mut staked_validators_cache, + ) + .await; + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => { + info!("AlpenglowVotingService has stopped"); + return; + } + } + } + + if handled_message { + tokio::task::yield_now().await; + } else { + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + }); + } + + fn log_datagram_sender_start_error(err: QuicDatagramSenderError) { + error!("Failed to start AlpenglowVotingService QUIC datagram sender: {err}"); + } + + async fn broadcast_consensus_message( slot: Slot, cluster_info: &ClusterInfo, message: &ConsensusMessage, - connection_cache: Arc, + sender: &mut QuicDatagramSender, additional_listeners: &[SocketAddr], staked_validators_cache: &mut StakedValidatorsCache, ) { @@ -184,21 +238,16 @@ impl VotingService { .iter() .chain(staked_validator_alpenglow_sockets.iter()); - // We use send_message in a loop right now because we worry that sending packets too fast - // will cause a packet spike and overwhelm the network. If we later find out that this is - // not an issue, we can optimize this by using multi_targret_send or similar methods. for socket in sockets { - if let Err(e) = send_message(buf.clone(), socket, &connection_cache) { - warn!("Failed to send alpenglow message to {socket}: {e:?}"); - } + sender.send(*socket, &buf).await; } } - fn handle_bls_op( + async fn handle_bls_op( cluster_info: &ClusterInfo, vote_history_storage: &dyn VoteHistoryStorage, bls_op: BLSOp, - connection_cache: Arc, + sender: &mut QuicDatagramSender, additional_listeners: &[SocketAddr], staked_validators_cache: &mut StakedValidatorsCache, ) { @@ -220,10 +269,11 @@ impl VotingService { slot, cluster_info, &msg, - connection_cache, + sender, additional_listeners, staked_validators_cache, - ); + ) + .await; } BLSOp::PushCertificate { certificate } => { let slot = certificate.cert_type.slot(); @@ -232,10 +282,11 @@ impl VotingService { slot, cluster_info, &message, - connection_cache, + sender, additional_listeners, staked_validators_cache, - ); + ) + .await; } } } @@ -257,6 +308,11 @@ mod tests { consensus_message::{ConsensusMessage, VoteMessage}, vote::Vote, }, + quinn::{ + Endpoint, EndpointConfig, IdleTimeout, ServerConfig, TokioRuntime, + crypto::rustls::QuicServerConfig, + }, + rustls::KeyLogFile, solana_bls_signatures::{BLS_SIGNATURE_AFFINE_SIZE, Signature as BLSSignature}, solana_gossip::{cluster_info::ClusterInfo, contact_info::ContactInfo}, solana_keypair::Keypair, @@ -269,24 +325,21 @@ mod tests { }, }, solana_signer::Signer, - solana_streamer::{ - nonblocking::swqos::SwQosConfig, - quic::{QuicStreamerConfig, SpawnServerResult, spawn_stake_weighted_qos_server}, - streamer::StakedNodes, - }, + solana_streamer::{nonblocking::quic::ALPN_TPU_PROTOCOL_ID, packet::PACKET_DATA_SIZE}, + solana_tls_utils::{new_dummy_x509_certificate, tls_server_config_builder}, std::{ - net::SocketAddr, - sync::{Arc, RwLock}, + net::{SocketAddr, UdpSocket}, + sync::Arc, + thread, + time::Duration, }, test_case::test_case, - tokio_util::sync::CancellationToken, }; fn create_voting_service( bls_receiver: Receiver, listener: SocketAddr, ) -> (VotingService, Vec) { - // Create 10 node validatorvotekeypairs vec let validator_keypairs = (0..10) .map(|_| ValidatorVoteKeypairs::new_rand()) .collect::>(); @@ -299,21 +352,23 @@ mod tests { let bank_forks = BankForks::new_rw_arc(bank0); let keypair = Keypair::new(); let contact_info = ContactInfo::new_localhost(&keypair.pubkey(), 0); + let key_updater = Arc::new(QuicDatagramClientKeyUpdater::new(&keypair)); let cluster_info = ClusterInfo::new( contact_info, Arc::new(keypair), SocketAddrSpace::Unspecified, ); + let transport = VotingServiceTransport::QuicDatagram(QuicDatagramSenderConfig { + client_socket: bind_to_localhost_unique().unwrap(), + key_updater, + }); ( VotingService::new( bls_receiver, Arc::new(cluster_info), Arc::new(NullVoteHistoryStorage::default()), - Arc::new(ConnectionCache::new_quic( - "TestAlpenglowConnectionCache", - 10, - )), + transport, bank_forks, Some(VotingServiceOverride { additional_listeners: vec![listener], @@ -324,6 +379,56 @@ mod tests { ) } + fn spawn_quic_datagram_receiver(socket: UdpSocket) -> crossbeam_channel::Receiver> { + let (sender, receiver) = crossbeam_channel::bounded(1); + let (ready_sender, ready_receiver) = crossbeam_channel::bounded(1); + thread::spawn(move || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let guard = runtime.enter(); + let endpoint = Endpoint::new( + EndpointConfig::default(), + Some(test_server_config(&Keypair::new())), + socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + drop(guard); + ready_sender.send(()).unwrap(); + runtime.block_on(async move { + let incoming = endpoint.accept().await.unwrap(); + let connection = incoming.await.unwrap(); + let datagram = connection.read_datagram().await.unwrap(); + sender.send(datagram.to_vec()).unwrap(); + }); + }); + ready_receiver.recv().unwrap(); + receiver + } + + fn test_server_config(keypair: &Keypair) -> ServerConfig { + let (cert, priv_key) = new_dummy_x509_certificate(keypair); + let mut server_tls_config = tls_server_config_builder() + .with_single_cert(vec![cert], priv_key) + .unwrap(); + server_tls_config.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + server_tls_config.key_log = Arc::new(KeyLogFile::new()); + let mut server_config = ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(server_tls_config).unwrap(), + )); + server_config.migration(false); + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(0u32.into()); + transport_config.max_concurrent_bidi_streams(0u32.into()); + transport_config.datagram_receive_buffer_size(Some(PACKET_DATA_SIZE * 64)); + transport_config.max_idle_timeout(Some( + IdleTimeout::try_from(Duration::from_secs(30)).unwrap(), + )); + server_config + } + #[test_case(BLSOp::PushVote { vote: Arc::new(VoteMessage { vote: Vote::new_skip_vote(5), @@ -350,51 +455,19 @@ mod tests { fn test_send_message(bls_op: BLSOp, expected_message: ConsensusMessage) { agave_logger::setup(); let (bls_sender, bls_receiver) = crossbeam_channel::unbounded(); - // Create listener thread on a random port we allocated and return SocketAddr to create VotingService - - // Bind to a random UDP port let socket = bind_to_localhost_unique().unwrap(); let listener_addr = socket.local_addr().unwrap(); + let datagram_receiver = spawn_quic_datagram_receiver(socket); + let (voting_service, _validator_keypairs) = + create_voting_service(bls_receiver, listener_addr); - // Create VotingService with the listener address - let (_, validator_keypairs) = create_voting_service(bls_receiver, listener_addr); - - // Send a BLS message via the VotingService assert!(bls_sender.send(bls_op).is_ok()); - // Start a quick streamer to handle quick control packets - let (sender, receiver) = crossbeam_channel::unbounded(); - let stakes = validator_keypairs - .iter() - .map(|x| (x.node_keypair.pubkey(), 100)) - .collect(); - let staked_nodes: Arc> = Arc::new(RwLock::new(StakedNodes::new( - Arc::new(stakes), - HashMap::::default(), // overrides - ))); - let cancel = CancellationToken::new(); - let SpawnServerResult { - endpoints: _, - thread: quic_server_thread, - key_updater: _, - } = spawn_stake_weighted_qos_server( - "AlpenglowLocalClusterTest", - "voting_service_test", - [socket.into()], - &Keypair::new(), - sender, - staked_nodes, - QuicStreamerConfig::default_for_tests(), - SwQosConfig::default(), - cancel.clone(), - ) - .unwrap(); - - let packets = receiver.recv().unwrap(); - let packet = packets.first().expect("No packets received"); - let received_message = packet - .deserialize_slice::(..) - .unwrap_or_else(|err| { + let datagram = datagram_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("No datagram received"); + let received_message = + wincode::deserialize::(&datagram).unwrap_or_else(|err| { panic!( "Failed to deserialize BLSMessage: {:?} {:?}", size_of::(), @@ -402,7 +475,7 @@ mod tests { ) }); assert_eq!(received_message, expected_message); - cancel.cancel(); - quic_server_thread.join().unwrap(); + drop(bls_sender); + voting_service.join().unwrap(); } }