Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
refactor(concurrency): use ThreadSafeVersionedState in tests and …
Browse files Browse the repository at this point in the history
…define appropriate methods (#1857)
  • Loading branch information
noaov1 authored May 5, 2024
1 parent 96d8343 commit 9fef3de
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
12 changes: 4 additions & 8 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::sync::{Arc, Mutex};

use crate::concurrency::versioned_state_proxy::VersionedState;
use crate::concurrency::versioned_state_proxy::{ThreadSafeVersionedState, VersionedState};
use crate::test_utils::dict_state_reader::DictStateReader;

#[macro_export]
Expand Down Expand Up @@ -32,10 +30,8 @@ macro_rules! default_scheduler {
}

// TODO(meshi, 01/06/2024): Consider making this a macro.
// TODO: Use `ThreadSafeVersionedState` as a return value, need to change the inner field to public,
// or add new method that gets `StateReader`.
pub fn versioned_state_for_testing(
pub fn safe_versioned_state_for_testing(
block_state: DictStateReader,
) -> Arc<Mutex<VersionedState<DictStateReader>>> {
Arc::new(Mutex::new(VersionedState::new(block_state)))
) -> ThreadSafeVersionedState<DictStateReader> {
ThreadSafeVersionedState::new(VersionedState::new(block_state))
}
8 changes: 8 additions & 0 deletions crates/blockifier/src/concurrency/versioned_state_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,17 @@ pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<VersionedState<S>>
pub type LockedVersionedState<'a, S> = MutexGuard<'a, VersionedState<S>>;

impl<S: StateReader> ThreadSafeVersionedState<S> {
pub fn new(versioned_state: VersionedState<S>) -> Self {
ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state)))
}

pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
}

pub fn state(&self) -> LockedVersionedState<'_, S> {
self.0.lock().expect("Failed to acquire state lock.")
}
}

pub struct VersionedStateProxy<S: StateReader> {
Expand Down
36 changes: 15 additions & 21 deletions crates/blockifier/src/concurrency/versioned_state_proxy_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use starknet_api::transaction::{Calldata, ContractAddressSalt, Fee, TransactionV
use starknet_api::{calldata, class_hash, contract_address, patricia_key, stark_felt};

use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address};
use crate::concurrency::test_utils::versioned_state_for_testing;
use crate::concurrency::test_utils::safe_versioned_state_for_testing;
use crate::concurrency::versioned_state_proxy::{
ThreadSafeVersionedState, VersionedState, VersionedStateProxy,
};
Expand Down Expand Up @@ -41,20 +41,18 @@ pub fn class_hash() -> ClassHash {
}

#[fixture]
pub fn versioned_state(
pub fn safe_versioned_state(
contract_address: ContractAddress,
class_hash: ClassHash,
) -> Arc<Mutex<VersionedState<DictStateReader>>> {
) -> ThreadSafeVersionedState<DictStateReader> {
let init_state = DictStateReader {
storage_view: HashMap::default(),
address_to_nonce: HashMap::default(),
address_to_class_hash: HashMap::from([(contract_address, class_hash)]),
class_hash_to_compiled_class_hash: HashMap::default(),
class_hash_to_class: HashMap::default(),
..Default::default()
};
versioned_state_for_testing(init_state)
safe_versioned_state_for_testing(init_state)
}

// TODO(OriF 15/5/24): Use `create_transactional` instead of `CachedState::from(..)` when fits.
#[test]
fn test_versioned_state_proxy() {
// Test data
Expand All @@ -67,7 +65,7 @@ fn test_versioned_state_proxy() {
let compiled_class_hash = compiled_class_hash!(29_u8);
let contract_class = test_contract.get_class();

// Create the verioned state
// Create the versioned state
let cached_state = CachedState::from(DictStateReader {
storage_view: HashMap::from([((contract_address, key), stark_felt)]),
address_to_nonce: HashMap::from([(contract_address, nonce)]),
Expand Down Expand Up @@ -275,15 +273,14 @@ fn test_run_parallel_txs() {
fn test_validate_read_set(
contract_address: ContractAddress,
class_hash: ClassHash,
versioned_state: Arc<Mutex<VersionedState<DictStateReader>>>,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let storage_key = storage_key!("0x10");

let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let transactional_state = CachedState::from(safe_versioned_state.pin_version(1));

// Validating tx index 0 always succeeds.
assert!(versioned_state.lock().unwrap().validate_read_set(0, &StateMaps::default()));
assert!(safe_versioned_state.state().validate_read_set(0, &StateMaps::default()));

assert!(transactional_state.cache.borrow().initial_reads.storage.is_empty());
transactional_state.get_storage_at(contract_address, storage_key).unwrap();
Expand All @@ -305,9 +302,8 @@ fn test_validate_read_set(
// preceding a declare flow is solved.

assert!(
versioned_state
.lock()
.unwrap()
safe_versioned_state
.state()
.validate_read_set(1, &transactional_state.cache.borrow().initial_reads)
);
}
Expand All @@ -316,9 +312,8 @@ fn test_validate_read_set(
fn test_apply_writes(
contract_address: ContractAddress,
class_hash: ClassHash,
versioned_state: Arc<Mutex<VersionedState<DictStateReader>>>,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();

Expand All @@ -334,7 +329,7 @@ fn test_apply_writes(
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);

versioned_state.lock().unwrap().apply_writes(
safe_versioned_state.state().apply_writes(
0,
&transactional_states[0].cache.borrow().writes,
&transactional_states[0].class_hash_to_class.borrow().clone(),
Expand All @@ -350,9 +345,8 @@ fn test_apply_writes(
fn test_apply_writes_reexecute_scenario(
contract_address: ContractAddress,
class_hash: ClassHash,
versioned_state: Arc<Mutex<VersionedState<DictStateReader>>>,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();

Expand All @@ -364,7 +358,7 @@ fn test_apply_writes_reexecute_scenario(
// updated.
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash);

versioned_state.lock().unwrap().apply_writes(
safe_versioned_state.state().apply_writes(
0,
&transactional_states[0].cache.borrow().writes,
&transactional_states[0].class_hash_to_class.borrow().clone(),
Expand Down

0 comments on commit 9fef3de

Please sign in to comment.