From 693da0e1b23cea37d178c11de21f0ff737a59204 Mon Sep 17 00:00:00 2001 From: OriStarkware <76900983+OriStarkware@users.noreply.github.com> Date: Sun, 2 Jun 2024 14:03:04 +0300 Subject: [PATCH] refactor(concurrency): use TransactionalState::create_transactional in tests (#1920) --- .../src/concurrency/versioned_state_test.rs | 86 +++++++++++-------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index 22f477fee6..da093bc5dd 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -46,8 +46,6 @@ pub fn safe_versioned_state( safe_versioned_state_for_testing(init_state) } -// TODO(OriF 15/5/24): Use `TransactionalState::create_transactional` instead of -// `CachedState::from(..)` when fits. #[test] fn test_versioned_state_proxy() { // Test data @@ -215,8 +213,10 @@ fn test_run_parallel_txs() { )))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let mut state_1 = CachedState::from(safe_versioned_state.pin_version(1)); - let mut state_2 = CachedState::from(safe_versioned_state.pin_version(2)); + let mut versioned_state_proxy_1 = safe_versioned_state.pin_version(1); + let mut state_1 = TransactionalState::create_transactional(&mut versioned_state_proxy_1); + let mut versioned_state_proxy_2 = safe_versioned_state.pin_version(2); + let mut state_2 = TransactionalState::create_transactional(&mut versioned_state_proxy_2); // Prepare transactions let deploy_account_tx_1 = deploy_account_tx( @@ -254,30 +254,28 @@ fn test_run_parallel_txs() { let block_context_1 = block_context.clone(); let block_context_2 = block_context.clone(); // Execute transactions - let thread_handle_1 = thread::spawn(move || { - let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true); - assert_eq!(result.is_err(), enforce_fee); - }); - - let thread_handle_2 = thread::spawn(move || { - account_tx_2.execute(&mut state_2, &block_context_2, true, true).unwrap(); - - // Check that the constructor wrote ctor_arg to the storage. - let storage_key = get_storage_var_address("ctor_arg", &[]); - let deployed_contract_address = calculate_contract_address( - ContractAddressSalt::default(), - class_hash, - &constructor_calldata, - ContractAddress::default(), - ) - .unwrap(); - let read_storage_arg = - state_2.get_storage_at(deployed_contract_address, storage_key).unwrap(); - assert_eq!(ctor_storage_arg, read_storage_arg); + thread::scope(|s| { + s.spawn(move || { + let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true); + assert_eq!(result.is_err(), enforce_fee); + }); + s.spawn(move || { + account_tx_2.execute(&mut state_2, &block_context_2, true, true).unwrap(); + + // Check that the constructor wrote ctor_arg to the storage. + let storage_key = get_storage_var_address("ctor_arg", &[]); + let deployed_contract_address = calculate_contract_address( + ContractAddressSalt::default(), + class_hash, + &constructor_calldata, + ContractAddress::default(), + ) + .unwrap(); + let read_storage_arg = + state_2.get_storage_at(deployed_contract_address, storage_key).unwrap(); + assert_eq!(ctor_storage_arg, read_storage_arg); + }); }); - - thread_handle_1.join().unwrap(); - thread_handle_2.join().unwrap(); } #[rstest] @@ -288,7 +286,8 @@ fn test_validate_reads( ) { let storage_key = storage_key!("0x10"); - let transactional_state = CachedState::from(safe_versioned_state.pin_version(1)); + let mut version_state_proxy = safe_versioned_state.pin_version(1); + let transactional_state = TransactionalState::create_transactional(&mut version_state_proxy); // Validating tx index 0 always succeeds. assert!(safe_versioned_state.pin_version(0).validate_reads(&StateMaps::default())); @@ -331,8 +330,11 @@ fn test_apply_writes( class_hash: ClassHash, safe_versioned_state: ThreadSafeVersionedState, ) { - let mut transactional_states: Vec>> = - (0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); + let mut versioned_proxy_states: Vec> = + (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); + let mut transactional_states: Vec< + TransactionalState<'_, VersionedStateProxy>, + > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -364,8 +366,11 @@ fn test_apply_writes_reexecute_scenario( class_hash: ClassHash, safe_versioned_state: ThreadSafeVersionedState, ) { - let mut transactional_states: Vec>> = - (0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); + let mut versioned_proxy_states: Vec> = + (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); + let mut transactional_states: Vec< + TransactionalState<'_, VersionedStateProxy>, + > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -386,7 +391,8 @@ fn test_apply_writes_reexecute_scenario( // TODO: Use re-execution native util once it's ready. // "Re-execute" the transaction. - transactional_states[1] = CachedState::from(safe_versioned_state.pin_version(1)); + let mut versioned_state_proxy = safe_versioned_state.pin_version(1); + transactional_states[1] = TransactionalState::create_transactional(&mut versioned_state_proxy); // The class hash should be updated. assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); } @@ -397,8 +403,12 @@ fn test_delete_writes( safe_versioned_state: ThreadSafeVersionedState, ) { let num_of_txs = 3; - let mut transactional_states: Vec>> = - (0..num_of_txs).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); + let mut versioned_proxy_states: Vec> = + (0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect(); + let mut transactional_states: Vec< + TransactionalState<'_, VersionedStateProxy>, + > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + // Setting 2 instances of the contract to ensure `delete_writes` removes information from // multiple keys. Class hash values are not checked in this test. let contract_addresses = [ @@ -406,7 +416,7 @@ fn test_delete_writes( (contract_address!("0x200"), class_hash!(21_u8)), ]; let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); - for tx_state in transactional_states.iter_mut() { + for (i, tx_state) in transactional_states.iter_mut().enumerate() { // Modify the `cache` member of the CachedState. for (contract_address, class_hash) in contract_addresses.iter() { tx_state.set_class_hash_at(*contract_address, *class_hash).unwrap(); @@ -415,14 +425,14 @@ fn test_delete_writes( tx_state .set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class()) .unwrap(); - tx_state.state.apply_writes( + safe_versioned_state.pin_version(i).apply_writes( &tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow(), &HashMap::default(), ); } - transactional_states[tx_index_to_delete_writes].state.delete_writes( + safe_versioned_state.pin_version(tx_index_to_delete_writes).delete_writes( &transactional_states[tx_index_to_delete_writes].cache.borrow().writes, &transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(), );