Skip to content

Commit

Permalink
rework expansion for contract to use ConnectedAccount
Browse files Browse the repository at this point in the history
  • Loading branch information
glihm committed Oct 10, 2023
1 parent ca6d5f0 commit 06e1d3a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 67 deletions.
25 changes: 19 additions & 6 deletions examples/abigen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,29 @@ use starknet::{
signers::{LocalWallet, SigningKey},
};

use std::sync::Arc;

// Generate the bindings for the contract and also includes
// all the structs and enums present in the ABI with the exact
// same name.
abigen!(TokenContract, "./examples/contracts_abis/mini_erc20.json");

#[tokio::main]
async fn main() {
let provider = Arc::new(SequencerGatewayProvider::starknet_alpha_goerli());
let provider = SequencerGatewayProvider::starknet_alpha_goerli();
println!("provider {:?}", provider);
let eth_goerli_token_address = FieldElement::from_hex_be(
"0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
)
.unwrap();

let token_contract = TokenContract::new(eth_goerli_token_address, Arc::clone(&provider));
// If you only plan to call views functions, you can use the `Reader`, which
// only requires a provider along with your contract address.
let token_contract = TokenContractReader::new(eth_goerli_token_address, &provider);

// To call a view, there is no need to initialize an account. You can directly
// use the name of the method in the ABI to realize the call.
let balance: u256 = token_contract
.balanceOf(&ContractAddress(
FieldElement::from_hex_be("YOUR_ACCOUNT_ADDRESS_HEX_HERE").unwrap(),
FieldElement::from_hex_be("YOUR_HEX_CONTRACT_ADDRESS_HERE").unwrap(),
))
.await
.expect("Call to get balance failed");
Expand All @@ -55,7 +56,19 @@ async fn main() {
ExecutionEncoding::Legacy,
);

let token_contract = token_contract.with_account(Arc::new(account));
// The `TokenContract` also contains a reader field that you can use if you need both
// to call external and views with the same instance.
let token_contract = TokenContract::new(eth_goerli_token_address, &account);

// Example here of querying again the balance, using the internal reader of the
// contract setup with an account.
token_contract
.reader
.balanceOf(&ContractAddress(
FieldElement::from_hex_be("YOUR_HEX_CONTRACT_ADDRESS_HERE").unwrap(),
))
.await
.expect("Call to get balance failed");

let _ = token_contract
.approve(
Expand Down
13 changes: 8 additions & 5 deletions examples/abigen_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ async fn main() {
let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(rpc_url.clone())));

let signer = LocalWallet::from(SigningKey::from_secret_scalar(
FieldElement::from_hex_be("YOUR_PRIVATE_KEY_IN_HEX_HERE").unwrap(),
FieldElement::from_hex_be("0x1800000000300000180000000000030000000000003006001800006600")
.unwrap(),
));
let address = FieldElement::from_hex_be("YOUR_ACCOUNT_CONTRACT_ADDRESS_IN_HEX_HERE").unwrap();
let address = FieldElement::from_hex_be(
"0x517ececd29116499f4a1b64b094da79ba08dfd54a3edaa316134c41f8160973",
)
.unwrap();
let account = SingleOwnerAccount::new(
provider.clone(),
signer,
Expand All @@ -30,10 +34,9 @@ async fn main() {
ExecutionEncoding::Legacy,
);

let contract_address = FieldElement::from_hex_be("CONTRACT_ADDRESS_HEX").unwrap();
let contract_address = FieldElement::from_hex_be("YOUR_CONTRACT_ADDRESS_HEX").unwrap();

let event_contract =
Contract::new(contract_address, Arc::clone(&provider)).with_account(Arc::new(account));
let event_contract = Contract::new(contract_address, &account);

// Let emits some events by calling two externals.
event_contract
Expand Down
53 changes: 27 additions & 26 deletions starknet-macros/src/abigen/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,53 @@
//! default configuration for provider and account, if any.
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;

use syn::Ident;

use super::utils;

pub struct CairoContract;

impl CairoContract {
pub fn expand(contract_name: Ident) -> TokenStream2 {
quote! {
let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str());
let q = quote! {

#[derive(Debug)]
pub struct #contract_name<P>
where
P: starknet::providers::Provider + Send + Sync, <P as starknet::providers::Provider>::Error: 'static
{
pub struct #contract_name<'a, A: starknet::accounts::ConnectedAccount + Sync> {
pub address: starknet::core::types::FieldElement,
pub provider: std::sync::Arc<P>,
pub account: std::option::Option<std::sync::Arc<starknet::accounts::SingleOwnerAccount<std::sync::Arc<P>, starknet::signers::LocalWallet>>>,
pub account: &'a A,
pub reader: #reader<'a, A::Provider>,
}

impl<'a, A: starknet::accounts::ConnectedAccount + Sync> #contract_name<'a, A> {
pub fn new(address: starknet::core::types::FieldElement, account: &'a A) -> Self {
let reader = #reader::new(address, account.provider());
Self { address, account, reader }
}
}

#[derive(Debug)]
pub struct #reader<'a, P: Provider + Sync> {
pub address: starknet::core::types::FieldElement,
pub provider: &'a P,
call_block_id: starknet::core::types::BlockId,
}

impl<P> #contract_name<P>
where
P: starknet::providers::Provider + Send + Sync, <P as starknet::providers::Provider>::Error: 'static
{
impl<'a, P: starknet::providers::Provider + Sync> #reader<'a, P> {
pub fn new(
address: starknet::core::types::FieldElement,
provider: std::sync::Arc<P>,
) -> Self {
Self {
address,
provider: std::sync::Arc::clone(&provider),
account: None,
call_block_id: starknet::core::types::BlockId::Tag(starknet::core::types::BlockTag::Pending),
}
}

pub fn with_account(mut self, account: std::sync::Arc<starknet::accounts::SingleOwnerAccount<std::sync::Arc<P>, starknet::signers::LocalWallet>>,
provider: &'a P,
) -> Self {
self.account = Some(std::sync::Arc::clone(&account));
self
let call_block_id = starknet::core::types::BlockId::Tag(starknet::core::types::BlockTag::Pending);
Self { address, provider, call_block_id }
}

pub fn set_call_block_id(mut self, block_id: starknet::core::types::BlockId) {
self.call_block_id = block_id;
}
}
}
};

q
}
}
22 changes: 3 additions & 19 deletions starknet-macros/src/abigen/expand/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ impl Expandable for CairoFunction {
},
StateMutability::External => {
quote!(-> Result<starknet::core::types::InvokeTransactionResult,
starknet::accounts::AccountError<starknet::accounts::single_owner::SignError<starknet::signers::local_wallet::SignError>, <P as starknet::providers::Provider>::Error>
>
starknet::accounts::AccountError<A::SignError, <A::Provider as starknet::providers::Provider>::Error>>
)
}
};
Expand Down Expand Up @@ -120,15 +119,6 @@ impl Expandable for CairoFunction {
use starknet::contract::abi::CairoType;
use starknet::accounts::Account;

// TODO: I don't know how to easily store the SingleOwnerAccount
// and it's generic types without complexifiying the whole typing.
// So it's constructed at every call. There is surely a better approach.
let account = match &self.account {
Some(a) => std::sync::Arc::clone(&a),
// TODO: better error handling here.
_ => panic!("Account is required to send invoke transactions")
};

let mut calldata = vec![];
#(#serializations)*

Expand All @@ -138,14 +128,8 @@ impl Expandable for CairoFunction {
calldata,
}];

let execution = account.execute(calls).fee_estimate_multiplier(2f64);
// TODO: we can have manual fee here, or it can also be estimate only.
let max_fee = execution.estimate_fee().await?.overall_fee.into();

execution
.max_fee(max_fee)
.send()
.await
// TODO: add a way for fee estimation and max fee to be parametrizable.
self.account.execute(calls).send().await
}
},
}
Expand Down
38 changes: 27 additions & 11 deletions starknet-macros/src/abigen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::collections::HashMap;

use starknet_contract::abi::cairo_types::{CAIRO_BASIC_ENUMS, CAIRO_BASIC_STRUCTS};
use starknet_contract::abi::parser::{CairoEnum, CairoEvent, CairoFunction, CairoStruct};
use starknet_core::types::contract::AbiEntry;
use starknet_core::types::contract::{AbiEntry, StateMutability};

mod expand;
use expand::contract::CairoContract;
Expand All @@ -27,6 +27,8 @@ use expand::{Expandable, ExpandableEvent};
mod contract_abi;
use contract_abi::ContractAbi;

use crate::abigen::expand::utils;

pub fn abigen_internal(input: TokenStream) -> TokenStream {
let contract_abi = parse_macro_input!(input as ContractAbi);
let contract_name = contract_abi.name;
Expand All @@ -38,11 +40,19 @@ pub fn abigen_internal(input: TokenStream) -> TokenStream {

let mut structs: HashMap<String, CairoStruct> = HashMap::new();
let mut enums: HashMap<String, CairoEnum> = HashMap::new();
let mut functions = vec![];
let mut views = vec![];
let mut externals = vec![];
let mut events = vec![];

for entry in &abi {
parse_entry(entry, &mut structs, &mut enums, &mut functions, &mut events);
parse_entry(
entry,
&mut structs,
&mut enums,
&mut externals,
&mut views,
&mut events,
);
}

for (_, cs) in structs {
Expand All @@ -60,12 +70,14 @@ pub fn abigen_internal(input: TokenStream) -> TokenStream {
tokens.push(ev.expand_impl(&events));
}

let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str());
tokens.push(quote! {
impl<P> #contract_name<P>
where
P: starknet::providers::Provider + Send + Sync, <P as starknet::providers::Provider>::Error: 'static
{
#(#functions)*
impl<'a, A: starknet::accounts::ConnectedAccount + Sync> #contract_name<'a, A> {
#(#externals)*
}

impl<'a, P: starknet::providers::Provider + Sync> #reader<'a, P> {
#(#views)*
}
});

Expand All @@ -80,7 +92,8 @@ fn parse_entry(
entry: &AbiEntry,
structs: &mut HashMap<String, CairoStruct>,
enums: &mut HashMap<String, CairoEnum>,
functions: &mut Vec<TokenStream2>,
externals: &mut Vec<TokenStream2>,
views: &mut Vec<TokenStream2>,
events: &mut Vec<CairoEvent>,
) {
match entry {
Expand Down Expand Up @@ -115,7 +128,10 @@ fn parse_entry(
// From this statement, we can safely assume that any function name is
// unique.
let cf = CairoFunction::new(&f.name, f.state_mutability.clone(), &f.inputs, &f.outputs);
functions.push(cf.expand_impl());
match f.state_mutability {
StateMutability::View => views.push(cf.expand_impl()),
StateMutability::External => externals.push(cf.expand_impl()),
}
}
AbiEntry::Event(ev) => {
if let Some(cev) = CairoEvent::new(ev) {
Expand All @@ -124,7 +140,7 @@ fn parse_entry(
}
AbiEntry::Interface(interface) => {
for entry in &interface.items {
parse_entry(entry, structs, enums, functions, events);
parse_entry(entry, structs, enums, externals, views, events);
}
}
_ => (),
Expand Down

0 comments on commit 06e1d3a

Please sign in to comment.