Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
refactor(concurrency): use TransactionalState::create_transactional i…
Browse files Browse the repository at this point in the history
…n tests (#1920)
  • Loading branch information
OriStarkware authored Jun 2, 2024
1 parent 5fdc9ea commit 693da0e
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ pub fn safe_versioned_state(
safe_versioned_state_for_testing(init_state)
}

// TODO(OriF 15/5/24): Use `TransactionalState::create_transactional` instead of
// `CachedState::from(..)` when fits.
#[test]
fn test_versioned_state_proxy() {
// Test data
Expand Down Expand Up @@ -215,8 +213,10 @@ fn test_run_parallel_txs() {
))));

let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let mut state_1 = CachedState::from(safe_versioned_state.pin_version(1));
let mut state_2 = CachedState::from(safe_versioned_state.pin_version(2));
let mut versioned_state_proxy_1 = safe_versioned_state.pin_version(1);
let mut state_1 = TransactionalState::create_transactional(&mut versioned_state_proxy_1);
let mut versioned_state_proxy_2 = safe_versioned_state.pin_version(2);
let mut state_2 = TransactionalState::create_transactional(&mut versioned_state_proxy_2);

// Prepare transactions
let deploy_account_tx_1 = deploy_account_tx(
Expand Down Expand Up @@ -254,30 +254,28 @@ fn test_run_parallel_txs() {
let block_context_1 = block_context.clone();
let block_context_2 = block_context.clone();
// Execute transactions
let thread_handle_1 = thread::spawn(move || {
let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true);
assert_eq!(result.is_err(), enforce_fee);
});

let thread_handle_2 = thread::spawn(move || {
account_tx_2.execute(&mut state_2, &block_context_2, true, true).unwrap();

// Check that the constructor wrote ctor_arg to the storage.
let storage_key = get_storage_var_address("ctor_arg", &[]);
let deployed_contract_address = calculate_contract_address(
ContractAddressSalt::default(),
class_hash,
&constructor_calldata,
ContractAddress::default(),
)
.unwrap();
let read_storage_arg =
state_2.get_storage_at(deployed_contract_address, storage_key).unwrap();
assert_eq!(ctor_storage_arg, read_storage_arg);
thread::scope(|s| {
s.spawn(move || {
let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true);
assert_eq!(result.is_err(), enforce_fee);
});
s.spawn(move || {
account_tx_2.execute(&mut state_2, &block_context_2, true, true).unwrap();

// Check that the constructor wrote ctor_arg to the storage.
let storage_key = get_storage_var_address("ctor_arg", &[]);
let deployed_contract_address = calculate_contract_address(
ContractAddressSalt::default(),
class_hash,
&constructor_calldata,
ContractAddress::default(),
)
.unwrap();
let read_storage_arg =
state_2.get_storage_at(deployed_contract_address, storage_key).unwrap();
assert_eq!(ctor_storage_arg, read_storage_arg);
});
});

thread_handle_1.join().unwrap();
thread_handle_2.join().unwrap();
}

#[rstest]
Expand All @@ -288,7 +286,8 @@ fn test_validate_reads(
) {
let storage_key = storage_key!("0x10");

let transactional_state = CachedState::from(safe_versioned_state.pin_version(1));
let mut version_state_proxy = safe_versioned_state.pin_version(1);
let transactional_state = TransactionalState::create_transactional(&mut version_state_proxy);

// Validating tx index 0 always succeeds.
assert!(safe_versioned_state.pin_version(0).validate_reads(&StateMaps::default()));
Expand Down Expand Up @@ -331,8 +330,11 @@ fn test_apply_writes(
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Transaction 0 class hash.
let class_hash_0 = class_hash!(76_u8);
Expand Down Expand Up @@ -364,8 +366,11 @@ fn test_apply_writes_reexecute_scenario(
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..2).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Transaction 0 class hash.
let class_hash_0 = class_hash!(76_u8);
Expand All @@ -386,7 +391,8 @@ fn test_apply_writes_reexecute_scenario(

// TODO: Use re-execution native util once it's ready.
// "Re-execute" the transaction.
transactional_states[1] = CachedState::from(safe_versioned_state.pin_version(1));
let mut versioned_state_proxy = safe_versioned_state.pin_version(1);
transactional_states[1] = TransactionalState::create_transactional(&mut versioned_state_proxy);
// The class hash should be updated.
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
}
Expand All @@ -397,16 +403,20 @@ fn test_delete_writes(
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let num_of_txs = 3;
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..num_of_txs).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
(0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Setting 2 instances of the contract to ensure `delete_writes` removes information from
// multiple keys. Class hash values are not checked in this test.
let contract_addresses = [
(contract_address!("0x100"), class_hash!(20_u8)),
(contract_address!("0x200"), class_hash!(21_u8)),
];
let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
for tx_state in transactional_states.iter_mut() {
for (i, tx_state) in transactional_states.iter_mut().enumerate() {
// Modify the `cache` member of the CachedState.
for (contract_address, class_hash) in contract_addresses.iter() {
tx_state.set_class_hash_at(*contract_address, *class_hash).unwrap();
Expand All @@ -415,14 +425,14 @@ fn test_delete_writes(
tx_state
.set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class())
.unwrap();
tx_state.state.apply_writes(
safe_versioned_state.pin_version(i).apply_writes(
&tx_state.cache.borrow().writes,
&tx_state.class_hash_to_class.borrow(),
&HashMap::default(),
);
}

transactional_states[tx_index_to_delete_writes].state.delete_writes(
safe_versioned_state.pin_version(tx_index_to_delete_writes).delete_writes(
&transactional_states[tx_index_to_delete_writes].cache.borrow().writes,
&transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(),
);
Expand Down

0 comments on commit 693da0e

Please sign in to comment.