Skip to content

Commit

Permalink
feat: limit the number of shards ids in SnapshotHostInfo (#10294)
Browse files Browse the repository at this point in the history
`SnapshotHostInfo` contains a list of shard ids for which a peer has a
snapshot.
Under normal circumstances this list is pretty small - the maximum size
would be the total number of shards.
But a malicious peer could craft a `SnapshotHostInfo` message with very
large number of shards - millions of shard ids. This could cause
problems on the receiver node, so there has to be limit on the number of
shard ids in a message.

Let's limit the number of shard ids in a single message to
`MAX_SHARDS_PER_SNAPSHOT_HOST_INFO = 512`.
It's a constant that describes how many shards a single peer can have
snapshots for.
A peer doesn't keep more state snapshots than the number of tracked
shards, so this limit is reasonable. 512 shards ought to be enough for
anyone.

In an ideal world we could check the current number of shards and reject
messages that contain more shard ids than the current number, but sadly
this can't be implemented. The problem is that the receiving node might
be behind the rest of the blockchain, and the latest information just
isn't available, so it can't check what the current number of shards is.
We could reject messages in such situations, but this would lead to loss
of information.

Limiting the number of shard ids to a constant number is an okay
alternative.
  • Loading branch information
jancionear authored Dec 4, 2023
1 parent 5ec1ae5 commit bc3590c
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 24 deletions.
129 changes: 128 additions & 1 deletion chain/network/src/concurrency/rayon.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rayon::iter::ParallelIterator;
use rayon::iter::{Either, ParallelIterator};
use std::error::Error;
use std::sync::atomic::{AtomicBool, Ordering};

/// spawns a closure on a global rayon threadpool and awaits its completion.
Expand Down Expand Up @@ -46,3 +47,129 @@ pub fn try_map<I: ParallelIterator, T: Send>(
.collect();
(res, ok.load(Ordering::Acquire))
}

/// Applies `func` to the iterated elements and collects the outputs. On the first `Error` the execution is stopped.
/// Returns the outputs collected so far and a [`Result`] ([`Result::Err`] iff any `Error` was returned).
/// Same as [`try_map`], but it operates on [`Result`] instead of [`Option`].
pub fn try_map_result<I: ParallelIterator, T: Send, E: Error + Send>(
iter: I,
func: impl Sync + Send + Fn(I::Item) -> Result<T, E>,
) -> (Vec<T>, Result<(), E>) {
// Call the function on every input value and emit some items for every result.
// On a successful call this iterator emits one item: `Some(Ok(output_value))`
// When an error occurs, the iterator emits two items: `Some(Err(the_error))` and `None``
// The `None` will later be used to tell rayon to stop the execution.
let optional_result_iter /* impl Iterator<Item = Option<Result<T, E>>> */ = iter
.map(|v| match func(v) {
Ok(val) => Either::Left(std::iter::once(Some(Ok(val)))),
Err(err) => Either::Right([Some(Err(err)), None].into_iter()),
})
.flatten_iter();

// `while_some()` monitors a stream of `Option` values and stops the execution when it spots a `None` value.
// It's used to implement the short-circuit logic - on the first error rayon will stop processing subsequent items.
let results_iter = optional_result_iter.while_some();

// Split the results into two groups - the left group contains the outputs resulting from a successful execution,
// while the right group contains errors. Collect them into two separate Vecs.
let (outputs, errors): (Vec<T>, Vec<E>) = results_iter
.map(|res| match res {
Ok(value) => Either::Left(value),
Err(error) => Either::Right(error),
})
.collect();

// Return the output and the first error (if there was any)
match errors.into_iter().next() {
Some(first_error) => (outputs, Err(first_error)),
None => (outputs, Ok(())),
}
}

#[cfg(test)]
mod tests {
use super::try_map_result;
use rayon::iter::ParallelBridge;

#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[error("error for testing")]
struct TestError;

/// On empty input try_map_result returns empty output and Ok(())
#[test]
fn try_map_result_empty_iter() {
let empty_array: [i32; 0] = [];

let panicking_function = |_input: i32| -> Result<i64, TestError> {
panic!();
};

let (outputs, result): (Vec<i64>, Result<(), TestError>) =
try_map_result(empty_array.into_iter().par_bridge(), panicking_function);
assert_eq!(outputs, Vec::<i64>::new());
assert_eq!(result, Ok(()));
}

/// Happy path of try_map_result - all items are successfully processed, should return the outputs and Ok(())
#[test]
fn try_map_result_success() {
let inputs = [1, 2, 3, 4, 5, 6].into_iter();
let func = |input: i32| -> Result<i64, TestError> { Ok(input as i64 + 1) };

let (mut outputs, result): (Vec<i64>, Result<(), TestError>) =
try_map_result(inputs.into_iter().par_bridge(), func);
outputs.sort();
assert_eq!(outputs, vec![2i64, 3, 4, 5, 6, 7]);
assert_eq!(result, Ok(()));
}

/// Run `try_map_result` with an infinite stream of tasks, but the 100th task returns an Error.
/// `try_map_result` should stop the execution and return some successful outputs along with the Error.
#[test]
fn try_map_result_stops_on_error() {
// Infinite iterator of inputs: 1, 2, 3, 4, 5, ...
let infinite_iter = (1..).into_iter();

let func = |input: i32| -> Result<i64, TestError> {
if input == 100 {
return Err(TestError);
}

Ok(2 * input as i64)
};

let (mut outputs, result): (Vec<i64>, Result<(), TestError>) =
try_map_result(infinite_iter.par_bridge(), func);
outputs.sort();

// The error will happen on 100th input, but other threads might produce subsequent outputs in parallel,
// so the size of outputs could be a bit larger than 100. Compare with 10_000 as a safety margin.
assert!(outputs.len() > 10);
assert!(outputs.len() < 10_000);
assert_eq!(&outputs[..10], &[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]);

for output in outputs {
// All outputs should be even, func multiplies the inputs by 2.
assert_eq!(output % 2, 0);
assert!(output < 20_0000);
}
assert_eq!(result, Err(TestError));
}

/// When using try_map_result, a panic in the function will be propagated to the caller
#[test]
#[should_panic]
fn try_map_result_panic() {
let inputs = (1..1000).into_iter();

let panicking_function = |input: i32| -> Result<i64, TestError> {
if input == 100 {
panic!("Oh no the input is equal to 100");
}
Ok(2 * input as i64)
};

let (_outputs, _result) = try_map_result(inputs.par_bridge(), panicking_function);
// Should panic
}
}
13 changes: 13 additions & 0 deletions chain/network/src/network_protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ pub struct VersionedAccountData {
/// because it may contain many unknown fields (which are dropped during parsing).
pub const MAX_ACCOUNT_DATA_SIZE_BYTES: usize = 10000; // 10kB

/// Limit on the number of shard ids in a single [`SnapshotHostInfo`](state_sync::SnapshotHostInfo) message.
/// The number of shards has to be limited, otherwise a malicious attack could fill the snapshot host cache
/// with millions of shards.
/// The assumption is that no single host is going to track state for more than 512 shards. Keeping state for
/// a shard requires significant resources, so a single peer shouldn't be able to handle too many of them.
/// If this assumption changes in the future, this limit will have to be revisited.
///
/// Warning: adjusting this constant directly will break upgradeability. A new versioned-node would not interop
/// correctly with an old-versioned node; it could send an excessively large message to an old node.
/// If we ever want to change it we will need to introduce separate send and receive limits,
/// increase the receive limit in one release then increase the send limit in the next.
pub const MAX_SHARDS_PER_SNAPSHOT_HOST_INFO: usize = 512;

impl VersionedAccountData {
/// Serializes AccountData to proto and signs it using `signer`.
/// Panics if AccountData.account_id doesn't match signer.validator_id(),
Expand Down
28 changes: 25 additions & 3 deletions chain/network/src/network_protocol/state_sync.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::MAX_SHARDS_PER_SNAPSHOT_HOST_INFO;
use crate::network_protocol::Arc;
use near_crypto::SecretKey;
use near_crypto::Signature;
Expand All @@ -21,7 +22,7 @@ pub struct SnapshotHostInfo {
pub sync_hash: CryptoHash,
/// Ordinal of the epoch of the state root
pub epoch_height: EpochHeight,
/// List of shards included in the snapshot
/// List of shards included in the snapshot.
pub shards: Vec<ShardId>,
/// Signature on (sync_hash, epoch_height, shards)
pub signature: Signature,
Expand Down Expand Up @@ -54,8 +55,18 @@ impl SnapshotHostInfo {
Self::build_hash(&self.sync_hash, &self.epoch_height, &self.shards)
}

pub(crate) fn verify(&self) -> bool {
self.signature.verify(self.hash().as_ref(), self.peer_id.public_key())
pub(crate) fn verify(&self) -> Result<(), SnapshotHostInfoVerificationError> {
// Number of shards must be limited, otherwise it'd be possible to create malicious
// messages with millions of shard ids.
if self.shards.len() > MAX_SHARDS_PER_SNAPSHOT_HOST_INFO {
return Err(SnapshotHostInfoVerificationError::TooManyShards(self.shards.len()));
}

if !self.signature.verify(self.hash().as_ref(), self.peer_id.public_key()) {
return Err(SnapshotHostInfoVerificationError::InvalidSignature);
}

Ok(())
}
}

Expand All @@ -67,3 +78,14 @@ impl SnapshotHostInfo {
pub struct SyncSnapshotHosts {
pub hosts: Vec<Arc<SnapshotHostInfo>>,
}

#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
pub enum SnapshotHostInfoVerificationError {
#[error("SnapshotHostInfo is signed with an invalid signature")]
InvalidSignature,
#[error(
"SnapshotHostInfo contains more shards than allowed: {0} > {} (MAX_SHARDS_PER_SNAPSHOT_HOST_INFO)",
MAX_SHARDS_PER_SNAPSHOT_HOST_INFO
)]
TooManyShards(usize),
}
12 changes: 8 additions & 4 deletions chain/network/src/peer/peer_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::accounts_data::AccountDataError;
use crate::concurrency::atomic_cell::AtomicCell;
use crate::concurrency::demux;
use crate::config::PEERS_RESPONSE_MAX_PEERS;
use crate::network_protocol::SnapshotHostInfoVerificationError;
use crate::network_protocol::{
DistanceVector, Edge, EdgeState, Encoding, OwnedAccount, ParsePeerMessageError,
PartialEdgeInfo, PeerChainInfoV2, PeerIdOrHash, PeerInfo, PeersRequest, PeersResponse,
Expand Down Expand Up @@ -1285,10 +1286,13 @@ impl PeerActor {
ctx.spawn(wrap_future(async move {
if let Some(err) = network_state.add_snapshot_hosts(msg.hosts).await {
conn.stop(Some(match err {
SnapshotHostInfoError::InvalidSignature => {
ReasonForBan::InvalidSignature
}
SnapshotHostInfoError::DuplicatePeerId => ReasonForBan::Abusive,
SnapshotHostInfoError::VerificationError(
SnapshotHostInfoVerificationError::InvalidSignature,
) => ReasonForBan::InvalidSignature,
SnapshotHostInfoError::VerificationError(
SnapshotHostInfoVerificationError::TooManyShards(_),
)
| SnapshotHostInfoError::DuplicatePeerId => ReasonForBan::Abusive,
}));
}
message_processed_event();
Expand Down
24 changes: 21 additions & 3 deletions chain/network/src/peer_manager/peer_manager_actor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::client;
use crate::config;
use crate::debug::{DebugStatus, GetDebugStatus};
use crate::network_protocol::SyncSnapshotHosts;
Expand All @@ -18,6 +17,7 @@ use crate::types::{
NetworkResponses, PeerInfo, PeerManagerMessageRequest, PeerManagerMessageResponse, PeerType,
SetChainInfo, SnapshotHostInfo,
};
use crate::{client, network_protocol};
use actix::fut::future::wrap_future;
use actix::{Actor as _, AsyncContext as _};
use anyhow::Context as _;
Expand All @@ -31,7 +31,8 @@ use near_primitives::views::{
ConnectionInfoView, EdgeView, KnownPeerStateView, NetworkGraphView, PeerStoreView,
RecentOutboundConnectionsView, SnapshotHostInfoView, SnapshotHostsView,
};
use rand::seq::IteratorRandom;
use network_protocol::MAX_SHARDS_PER_SNAPSHOT_HOST_INFO;
use rand::seq::{IteratorRandom, SliceRandom};
use rand::thread_rng;
use rand::Rng;
use std::cmp::min;
Expand Down Expand Up @@ -784,7 +785,24 @@ impl PeerManagerActor {
NetworkResponses::RouteNotFound
}
}
NetworkRequests::SnapshotHostInfo { sync_hash, epoch_height, shards } => {
NetworkRequests::SnapshotHostInfo { sync_hash, epoch_height, mut shards } => {
if shards.len() > MAX_SHARDS_PER_SNAPSHOT_HOST_INFO {
tracing::warn!("PeerManager: Sending out a SnapshotHostInfo message with {} shards, \
this is more than the allowed limit. The list of shards will be truncated. \
Please adjust the MAX_SHARDS_PER_SNAPSHOT_HOST_INFO constant ({})", shards.len(), MAX_SHARDS_PER_SNAPSHOT_HOST_INFO);

// We can's send out more than MAX_SHARDS_PER_SNAPSHOT_HOST_INFO shards because other nodes would
// ban us for abusive behavior. Let's truncate the shards vector by choosing a random subset of
// MAX_SHARDS_PER_SNAPSHOT_HOST_INFO shard ids. Choosing a random subset slightly increases the chances
// that other nodes will have snapshot sync information about all shards from some node.
shards = shards
.choose_multiple(&mut rand::thread_rng(), MAX_SHARDS_PER_SNAPSHOT_HOST_INFO)
.copied()
.collect();
}
// Sort the shards to keep things tidy
shards.sort();

// Sign the information about the locally created snapshot using the keys in the
// network config before broadcasting it
let snapshot_host_info = SnapshotHostInfo::new(
Expand Down
Loading

0 comments on commit bc3590c

Please sign in to comment.