Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
refactor(concurrency): use test_state instead of test_state_reader in…
Browse files Browse the repository at this point in the history
… testing (#1943)
  • Loading branch information
OriStarkware authored Jun 4, 2024
1 parent efaa69f commit 79b6ae8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 43 deletions.
4 changes: 2 additions & 2 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ macro_rules! default_scheduler {

// TODO(meshi, 01/06/2024): Consider making this a macro.
pub fn safe_versioned_state_for_testing(
block_state: DictStateReader,
) -> ThreadSafeVersionedState<DictStateReader> {
block_state: CachedState<DictStateReader>,
) -> ThreadSafeVersionedState<CachedState<DictStateReader>> {
ThreadSafeVersionedState::new(VersionedState::new(block_state))
}

Expand Down
30 changes: 15 additions & 15 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key};
pub fn safe_versioned_state(
contract_address: ContractAddress,
class_hash: ClassHash,
) -> ThreadSafeVersionedState<DictStateReader> {
) -> ThreadSafeVersionedState<CachedState<DictStateReader>> {
let init_state = DictStateReader {
address_to_class_hash: HashMap::from([(contract_address, class_hash)]),
..Default::default()
};
safe_versioned_state_for_testing(init_state)
safe_versioned_state_for_testing(CachedState::new(init_state))
}

#[test]
Expand Down Expand Up @@ -282,7 +282,7 @@ fn test_run_parallel_txs() {
fn test_validate_reads(
contract_address: ContractAddress,
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let storage_key = storage_key!("0x10");

Expand Down Expand Up @@ -328,12 +328,12 @@ fn test_validate_reads(
fn test_apply_writes(
contract_address: ContractAddress,
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
let mut versioned_proxy_states: Vec<VersionedStateProxy<CachedState<DictStateReader>>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
TransactionalState<'_, VersionedStateProxy<CachedState<DictStateReader>>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Transaction 0 class hash.
Expand Down Expand Up @@ -364,12 +364,12 @@ fn test_apply_writes(
fn test_apply_writes_reexecute_scenario(
contract_address: ContractAddress,
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
let mut versioned_proxy_states: Vec<VersionedStateProxy<CachedState<DictStateReader>>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
TransactionalState<'_, VersionedStateProxy<CachedState<DictStateReader>>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Transaction 0 class hash.
Expand Down Expand Up @@ -400,13 +400,13 @@ fn test_apply_writes_reexecute_scenario(
#[rstest]
fn test_delete_writes(
#[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let num_of_txs = 3;
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
let mut versioned_proxy_states: Vec<VersionedStateProxy<CachedState<DictStateReader>>> =
(0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<
TransactionalState<'_, VersionedStateProxy<DictStateReader>>,
TransactionalState<'_, VersionedStateProxy<CachedState<DictStateReader>>>,
> = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();

// Setting 2 instances of the contract to ensure `delete_writes` removes information from
Expand Down Expand Up @@ -465,7 +465,7 @@ fn test_delete_writes(

#[rstest]
fn test_delete_writes_completeness(
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let state_maps_writes = StateMaps {
Expand Down Expand Up @@ -527,13 +527,13 @@ fn test_delete_writes_completeness(

#[rstest]
fn test_versioned_proxy_state_flow(
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader>>,
) {
let contract_address = contract_address!("0x1");
let class_hash = ClassHash(stark_felt!(27_u8));

let mut block_state = CachedState::from(DictStateReader::default());
let mut versioned_proxy_states: Vec<VersionedStateProxy<DictStateReader>> =
let mut versioned_proxy_states: Vec<VersionedStateProxy<CachedState<DictStateReader>>> =
(0..4).map(|i| safe_versioned_state.pin_version(i)).collect();

let mut transactional_states = Vec::with_capacity(4);
Expand Down
23 changes: 9 additions & 14 deletions crates/blockifier/src/concurrency/worker_logic_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::state::cached_state::StateMaps;
use crate::state::state_api::StateReader;
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::declare::declare_tx;
use crate::test_utils::initial_test_state::test_state_reader;
use crate::test_utils::initial_test_state::test_state;
use crate::test_utils::{
create_calldata, CairoVersion, NonceManager, BALANCE, MAX_FEE, MAX_L1_GAS_AMOUNT,
MAX_L1_GAS_PRICE, TEST_ERC20_CONTRACT_ADDRESS,
Expand All @@ -43,9 +43,8 @@ fn test_worker_execute() {
let chain_info = &block_context.chain_info;

// Create the state.
let state_reader =
test_state_reader(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]);
let safe_versioned_state = safe_versioned_state_for_testing(state_reader);
let state = test_state(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]);
let safe_versioned_state = safe_versioned_state_for_testing(state);

// Create transactions.
let test_contract_address = test_contract.get_instance_address(0);
Expand Down Expand Up @@ -216,9 +215,8 @@ fn test_worker_validate() {
let chain_info = &block_context.chain_info;

// Create the state.
let state_reader =
test_state_reader(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]);
let safe_versioned_state = safe_versioned_state_for_testing(state_reader);
let state = test_state(chain_info, BALANCE, &[(account_contract, 1), (test_contract, 1)]);
let safe_versioned_state = safe_versioned_state_for_testing(state);

// Create transactions.
let test_contract_address = test_contract.get_instance_address(0);
Expand Down Expand Up @@ -320,11 +318,8 @@ pub fn test_add_fee_to_sequencer_balance(
let tx_index = 0;
let block_context = BlockContext::create_for_account_testing_with_concurrency_mode(true);
let account = FeatureContract::Empty(CairoVersion::Cairo1);
let safe_versioned_state = safe_versioned_state_for_testing(test_state_reader(
&block_context.chain_info,
0,
&[(account, 1)],
));
let safe_versioned_state =
safe_versioned_state_for_testing(test_state(&block_context.chain_info, 0, &[(account, 1)]));
let mut tx_versioned_state = safe_versioned_state.pin_version(tx_index);
let (sequencer_balance_key_low, sequencer_balance_key_high) =
get_sequencer_balance_keys(&block_context);
Expand Down Expand Up @@ -371,8 +366,8 @@ fn test_deploy_before_declare() {
let block_context = BlockContext::create_for_account_testing_with_concurrency_mode(true);
let chain_info = &block_context.chain_info;
let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let state_reader = test_state_reader(chain_info, BALANCE, &[(account_contract, 2)]);
let safe_versioned_state = safe_versioned_state_for_testing(state_reader);
let state = test_state(chain_info, BALANCE, &[(account_contract, 2)]);
let safe_versioned_state = safe_versioned_state_for_testing(state);

// Create transactions.
let account_address_0 = account_contract.get_instance_address(0);
Expand Down
15 changes: 3 additions & 12 deletions crates/blockifier/src/test_utils/initial_test_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pub fn fund_account(
/// * "Declares" the input list of contracts.
/// * "Deploys" the requested number of instances of each input contract.
/// * Makes each input account contract privileged.
pub fn test_state_reader(
pub fn test_state(
chain_info: &ChainInfo,
initial_balances: u128,
contract_instances: &[(FeatureContract, u16)],
) -> DictStateReader {
) -> CachedState<DictStateReader> {
let mut class_hash_to_class = HashMap::new();
let mut address_to_class_hash = HashMap::new();

Expand Down Expand Up @@ -81,14 +81,5 @@ pub fn test_state_reader(
}
}

state_reader
}

/// Initializes a state for testing, with the output of test_state_reader as the initial state.
pub fn test_state(
chain_info: &ChainInfo,
initial_balances: u128,
contract_instances: &[(FeatureContract, u16)],
) -> CachedState<DictStateReader> {
CachedState::from(test_state_reader(chain_info, initial_balances, contract_instances))
CachedState::from(state_reader)
}

0 comments on commit 79b6ae8

Please sign in to comment.