diff --git a/crates/blockifier/src/blockifier/stateful_validator.rs b/crates/blockifier/src/blockifier/stateful_validator.rs index c584c1f29a..a8a5648131 100644 --- a/crates/blockifier/src/blockifier/stateful_validator.rs +++ b/crates/blockifier/src/blockifier/stateful_validator.rs @@ -7,7 +7,9 @@ use starknet_api::transaction::TransactionHash; use thiserror::Error; use crate::blockifier::config::TransactionExecutorConfig; -use crate::blockifier::transaction_executor::{TransactionExecutor, TransactionExecutorError}; +use crate::blockifier::transaction_executor::{ + TransactionExecutor, TransactionExecutorError, BLOCK_STATE_ACCESS_ERR, +}; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; use crate::fee::actual_cost::TransactionReceipt; @@ -108,7 +110,7 @@ impl StatefulValidator { // Run pre-validation in charge fee mode to perform fee and balance related checks. let charge_fee = true; tx.perform_pre_validation_stage( - &mut self.tx_executor.state, + self.tx_executor.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), tx_context, charge_fee, strict_nonce_check, @@ -125,7 +127,12 @@ impl StatefulValidator { tx_info: &TransactionInfo, deploy_account_tx_hash: Option, ) -> StatefulValidatorResult { - let nonce = self.tx_executor.state.get_nonce_at(tx_info.sender_address())?; + let nonce = self + .tx_executor + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_nonce_at(tx_info.sender_address())?; let tx_nonce = tx_info.nonce(); let deploy_account_not_processed = @@ -151,7 +158,7 @@ impl StatefulValidator { let limit_steps_by_resources = true; let validate_call_info = tx.validate_tx( - &mut self.tx_executor.state, + self.tx_executor.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), &mut execution_resources, tx_context.clone(), &mut remaining_gas, @@ -161,7 +168,12 @@ impl StatefulValidator { let tx_receipt = TransactionReceipt::from_account_tx( tx, &tx_context, - &self.tx_executor.state.get_actual_state_changes()?, + &self + .tx_executor + .block_state + .as_mut() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_actual_state_changes()?, &execution_resources, validate_call_info.iter(), 0, diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 7486d79215..220c26fefa 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -18,6 +18,8 @@ use crate::transaction::transactions::ExecutableTransaction; #[path = "transaction_executor_test.rs"] pub mod transaction_executor_test; +pub const BLOCK_STATE_ACCESS_ERR: &str = "Error: The block state should be `Some`."; + #[derive(Debug, Error)] pub enum TransactionExecutorError { #[error("Transaction cannot be added to the current block, block capacity reached.")] @@ -39,12 +41,16 @@ pub struct TransactionExecutor { pub config: TransactionExecutorConfig, // State-related fields. - pub state: CachedState, + // The transaction executor operates at the block level. In concurrency mode, it moves the + // block state to the worker executor - operating at the chunk level - and gets it back after + // committing the chunk. The block state is wrapped with an Option<_> to allow setting it to + // `None` while it is moved to the worker executor. + pub block_state: Option>, } impl TransactionExecutor { pub fn new( - state: CachedState, + block_state: CachedState, block_context: BlockContext, config: TransactionExecutorConfig, ) -> Self { @@ -52,8 +58,12 @@ impl TransactionExecutor { let bouncer_config = block_context.bouncer_config.clone(); // Note: the state might not be empty even at this point; it is the creator's // responsibility to tune the bouncer according to pre and post block process. - let tx_executor = - Self { block_context, bouncer: Bouncer::new(bouncer_config), config, state }; + let tx_executor = Self { + block_context, + bouncer: Bouncer::new(bouncer_config), + config, + block_state: Some(block_state), + }; log::debug!("Initialized Transaction Executor."); tx_executor @@ -67,7 +77,9 @@ impl TransactionExecutor { tx: &Transaction, charge_fee: bool, ) -> TransactionExecutorResult { - let mut transactional_state = TransactionalState::create_transactional(&mut self.state); + let mut transactional_state = TransactionalState::create_transactional( + self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), + ); let validate = true; let tx_execution_result = @@ -153,18 +165,24 @@ impl TransactionExecutor { // This is done by taking all the visited PCs of each contract, and compress them to one // representative for each visited segment. let visited_segments = self - .state + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) .visited_pcs .iter() .map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> { - let contract_class = self.state.get_compiled_contract_class(*class_hash)?; + let contract_class = self + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_compiled_contract_class(*class_hash)?; Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) }) .collect::>()?; log::debug!("Final block weights: {:?}.", self.bouncer.get_accumulated_weights()); Ok(( - self.state.to_state_diff()?.into(), + self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR).to_state_diff()?.into(), visited_segments, *self.bouncer.get_accumulated_weights(), )) diff --git a/crates/blockifier/src/blockifier/transaction_executor_test.rs b/crates/blockifier/src/blockifier/transaction_executor_test.rs index 4c75d367fa..5d3697095a 100644 --- a/crates/blockifier/src/blockifier/transaction_executor_test.rs +++ b/crates/blockifier/src/blockifier/transaction_executor_test.rs @@ -6,7 +6,9 @@ use starknet_api::stark_felt; use starknet_api::transaction::{Fee, TransactionVersion}; use crate::blockifier::config::TransactionExecutorConfig; -use crate::blockifier::transaction_executor::{TransactionExecutor, TransactionExecutorError}; +use crate::blockifier::transaction_executor::{ + TransactionExecutor, TransactionExecutorError, BLOCK_STATE_ACCESS_ERR, +}; use crate::bouncer::{Bouncer, BouncerWeights}; use crate::context::BlockContext; use crate::state::cached_state::CachedState; @@ -314,7 +316,15 @@ fn test_execute_txs_bouncing() { assert!(results[2].is_ok()); // Check state. - assert_eq!(tx_executor.state.get_nonce_at(account_address).unwrap(), nonce!(2_u32)); + assert_eq!( + tx_executor + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_nonce_at(account_address) + .unwrap(), + nonce!(2_u32) + ); // Check idempotency: excess transactions should not be added. let remaining_txs = &txs[expected_offset..]; @@ -328,5 +338,13 @@ fn test_execute_txs_bouncing() { assert_eq!(remaining_tx_results.len(), 2); assert!(remaining_tx_results[0].is_ok()); assert!(remaining_tx_results[1].is_ok()); - assert_eq!(tx_executor.state.get_nonce_at(account_address).unwrap(), nonce!(4_u32)); + assert_eq!( + tx_executor + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_nonce_at(account_address) + .unwrap(), + nonce!(4_u32) + ); } diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index 7963168c9e..bbeb3daa88 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use blockifier::blockifier::transaction_executor::BLOCK_STATE_ACCESS_ERR; use blockifier::execution::contract_class::{ContractClass, ContractClassV1}; use blockifier::state::state_api::StateReader; use cached::Cached; @@ -53,8 +54,13 @@ fn global_contract_cache_update() { assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 0); - let queried_contract_class = - block_executor.tx_executor().state.get_compiled_contract_class(class_hash).unwrap(); + let queried_contract_class = block_executor + .tx_executor() + .block_state + .as_ref() + .expect(BLOCK_STATE_ACCESS_ERR) + .get_compiled_contract_class(class_hash) + .unwrap(); assert_eq!(queried_contract_class, contract_class); assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 1);