Skip to content

Commit

Permalink
fix(resharding): fix Chain::get_shards_to_state_sync() (#12617)
Browse files Browse the repository at this point in the history
This function was left unimplemented for the resharding v3 case, so here
we implement it. For each shard ID, we say that we want to state sync it
if we'll be tracking it next epoch and don't currently track it in this
or the previous epoch. If we don't track it in the current epoch but did
track it in the previous, then there's no need to state sync, since we
can just keep applying chunks to stay caught up.

This logic doesn't make reference to resharding since it should be the
same in either case, but it fixes the bug where both state sync and the
resharding code might believe they're responsible for generating the
child shard state after a resharding if we're in this rare case where we
track the parent in epoch `T-1`, don't track the child in `T`, and will
track the child in `T+1`. In that case, start_resharding() will be
called at the end of epoch `T-1` and will initiate the child shard
splitting (this is true even if we don't track either child in the
current epoch). And then resharding will proceed as usual and we'll
start applying chunks for the child when it's ready, and no state sync
has to be involved.
  • Loading branch information
marcelo-gonzalez authored Jan 6, 2025
1 parent 445ee6f commit 0b5398e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 71 deletions.
164 changes: 93 additions & 71 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,28 +795,27 @@ impl Chain {
fn get_state_sync_info(
&self,
me: &Option<AccountId>,
epoch_first_block: &Block,
epoch_id: &EpochId,
block_hash: &CryptoHash,
prev_hash: &CryptoHash,
prev_prev_hash: &CryptoHash,
) -> Result<Option<StateSyncInfo>, Error> {
let prev_hash = *epoch_first_block.header().prev_hash();
let shards_to_state_sync = Chain::get_shards_to_state_sync(
self.epoch_manager.as_ref(),
&self.shard_tracker,
me,
&prev_hash,
prev_hash,
prev_prev_hash,
)?;
if shards_to_state_sync.is_empty() {
Ok(None)
} else {
debug!(target: "chain", "Downloading state for {:?}, I'm {:?}", shards_to_state_sync, me);
let epoch_id = epoch_first_block.header().epoch_id();
let protocol_version = self.epoch_manager.get_epoch_protocol_version(epoch_id)?;
// Note that this block is the first block in an epoch because this function is only called
// in get_catchup_and_state_sync_infos() when that is the case.
let state_sync_info = StateSyncInfo::new(
protocol_version,
*epoch_first_block.header().hash(),
shards_to_state_sync,
);
let state_sync_info =
StateSyncInfo::new(protocol_version, *block_hash, shards_to_state_sync);
Ok(Some(state_sync_info))
}
}
Expand Down Expand Up @@ -2271,7 +2270,7 @@ impl Chain {
// First real I/O expense.
let prev = self.get_previous_header(header)?;
let prev_hash = *prev.hash();
let prev_prev_hash = *prev.prev_hash();
let prev_prev_hash = prev.prev_hash();
let gas_price = prev.next_gas_price();
let prev_random_value = *prev.random_value();
let prev_height = prev.height();
Expand All @@ -2281,8 +2280,13 @@ impl Chain {
return Err(Error::InvalidBlockHeight(prev_height));
}

let (is_caught_up, state_sync_info) =
self.get_catchup_and_state_sync_infos(header, prev_hash, prev_prev_hash, me, block)?;
let (is_caught_up, state_sync_info) = self.get_catchup_and_state_sync_infos(
header.epoch_id(),
header.hash(),
&prev_hash,
prev_prev_hash,
me,
)?;

self.check_if_challenged_block_on_chain(header)?;

Expand Down Expand Up @@ -2375,29 +2379,32 @@ impl Chain {

fn get_catchup_and_state_sync_infos(
&self,
header: &BlockHeader,
prev_hash: CryptoHash,
prev_prev_hash: CryptoHash,
epoch_id: &EpochId,
block_hash: &CryptoHash,
prev_hash: &CryptoHash,
prev_prev_hash: &CryptoHash,
me: &Option<AccountId>,
block: &MaybeValidated<Block>,
) -> Result<(bool, Option<StateSyncInfo>), Error> {
if self.epoch_manager.is_next_block_epoch_start(&prev_hash)? {
debug!(target: "chain", block_hash=?header.hash(), "block is the first block of an epoch");
if !self.prev_block_is_caught_up(&prev_prev_hash, &prev_hash)? {
// The previous block is not caught up for the next epoch relative to the previous
// block, which is the current epoch for this block, so this block cannot be applied
// at all yet, needs to be orphaned
return Err(Error::Orphan);
}

// For the first block of the epoch we check if we need to start download states for
// shards that we will care about in the next epoch. If there is no state to be downloaded,
// we consider that we are caught up, otherwise not
let state_sync_info = self.get_state_sync_info(me, block)?;
Ok((state_sync_info.is_none(), state_sync_info))
} else {
Ok((self.prev_block_is_caught_up(&prev_prev_hash, &prev_hash)?, None))
if !self.epoch_manager.is_next_block_epoch_start(prev_hash)? {
return Ok((self.prev_block_is_caught_up(prev_prev_hash, prev_hash)?, None));
}
if !self.prev_block_is_caught_up(prev_prev_hash, prev_hash)? {
// The previous block is not caught up for the next epoch relative to the previous
// block, which is the current epoch for this block, so this block cannot be applied
// at all yet, needs to be orphaned
return Err(Error::Orphan);
}

// For the first block of the epoch we check if we need to start download states for
// shards that we will care about in the next epoch. If there is no state to be downloaded,
// we consider that we are caught up, otherwise not
let state_sync_info =
self.get_state_sync_info(me, epoch_id, block_hash, prev_hash, prev_prev_hash)?;
debug!(
target: "chain", %block_hash, shards_to_sync=?state_sync_info.as_ref().map(|s| s.shards()),
"Checked for shards to sync for epoch T+1 upon processing first block of epoch T"
);
Ok((state_sync_info.is_none(), state_sync_info))
}

pub fn prev_block_is_caught_up(
Expand Down Expand Up @@ -2425,56 +2432,71 @@ impl Chain {
shard_tracker: &ShardTracker,
me: &Option<AccountId>,
parent_hash: &CryptoHash,
prev_prev_hash: &CryptoHash,
) -> Result<Vec<ShardId>, Error> {
let epoch_id = epoch_manager.get_epoch_id_from_prev_block(parent_hash)?;
Ok((epoch_manager.shard_ids(&epoch_id)?)
.into_iter()
.filter(|shard_id| {
Self::should_catch_up_shard(
epoch_manager,
shard_tracker,
me,
parent_hash,
*shard_id,
)
})
.collect())
let mut shards_to_sync = Vec::new();
for shard_id in epoch_manager.shard_ids(&epoch_id)? {
if Self::should_catch_up_shard(
epoch_manager,
shard_tracker,
me,
&epoch_id,
parent_hash,
prev_prev_hash,
shard_id,
)? {
shards_to_sync.push(shard_id)
}
}
Ok(shards_to_sync)
}

/// Returns whether we need to initiate state sync for the given `shard_id` for the epoch
/// beginning after the block `epoch_last_block`. If that epoch is epoch T, the logic is:
/// - will track the shard in epoch T+1
/// - AND not tracking it in T
/// - AND didn't track it in T-1
/// We check that we didn't track it in T-1 because if so, and we're in the relatively rare case
/// where we'll go from tracking it to not tracking it and back to tracking it in consecutive epochs,
/// then we can just continue to apply chunks as if we were tracking it in epoch T, and there's no need to state sync.
fn should_catch_up_shard(
epoch_manager: &dyn EpochManagerAdapter,
shard_tracker: &ShardTracker,
me: &Option<AccountId>,
parent_hash: &CryptoHash,
epoch_id: &EpochId,
epoch_last_block: &CryptoHash,
epoch_last_block_prev: &CryptoHash,
shard_id: ShardId,
) -> bool {
let result = epoch_manager.will_shard_layout_change(parent_hash);
let will_shard_layout_change = match result {
Ok(_will_shard_layout_change) => {
// TODO(#11881): before state sync is fixed, we don't catch up
// split shards. Assume that all needed shards are tracked
// already.
// will_shard_layout_change,
false
}
Err(err) => {
// TODO(resharding) This is a problem, if this happens the node
// will not perform resharding and fall behind the network.
tracing::error!(target: "chain", ?err, "failed to check if shard layout will change");
false
}
};
// if shard layout will change the next epoch, we should catch up the shard regardless
// whether we already have the shard's state this epoch, because we need to generate
// new states for shards split from the current shard for the next epoch
let will_care_about_shard =
shard_tracker.will_care_about_shard(me.as_ref(), parent_hash, shard_id, true);
let does_care_about_shard =
shard_tracker.care_about_shard(me.as_ref(), parent_hash, shard_id, true);
) -> Result<bool, Error> {
// Won't care about it next epoch, no need to state sync it.
if !shard_tracker.will_care_about_shard(me.as_ref(), epoch_last_block, shard_id, true) {
return Ok(false);
}
// Currently tracking the shard, so no need to state sync it.
if shard_tracker.care_about_shard(me.as_ref(), epoch_last_block, shard_id, true) {
return Ok(false);
}

tracing::debug!(target: "chain", does_care_about_shard, will_care_about_shard, will_shard_layout_change, "should catch up shard");
// Now we need to state sync it unless we were tracking the parent in the previous epoch,
// in which case we don't need to because we already have the state, and can just continue applying chunks
if epoch_id == &EpochId::default() {
return Ok(true);
}

will_care_about_shard && (will_shard_layout_change || !does_care_about_shard)
let (_layout, parent_shard_id, _index) =
epoch_manager.get_prev_shard_id_from_prev_hash(epoch_last_block, shard_id)?;
// Note that here passing `epoch_last_block_prev` to care_about_shard() will have us check whether we were tracking it in
// the previous epoch, because it is the "parent_hash" of the last block of the previous epoch.
// TODO: consider refactoring these ShardTracker functions to accept an epoch_id
// to make this less tricky.
let tracked_before = shard_tracker.care_about_shard(
me.as_ref(),
epoch_last_block_prev,
parent_shard_id,
true,
);
Ok(!tracked_before)
}

/// Check if any block with missing chunk is ready to be processed and start processing these blocks
Expand Down
32 changes: 32 additions & 0 deletions integration-tests/src/test_loop/tests/resharding_v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,38 @@ fn test_resharding_v3_shard_shuffling() {
test_resharding_v3_base(params);
}

/// This tests an edge case where we track the parent in the pre-resharding epoch, then we
/// track an unrelated shard in the first epoch after resharding, then we track a child of the resharding
/// in the next epoch after that. In that case we don't want to state sync because we can just perform
/// the resharding and continue applying chunks for the child in the first epoch post-resharding.
#[test]
fn test_resharding_v3_shard_shuffling_untrack_then_track() {
let account_in_stable_shard: AccountId = "account0".parse().unwrap();
let split_boundary_account: AccountId = NEW_BOUNDARY_ACCOUNT.parse().unwrap();
let base_shard_layout = get_base_shard_layout(DEFAULT_SHARD_LAYOUT_VERSION);
let new_shard_layout =
ShardLayout::derive_shard_layout(&base_shard_layout, split_boundary_account.clone());
let parent_shard_id = base_shard_layout.account_id_to_shard_id(&split_boundary_account);
let child_shard_id = new_shard_layout.account_id_to_shard_id(&split_boundary_account);
let unrelated_shard_id = new_shard_layout.account_id_to_shard_id(&account_in_stable_shard);

let tracked_shard_sequence =
vec![parent_shard_id, parent_shard_id, unrelated_shard_id, child_shard_id];
let num_clients = 8;
let tracked_shard_schedule = TrackedShardSchedule {
client_index: (num_clients - 1) as usize,
schedule: shard_sequence_to_schedule(tracked_shard_sequence),
};
let params = TestReshardingParametersBuilder::default()
.shuffle_shard_assignment_for_chunk_producers(true)
.num_clients(num_clients)
.tracked_shard_schedule(Some(tracked_shard_schedule))
// TODO(resharding): uncomment after fixing test_resharding_v3_state_cleanup()
//.add_loop_action(check_state_cleanup_after_resharding(tracked_shard_schedule))
.build();
test_resharding_v3_base(params);
}

#[test]
fn test_resharding_v3_shard_shuffling_intense() {
let chunk_ranges_to_drop = HashMap::from([(0, -1..2), (1, -3..0), (2, -3..3), (3, 0..1)]);
Expand Down

0 comments on commit 0b5398e

Please sign in to comment.