diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index 8b8c48d7db42f6..c56a532f4ae680 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -6293,6 +6293,7 @@ dependencies = [ "rand 0.8.5", "rustls", "smallvec", + "socket2 0.5.7", "solana-measure", "solana-metrics", "solana-perf", diff --git a/streamer/Cargo.toml b/streamer/Cargo.toml index 8ffa23299ba5b2..b6051bc604451f 100644 --- a/streamer/Cargo.toml +++ b/streamer/Cargo.toml @@ -29,6 +29,7 @@ quinn-proto = { workspace = true } rand = { workspace = true } rustls = { workspace = true, features = ["dangerous_configuration"] } smallvec = { workspace = true } +socket2 = { workspace = true } solana-measure = { workspace = true } solana-metrics = { workspace = true } solana-perf = { workspace = true } @@ -40,7 +41,6 @@ x509-parser = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } -socket2 = { workspace = true } solana-logger = { workspace = true } [lib] diff --git a/streamer/src/nonblocking/mod.rs b/streamer/src/nonblocking/mod.rs index 9eed5e402c5f25..d7205e42468235 100644 --- a/streamer/src/nonblocking/mod.rs +++ b/streamer/src/nonblocking/mod.rs @@ -5,3 +5,4 @@ pub mod rate_limiter; pub mod recvmmsg; pub mod sendmmsg; mod stream_throttle; +pub mod testing_utilities; diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 0604973526fb55..182d54630beec9 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -1439,159 +1439,20 @@ pub mod test { use { super::*, crate::{ - nonblocking::quic::compute_max_allowed_uni_streams, + nonblocking::{ + quic::compute_max_allowed_uni_streams, + testing_utilities::{get_client_config, make_client_endpoint, setup_quic_server}, + }, quic::{MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS}, - tls_certificates::new_dummy_x509_certificate, }, assert_matches::assert_matches, async_channel::unbounded as async_unbounded, crossbeam_channel::{unbounded, Receiver}, - quinn::{ClientConfig, IdleTimeout, TransportConfig}, - solana_sdk::{ - net::DEFAULT_TPU_COALESCE, - quic::{QUIC_KEEP_ALIVE, QUIC_MAX_TIMEOUT}, - signature::Keypair, - signer::Signer, - }, + solana_sdk::{net::DEFAULT_TPU_COALESCE, signature::Keypair, signer::Signer}, std::collections::HashMap, tokio::time::sleep, }; - struct SkipServerVerification; - - impl SkipServerVerification { - fn new() -> Arc { - Arc::new(Self) - } - } - - impl rustls::client::ServerCertVerifier for SkipServerVerification { - fn verify_server_cert( - &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &rustls::ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> Result { - Ok(rustls::client::ServerCertVerified::assertion()) - } - } - - pub fn get_client_config(keypair: &Keypair) -> ClientConfig { - let (cert, key) = new_dummy_x509_certificate(keypair); - - let mut crypto = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(SkipServerVerification::new()) - .with_client_auth_cert(vec![cert], key) - .expect("Failed to use client certificate"); - - crypto.enable_early_data = true; - crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; - - let mut config = ClientConfig::new(Arc::new(crypto)); - - let mut transport_config = TransportConfig::default(); - let timeout = IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap(); - transport_config.max_idle_timeout(Some(timeout)); - transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE)); - config.transport_config(Arc::new(transport_config)); - - config - } - - fn setup_quic_server( - option_staked_nodes: Option, - max_connections_per_peer: usize, - ) -> ( - JoinHandle<()>, - Arc, - crossbeam_channel::Receiver, - SocketAddr, - Arc, - ) { - let sockets = { - #[cfg(not(target_os = "windows"))] - { - use std::{ - os::fd::{FromRawFd, IntoRawFd}, - str::FromStr as _, - }; - (0..10) - .map(|_| { - let sock = socket2::Socket::new( - socket2::Domain::IPV4, - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - ) - .unwrap(); - sock.set_reuse_port(true).unwrap(); - sock.bind(&SocketAddr::from_str("127.0.0.1:0").unwrap().into()) - .unwrap(); - unsafe { UdpSocket::from_raw_fd(sock.into_raw_fd()) } - }) - .collect::>() - } - #[cfg(target_os = "windows")] - { - vec![UdpSocket::bind("127.0.0.1:0").unwrap()] - } - }; - - let exit = Arc::new(AtomicBool::new(false)); - let (sender, receiver) = unbounded(); - let keypair = Keypair::new(); - let server_address = sockets[0].local_addr().unwrap(); - let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default())); - let SpawnNonBlockingServerResult { - endpoints: _, - stats, - thread: t, - max_concurrent_connections: _, - } = spawn_server_multi( - "quic_streamer_test", - sockets, - &keypair, - sender, - exit.clone(), - max_connections_per_peer, - staked_nodes, - MAX_STAKED_CONNECTIONS, - MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, - DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE, - Duration::from_secs(2), - DEFAULT_TPU_COALESCE, - ) - .unwrap(); - (t, exit, receiver, server_address, stats) - } - - pub async fn make_client_endpoint( - addr: &SocketAddr, - client_keypair: Option<&Keypair>, - ) -> Connection { - let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap(); - let mut endpoint = quinn::Endpoint::new( - EndpointConfig::default(), - None, - client_socket, - Arc::new(TokioRuntime), - ) - .unwrap(); - let default_keypair = Keypair::new(); - endpoint.set_default_client_config(get_client_config( - client_keypair.unwrap_or(&default_keypair), - )); - endpoint - .connect(*addr, "localhost") - .expect("Failed in connecting") - .await - .expect("Failed in waiting") - } - pub async fn check_timeout(receiver: Receiver, server_address: SocketAddr) { let conn1 = make_client_endpoint(&server_address, None).await; let total = 30; diff --git a/streamer/src/nonblocking/testing_utilities.rs b/streamer/src/nonblocking/testing_utilities.rs new file mode 100644 index 00000000000000..140cca10b5f911 --- /dev/null +++ b/streamer/src/nonblocking/testing_utilities.rs @@ -0,0 +1,161 @@ +//! Contains utility functions to create server and client for test purposes. +use { + super::quic::{ + spawn_server_multi, SpawnNonBlockingServerResult, ALPN_TPU_PROTOCOL_ID, + DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE, DEFAULT_MAX_STREAMS_PER_MS, + }, + crate::{ + quic::{StreamerStats, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS}, + streamer::StakedNodes, + tls_certificates::new_dummy_x509_certificate, + }, + crossbeam_channel::unbounded, + quinn::{ClientConfig, Connection, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig}, + solana_perf::packet::PacketBatch, + solana_sdk::{ + net::DEFAULT_TPU_COALESCE, + quic::{QUIC_KEEP_ALIVE, QUIC_MAX_TIMEOUT}, + signer::keypair::Keypair, + }, + std::{ + net::{SocketAddr, UdpSocket}, + sync::{atomic::AtomicBool, Arc, RwLock}, + time::Duration, + }, + tokio::task::JoinHandle, +}; + +struct SkipServerVerification; + +impl SkipServerVerification { + fn new() -> Arc { + Arc::new(Self) + } +} + +impl rustls::client::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} + +pub fn get_client_config(keypair: &Keypair) -> ClientConfig { + let (cert, key) = new_dummy_x509_certificate(keypair); + + let mut crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_client_auth_cert(vec![cert], key) + .expect("Failed to use client certificate"); + + crypto.enable_early_data = true; + crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()]; + + let mut config = ClientConfig::new(Arc::new(crypto)); + + let mut transport_config = TransportConfig::default(); + let timeout = IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap(); + transport_config.max_idle_timeout(Some(timeout)); + transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE)); + config.transport_config(Arc::new(transport_config)); + + config +} + +pub fn setup_quic_server( + option_staked_nodes: Option, + max_connections_per_peer: usize, +) -> ( + JoinHandle<()>, + Arc, + crossbeam_channel::Receiver, + SocketAddr, + Arc, +) { + let sockets = { + #[cfg(not(target_os = "windows"))] + { + use std::{ + os::fd::{FromRawFd, IntoRawFd}, + str::FromStr as _, + }; + (0..10) + .map(|_| { + let sock = socket2::Socket::new( + socket2::Domain::IPV4, + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .unwrap(); + sock.set_reuse_port(true).unwrap(); + sock.bind(&SocketAddr::from_str("127.0.0.1:0").unwrap().into()) + .unwrap(); + unsafe { UdpSocket::from_raw_fd(sock.into_raw_fd()) } + }) + .collect::>() + } + #[cfg(target_os = "windows")] + { + vec![UdpSocket::bind("127.0.0.1:0").unwrap()] + } + }; + + let exit = Arc::new(AtomicBool::new(false)); + let (sender, receiver) = unbounded(); + let keypair = Keypair::new(); + let server_address = sockets[0].local_addr().unwrap(); + let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default())); + let SpawnNonBlockingServerResult { + endpoints: _, + stats, + thread: t, + max_concurrent_connections: _, + } = spawn_server_multi( + "quic_streamer_test", + sockets, + &keypair, + sender, + exit.clone(), + max_connections_per_peer, + staked_nodes, + MAX_STAKED_CONNECTIONS, + MAX_UNSTAKED_CONNECTIONS, + DEFAULT_MAX_STREAMS_PER_MS, + DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE, + Duration::from_secs(2), + DEFAULT_TPU_COALESCE, + ) + .unwrap(); + (t, exit, receiver, server_address, stats) +} + +pub async fn make_client_endpoint( + addr: &SocketAddr, + client_keypair: Option<&Keypair>, +) -> Connection { + let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap(); + let mut endpoint = quinn::Endpoint::new( + EndpointConfig::default(), + None, + client_socket, + Arc::new(TokioRuntime), + ) + .unwrap(); + let default_keypair = Keypair::new(); + endpoint.set_default_client_config(get_client_config( + client_keypair.unwrap_or(&default_keypair), + )); + endpoint + .connect(*addr, "localhost") + .expect("Failed in connecting") + .await + .expect("Failed in waiting") +}