diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index bebdc31138..f64d1b2a75 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -1,6 +1,4 @@ -use std::sync::{Arc, Mutex}; - -use crate::concurrency::versioned_state_proxy::VersionedState; +use crate::concurrency::versioned_state_proxy::{ThreadSafeVersionedState, VersionedState}; use crate::test_utils::dict_state_reader::DictStateReader; #[macro_export] @@ -32,10 +30,8 @@ macro_rules! default_scheduler { } // TODO(meshi, 01/06/2024): Consider making this a macro. -// TODO: Use `ThreadSafeVersionedState` as a return value, need to change the inner field to public, -// or add new method that gets `StateReader`. -pub fn versioned_state_for_testing( +pub fn safe_versioned_state_for_testing( block_state: DictStateReader, -) -> Arc>> { - Arc::new(Mutex::new(VersionedState::new(block_state))) +) -> ThreadSafeVersionedState { + ThreadSafeVersionedState::new(VersionedState::new(block_state)) } diff --git a/crates/blockifier/src/concurrency/versioned_state_proxy.rs b/crates/blockifier/src/concurrency/versioned_state_proxy.rs index e59f2a489e..94aa273acb 100644 --- a/crates/blockifier/src/concurrency/versioned_state_proxy.rs +++ b/crates/blockifier/src/concurrency/versioned_state_proxy.rs @@ -145,9 +145,17 @@ pub struct ThreadSafeVersionedState(Arc> pub type LockedVersionedState<'a, S> = MutexGuard<'a, VersionedState>; impl ThreadSafeVersionedState { + pub fn new(versioned_state: VersionedState) -> Self { + ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state))) + } + pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { VersionedStateProxy { tx_index, state: self.0.clone() } } + + pub fn state(&self) -> LockedVersionedState<'_, S> { + self.0.lock().expect("Failed to acquire state lock.") + } } pub struct VersionedStateProxy { diff --git a/crates/blockifier/src/concurrency/versioned_state_proxy_test.rs b/crates/blockifier/src/concurrency/versioned_state_proxy_test.rs index 6d52548c33..a4e30cf728 100644 --- a/crates/blockifier/src/concurrency/versioned_state_proxy_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_proxy_test.rs @@ -9,7 +9,7 @@ use starknet_api::transaction::{Calldata, ContractAddressSalt, Fee, TransactionV use starknet_api::{calldata, class_hash, contract_address, patricia_key, stark_felt}; use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address}; -use crate::concurrency::test_utils::versioned_state_for_testing; +use crate::concurrency::test_utils::safe_versioned_state_for_testing; use crate::concurrency::versioned_state_proxy::{ ThreadSafeVersionedState, VersionedState, VersionedStateProxy, }; @@ -41,20 +41,18 @@ pub fn class_hash() -> ClassHash { } #[fixture] -pub fn versioned_state( +pub fn safe_versioned_state( contract_address: ContractAddress, class_hash: ClassHash, -) -> Arc>> { +) -> ThreadSafeVersionedState { let init_state = DictStateReader { - storage_view: HashMap::default(), - address_to_nonce: HashMap::default(), address_to_class_hash: HashMap::from([(contract_address, class_hash)]), - class_hash_to_compiled_class_hash: HashMap::default(), - class_hash_to_class: HashMap::default(), + ..Default::default() }; - versioned_state_for_testing(init_state) + safe_versioned_state_for_testing(init_state) } +// TODO(OriF 15/5/24): Use `create_transactional` instead of `CachedState::from(..)` when fits. #[test] fn test_versioned_state_proxy() { // Test data @@ -67,7 +65,7 @@ fn test_versioned_state_proxy() { let compiled_class_hash = compiled_class_hash!(29_u8); let contract_class = test_contract.get_class(); - // Create the verioned state + // Create the versioned state let cached_state = CachedState::from(DictStateReader { storage_view: HashMap::from([((contract_address, key), stark_felt)]), address_to_nonce: HashMap::from([(contract_address, nonce)]), @@ -275,15 +273,14 @@ fn test_run_parallel_txs() { fn test_validate_read_set( contract_address: ContractAddress, class_hash: ClassHash, - versioned_state: Arc>>, + safe_versioned_state: ThreadSafeVersionedState, ) { let storage_key = storage_key!("0x10"); - let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); let transactional_state = CachedState::from(safe_versioned_state.pin_version(1)); // Validating tx index 0 always succeeds. - assert!(versioned_state.lock().unwrap().validate_read_set(0, &StateMaps::default())); + assert!(safe_versioned_state.state().validate_read_set(0, &StateMaps::default())); assert!(transactional_state.cache.borrow().initial_reads.storage.is_empty()); transactional_state.get_storage_at(contract_address, storage_key).unwrap(); @@ -305,9 +302,8 @@ fn test_validate_read_set( // preceding a declare flow is solved. assert!( - versioned_state - .lock() - .unwrap() + safe_versioned_state + .state() .validate_read_set(1, &transactional_state.cache.borrow().initial_reads) ); } @@ -316,9 +312,8 @@ fn test_validate_read_set( fn test_apply_writes( contract_address: ContractAddress, class_hash: ClassHash, - versioned_state: Arc>>, + safe_versioned_state: ThreadSafeVersionedState, ) { - let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); let mut transactional_states: Vec>> = (0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); @@ -334,7 +329,7 @@ fn test_apply_writes( transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap(); assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1); - versioned_state.lock().unwrap().apply_writes( + safe_versioned_state.state().apply_writes( 0, &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), @@ -350,9 +345,8 @@ fn test_apply_writes( fn test_apply_writes_reexecute_scenario( contract_address: ContractAddress, class_hash: ClassHash, - versioned_state: Arc>>, + safe_versioned_state: ThreadSafeVersionedState, ) { - let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); let mut transactional_states: Vec>> = (0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); @@ -364,7 +358,7 @@ fn test_apply_writes_reexecute_scenario( // updated. assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash); - versioned_state.lock().unwrap().apply_writes( + safe_versioned_state.state().apply_writes( 0, &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(),