Skip to content

Commit

Permalink
feat(mempool_test_utils,starknet_integration_tests): prepare funded b…
Browse files Browse the repository at this point in the history
…ut undeployed accounts
  • Loading branch information
yair-starkware committed Jan 2, 2025
1 parent fa1c91c commit dce8816
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
6 changes: 3 additions & 3 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ impl MultiAccountTransactionGenerator {
})
}

pub fn accounts(&self) -> Vec<Contract> {
self.account_tx_generators.iter().map(|tx_gen| &tx_gen.account).copied().collect()
pub fn accounts(&self) -> &[AccountTransactionGenerator] {
self.account_tx_generators.as_slice()
}
}

Expand All @@ -272,7 +272,7 @@ impl MultiAccountTransactionGenerator {
/// with room for future extensions.
///
/// TODO: add more transaction generation methods as needed.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct AccountTransactionGenerator {
pub account: Contract,
nonce_manager: SharedNonceManager,
Expand Down
11 changes: 7 additions & 4 deletions crates/starknet_integration_tests/src/flow_test_setup.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::net::SocketAddr;

use blockifier::context::ChainInfo;
use mempool_test_utils::starknet_api_test_utils::{Contract, MultiAccountTransactionGenerator};
use mempool_test_utils::starknet_api_test_utils::{
AccountTransactionGenerator,
MultiAccountTransactionGenerator,
};
use papyrus_network::network_manager::BroadcastTopicChannels;
use papyrus_protobuf::consensus::{ProposalPart, StreamMessage};
use starknet_api::rpc_transaction::RpcTransaction;
Expand Down Expand Up @@ -66,7 +69,7 @@ impl FlowTestSetup {

// Create nodes one after the other in order to make sure the ports are not overlapping.
let sequencer_0 = FlowSequencerSetup::new(
accounts.clone(),
accounts.to_vec(),
SEQUENCER_0,
chain_info.clone(),
sequencer_0_consensus_manager_config,
Expand All @@ -76,7 +79,7 @@ impl FlowTestSetup {
.await;

let sequencer_1 = FlowSequencerSetup::new(
accounts,
accounts.to_vec(),
SEQUENCER_1,
chain_info,
sequencer_1_consensus_manager_config,
Expand Down Expand Up @@ -114,7 +117,7 @@ pub struct FlowSequencerSetup {
impl FlowSequencerSetup {
#[instrument(skip(accounts, chain_info, consensus_manager_config), level = "debug")]
pub async fn new(
accounts: Vec<Contract>,
accounts: Vec<AccountTransactionGenerator>,
sequencer_index: usize,
chain_info: ChainInfo,
consensus_manager_config: ConsensusManagerConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::net::SocketAddr;
use std::path::PathBuf;

use blockifier::context::ChainInfo;
use mempool_test_utils::starknet_api_test_utils::{Contract, MultiAccountTransactionGenerator};
use mempool_test_utils::starknet_api_test_utils::{
AccountTransactionGenerator,
MultiAccountTransactionGenerator,
};
use papyrus_storage::{StorageConfig, StorageReader};
use starknet_api::rpc_transaction::RpcTransaction;
use starknet_api::transaction::TransactionHash;
Expand Down Expand Up @@ -61,7 +64,7 @@ impl IntegrationTestSetup {
let consensus_manager_config = consensus_manager_configs.remove(0);
let mempool_p2p_config = mempool_p2p_configs.remove(0);
let sequencer = IntegrationSequencerSetup::new(
accounts.clone(),
accounts.to_vec(),
sequencer_id,
chain_info.clone(),
consensus_manager_config,
Expand Down Expand Up @@ -141,7 +144,7 @@ pub struct IntegrationSequencerSetup {
impl IntegrationSequencerSetup {
#[instrument(skip(accounts, chain_info, consensus_manager_config), level = "debug")]
pub async fn new(
accounts: Vec<Contract>,
accounts: Vec<AccountTransactionGenerator>,
sequencer_index: usize,
chain_info: ChainInfo,
consensus_manager_config: ConsensusManagerConfig,
Expand Down
54 changes: 43 additions & 11 deletions crates/starknet_integration_tests/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use blockifier::test_utils::{CairoVersion, RunnableCairo1, BALANCE};
use blockifier::versioned_constants::VersionedConstants;
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use indexmap::IndexMap;
use mempool_test_utils::starknet_api_test_utils::Contract;
use mempool_test_utils::starknet_api_test_utils::{AccountTransactionGenerator, Contract};
use papyrus_common::pending_classes::PendingClasses;
use papyrus_rpc::{run_server, RpcConfig};
use papyrus_storage::body::BodyStorageWriter;
Expand Down Expand Up @@ -57,7 +57,10 @@ pub struct StorageTestSetup {
}

impl StorageTestSetup {
pub fn new(test_defined_accounts: Vec<Contract>, chain_info: &ChainInfo) -> Self {
pub fn new(
test_defined_accounts: Vec<AccountTransactionGenerator>,
chain_info: &ChainInfo,
) -> Self {
let ((_, mut batcher_storage_writer), batcher_storage_config, batcher_storage_file_handle) =
TestStorageBuilder::default()
.scope(StorageScope::StateOnly)
Expand Down Expand Up @@ -86,7 +89,7 @@ impl StorageTestSetup {
fn create_test_state(
storage_writer: &mut StorageWriter,
chain_info: &ChainInfo,
test_defined_accounts: Vec<Contract>,
test_defined_accounts: Vec<AccountTransactionGenerator>,
) {
let into_contract = |contract: FeatureContract| Contract {
contract,
Expand Down Expand Up @@ -115,7 +118,7 @@ fn create_test_state(
fn initialize_papyrus_test_state(
storage_writer: &mut StorageWriter,
chain_info: &ChainInfo,
test_defined_accounts: Vec<Contract>,
test_defined_accounts: Vec<AccountTransactionGenerator>,
default_test_contracts: Vec<Contract>,
erc20_contract: Contract,
) {
Expand All @@ -126,8 +129,11 @@ fn initialize_papyrus_test_state(
&erc20_contract,
);

let contract_classes_to_retrieve =
test_defined_accounts.into_iter().chain(default_test_contracts).chain([erc20_contract]);
let contract_classes_to_retrieve = test_defined_accounts
.into_iter()
.map(|acc| acc.account)
.chain(default_test_contracts)
.chain([erc20_contract]);
let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone());
let (cairo0_contract_classes, cairo1_contract_classes) =
prepare_compiled_contract_classes(contract_classes_to_retrieve);
Expand All @@ -143,7 +149,7 @@ fn initialize_papyrus_test_state(

fn prepare_state_diff(
chain_info: &ChainInfo,
test_defined_accounts: &[Contract],
test_defined_accounts: &[AccountTransactionGenerator],
default_test_contracts: &[Contract],
erc20_contract: &Contract,
) -> ThinStateDiff {
Expand All @@ -161,7 +167,17 @@ fn prepare_state_diff(
// state_diff_builder.set_contracts(accounts_defined_in_the_test).declare().fund();
// ```
// or use declare txs and transfers for both.
state_diff_builder.inject_accounts_into_state(test_defined_accounts);
let (deployed_accounts, undeployed_accounts): (Vec<_>, Vec<_>) =
test_defined_accounts.iter().partition(|account| account.is_deployed());

let deployed_accounts_contracts: Vec<_> =
deployed_accounts.iter().map(|acc| acc.account).collect();
let undeployed_accounts_contracts: Vec<_> =
undeployed_accounts.iter().map(|acc| acc.account).collect();

state_diff_builder.inject_deployed_accounts_into_state(deployed_accounts_contracts.as_slice());
state_diff_builder
.inject_undeployed_accounts_into_state(undeployed_accounts_contracts.as_slice());

state_diff_builder.build()
}
Expand Down Expand Up @@ -385,9 +401,25 @@ impl<'a> ThinStateDiffBuilder<'a> {
self
}

// TODO(deploy_account_support): delete method once we have batcher with execution.
fn inject_accounts_into_state(&mut self, accounts_defined_in_the_test: &'a [Contract]) {
self.set_contracts(accounts_defined_in_the_test).declare().deploy().fund();
fn inject_deployed_accounts_into_state(
&mut self,
deployed_accounts_defined_in_the_test: &'a [Contract],
) {
self.set_contracts(deployed_accounts_defined_in_the_test).declare().deploy().fund();

// Set nonces as 1 in the state so that subsequent invokes can pass validation.
self.nonces = self
.deployed_contracts
.iter()
.map(|(&address, _)| (address, Nonce(Felt::ONE)))
.collect();
}

fn inject_undeployed_accounts_into_state(
&mut self,
undeployed_accounts_defined_in_the_test: &'a [Contract],
) {
self.set_contracts(undeployed_accounts_defined_in_the_test).declare().fund();

// Set nonces as 1 in the state so that subsequent invokes can pass validation.
self.nonces = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn setup(
configure_tracing().await;
let accounts = tx_generator.accounts();
let chain_info = create_chain_info();
let storage_for_test = StorageTestSetup::new(accounts, &chain_info);
let storage_for_test = StorageTestSetup::new(accounts.to_vec(), &chain_info);
let mut available_ports = AvailablePorts::new(test_identifier.into(), 0);

// Derive the configuration for the mempool node.
Expand Down

0 comments on commit dce8816

Please sign in to comment.