diff --git a/turbine/src/cluster_nodes.rs b/turbine/src/cluster_nodes.rs index fd5b3e5e6dc454..5afaa3cd399c45 100644 --- a/turbine/src/cluster_nodes.rs +++ b/turbine/src/cluster_nodes.rs @@ -27,7 +27,7 @@ use { solana_streamer::socket::SocketAddrSpace, std::{ any::TypeId, - cmp::Reverse, + cmp::Ordering, collections::{HashMap, HashSet}, iter::repeat_with, marker::PhantomData, @@ -50,6 +50,7 @@ pub enum Error { Loopback { leader: Pubkey, shred: ShredId }, } +#[derive(Debug)] #[allow(clippy::large_enum_variant)] enum NodeId { // TVU node obtained through gossip (staked or not). @@ -307,9 +308,10 @@ fn get_nodes( stake, }), ) - .sorted_by_key(|node| Reverse((node.stake, *node.pubkey()))) - // Since sorted_by_key is stable, in case of duplicates, this - // will keep nodes with contact-info. + .sorted_unstable_by(|a, b| cmp_nodes_stake(b, a)) + // dedup_by keeps the first of consecutive elements which compare equal. + // Because if all else are equal above sort puts NodeId::ContactInfo before + // NodeId::Pubkey, this will keep nodes with contact-info. .dedup_by(|a, b| a.pubkey() == b.pubkey()) .filter_map(|node| { if !should_dedup_addrs @@ -338,6 +340,21 @@ fn get_nodes( .collect() } +// Compares Nodes by stake and tie breaks by pubkeys. +// For the same pubkey, NodeId::ContactInfo is considered > NodeId::Pubkey. +#[inline] +fn cmp_nodes_stake(a: &Node, b: &Node) -> Ordering { + a.stake + .cmp(&b.stake) + .then_with(|| a.pubkey().cmp(b.pubkey())) + .then_with(|| match (&a.node, &b.node) { + (NodeId::ContactInfo(_), NodeId::ContactInfo(_)) => Ordering::Equal, + (NodeId::ContactInfo(_), NodeId::Pubkey(_)) => Ordering::Greater, + (NodeId::Pubkey(_), NodeId::ContactInfo(_)) => Ordering::Less, + (NodeId::Pubkey(_), NodeId::Pubkey(_)) => Ordering::Equal, + }) +} + fn get_seeded_rng(leader: &Pubkey, shred: &ShredId) -> ChaChaRng { let seed = shred.seed(leader); ChaChaRng::from_seed(seed) @@ -870,4 +887,60 @@ mod tests { } } } + + #[test] + fn test_cmp_nodes_stake() { + let mut rng = rand::thread_rng(); + let pubkeys: Vec = std::iter::repeat_with(|| Pubkey::from(rng.gen::<[u8; 32]>())) + .take(50) + .collect(); + let stakes = std::iter::repeat_with(|| rng.gen_range(0..100u64)); + let stakes: HashMap = pubkeys.iter().copied().zip(stakes).collect(); + let mut nodes: Vec = std::iter::repeat_with(|| { + let pubkey = pubkeys.choose(&mut rng).copied().unwrap(); + let stake = stakes[&pubkey]; + let node = ContactInfo::new_localhost(&pubkey, /*wallclock:*/ timestamp()); + [ + Node { + node: NodeId::from(node), + stake, + }, + Node { + node: NodeId::from(pubkey), + stake, + }, + ] + }) + .flatten() + .take(10_000) + .collect(); + nodes.shuffle(&mut rng); + let nodes: Vec = nodes + .into_iter() + .sorted_unstable_by(|a, b| cmp_nodes_stake(b, a)) + .dedup_by(|a, b| a.pubkey() == b.pubkey()) + .collect(); + // Assert that stakes are non-decreasing. + for (a, b) in nodes.iter().tuple_windows() { + assert!(a.stake >= b.stake); + } + // Assert that larger pubkey tie-breaks equal stakes. + for (a, b) in nodes.iter().tuple_windows() { + if a.stake == b.stake { + assert!(a.pubkey() > b.pubkey()); + } + } + // Assert that NodeId::Pubkey are dropped in favor of + // NodeId::ContactInfo. + for node in &nodes { + assert_matches!(node.node, NodeId::ContactInfo(_)); + } + // Assert that unique pubkeys are preserved. + { + let mut pubkeys = HashSet::new(); + for node in &nodes { + assert!(pubkeys.insert(node.pubkey())); + } + } + } }