From 79b6ae89746080edf569521639386e916b2d2da6 Mon Sep 17 00:00:00 2001 From: OriStarkware <76900983+OriStarkware@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:20:06 +0300 Subject: [PATCH] refactor(concurrency): use test_state instead of test_state_reader in testing (#1943) --- .../blockifier/src/concurrency/test_utils.rs | 4 +-- .../src/concurrency/versioned_state_test.rs | 30 +++++++++---------- .../src/concurrency/worker_logic_test.rs | 23 ++++++-------- .../src/test_utils/initial_test_state.rs | 15 ++-------- 4 files changed, 29 insertions(+), 43 deletions(-) diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index 3df10eb6dd..c59e406c43 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -56,8 +56,8 @@ macro_rules! default_scheduler { // TODO(meshi, 01/06/2024): Consider making this a macro. pub fn safe_versioned_state_for_testing( - block_state: DictStateReader, -) -> ThreadSafeVersionedState { + block_state: CachedState, +) -> ThreadSafeVersionedState> { ThreadSafeVersionedState::new(VersionedState::new(block_state)) } diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index da093bc5dd..80fb14f4b3 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -38,12 +38,12 @@ use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key}; pub fn safe_versioned_state( contract_address: ContractAddress, class_hash: ClassHash, -) -> ThreadSafeVersionedState { +) -> ThreadSafeVersionedState> { let init_state = DictStateReader { address_to_class_hash: HashMap::from([(contract_address, class_hash)]), ..Default::default() }; - safe_versioned_state_for_testing(init_state) + safe_versioned_state_for_testing(CachedState::new(init_state)) } #[test] @@ -282,7 +282,7 @@ fn test_run_parallel_txs() { fn test_validate_reads( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { let storage_key = storage_key!("0x10"); @@ -328,12 +328,12 @@ fn test_validate_reads( fn test_apply_writes( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec> = + let mut versioned_proxy_states: Vec>> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>, + TransactionalState<'_, VersionedStateProxy>>, > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. @@ -364,12 +364,12 @@ fn test_apply_writes( fn test_apply_writes_reexecute_scenario( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec> = + let mut versioned_proxy_states: Vec>> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>, + TransactionalState<'_, VersionedStateProxy>>, > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. @@ -400,13 +400,13 @@ fn test_apply_writes_reexecute_scenario( #[rstest] fn test_delete_writes( #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { let num_of_txs = 3; - let mut versioned_proxy_states: Vec> = + let mut versioned_proxy_states: Vec>> = (0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect(); let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>, + TransactionalState<'_, VersionedStateProxy>>, > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Setting 2 instances of the contract to ensure `delete_writes` removes information from @@ -465,7 +465,7 @@ fn test_delete_writes( #[rstest] fn test_delete_writes_completeness( - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let state_maps_writes = StateMaps { @@ -527,13 +527,13 @@ fn test_delete_writes_completeness( #[rstest] fn test_versioned_proxy_state_flow( - safe_versioned_state: ThreadSafeVersionedState, + safe_versioned_state: ThreadSafeVersionedState>, ) { let contract_address = contract_address!("0x1"); let class_hash = ClassHash(stark_felt!(27_u8)); let mut block_state = CachedState::from(DictStateReader::default()); - let mut versioned_proxy_states: Vec> = + let mut versioned_proxy_states: Vec>> = (0..4).map(|i| safe_versioned_state.pin_version(i)).collect(); let mut transactional_states = Vec::with_capacity(4); diff --git a/crates/blockifier/src/concurrency/worker_logic_test.rs b/crates/blockifier/src/concurrency/worker_logic_test.rs index eb1a678ead..d366e9687b 100644 --- a/crates/blockifier/src/concurrency/worker_logic_test.rs +++ b/crates/blockifier/src/concurrency/worker_logic_test.rs @@ -19,7 +19,7 @@ use crate::state::cached_state::StateMaps; use crate::state::state_api::StateReader; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; -use crate::test_utils::initial_test_state::test_state_reader; +use crate::test_utils::initial_test_state::test_state; use crate::test_utils::{ create_calldata, CairoVersion, NonceManager, BALANCE, MAX_FEE, MAX_L1_GAS_AMOUNT, MAX_L1_GAS_PRICE, TEST_ERC20_CONTRACT_ADDRESS, @@ -43,9 +43,8 @@ fn test_worker_execute() { let chain_info = &block_context.chain_info; // Create the state. - let state_reader = - test_state_reader(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]); - let safe_versioned_state = safe_versioned_state_for_testing(state_reader); + let state = test_state(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]); + let safe_versioned_state = safe_versioned_state_for_testing(state); // Create transactions. let test_contract_address = test_contract.get_instance_address(0); @@ -216,9 +215,8 @@ fn test_worker_validate() { let chain_info = &block_context.chain_info; // Create the state. - let state_reader = - test_state_reader(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]); - let safe_versioned_state = safe_versioned_state_for_testing(state_reader); + let state = test_state(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]); + let safe_versioned_state = safe_versioned_state_for_testing(state); // Create transactions. let test_contract_address = test_contract.get_instance_address(0); @@ -320,11 +318,8 @@ pub fn test_add_fee_to_sequencer_balance( let tx_index = 0; let block_context = BlockContext::create_for_account_testing_with_concurrency_mode(true); let account = FeatureContract::Empty(CairoVersion::Cairo1); - let safe_versioned_state = safe_versioned_state_for_testing(test_state_reader( - &block_context.chain_info, - 0, - &[(account, 1)], - )); + let safe_versioned_state = + safe_versioned_state_for_testing(test_state(&block_context.chain_info, 0, &[(account, 1)])); let mut tx_versioned_state = safe_versioned_state.pin_version(tx_index); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(&block_context); @@ -371,8 +366,8 @@ fn test_deploy_before_declare() { let block_context = BlockContext::create_for_account_testing_with_concurrency_mode(true); let chain_info = &block_context.chain_info; let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); - let state_reader = test_state_reader(chain_info, BALANCE, &[(account_contract, 2)]); - let safe_versioned_state = safe_versioned_state_for_testing(state_reader); + let state = test_state(chain_info, BALANCE, &[(account_contract, 2)]); + let safe_versioned_state = safe_versioned_state_for_testing(state); // Create transactions. let account_address_0 = account_contract.get_instance_address(0); diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index f1e2dc41a8..e6dceeba4a 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -37,11 +37,11 @@ pub fn fund_account( /// * "Declares" the input list of contracts. /// * "Deploys" the requested number of instances of each input contract. /// * Makes each input account contract privileged. -pub fn test_state_reader( +pub fn test_state( chain_info: &ChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, u16)], -) -> DictStateReader { +) -> CachedState { let mut class_hash_to_class = HashMap::new(); let mut address_to_class_hash = HashMap::new(); @@ -81,14 +81,5 @@ pub fn test_state_reader( } } - state_reader -} - -/// Initializes a state for testing, with the output of test_state_reader as the initial state. -pub fn test_state( - chain_info: &ChainInfo, - initial_balances: u128, - contract_instances: &[(FeatureContract, u16)], -) -> CachedState { - CachedState::from(test_state_reader(chain_info, initial_balances, contract_instances)) + CachedState::from(state_reader) }