diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index 4ff5cfe5..02e775ea 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -6,6 +6,7 @@ use blockifier::state::errors::StateError; use blockifier::transaction::errors::TransactionExecutionError; use cairo_vm::types::errors::program_errors::ProgramError; use starknet_api::block::BlockNumber; +use starknet_api::core::CompiledClassHash; use starknet_api::transaction::{Resource, ResourceBounds}; use starknet_api::StarknetApiError; use thiserror::Error; @@ -18,6 +19,11 @@ use crate::compiler_version::{VersionId, VersionIdError}; pub enum GatewayError { #[error(transparent)] CompilationError(#[from] starknet_sierra_compile::compile::CompilationUtilError), + #[error( + "The supplied compiled class hash {supplied:?} does not match the hash of the Casm class \ + compiled from the supplied Sierra {hash_result:?}." + )] + CompiledClassHashMismatch { supplied: CompiledClassHash, hash_result: CompiledClassHash }, #[error(transparent)] DeclaredContractClassError(#[from] ContractClassError), #[error(transparent)] diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 54e3f38c..12464589 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -7,6 +7,8 @@ use axum::extract::State; use axum::routing::{get, post}; use axum::{Json, Router}; use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1}; +use blockifier::execution::execution_utils::felt_to_stark_felt; +use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_api::transaction::TransactionHash; use starknet_mempool_types::communication::SharedMempoolClient; @@ -157,6 +159,15 @@ pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResu } }; + let hash_result = + CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash())); + if hash_result != tx.compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: tx.compiled_class_hash, + hash_result, + }); + } + // Convert Casm contract class to Starknet contract class directly. let blockifier_contract_class = ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 67f7c270..dc89505d 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -1,13 +1,16 @@ use std::sync::Arc; +use assert_matches::assert_matches; use axum::body::{Bytes, HttpBody}; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use blockifier::context::ChainInfo; +use blockifier::execution::contract_class::ContractClass; use blockifier::test_utils::CairoVersion; use rstest::{fixture, rstest}; -use starknet_api::rpc_transaction::RPCTransaction; +use starknet_api::core::CompiledClassHash; +use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_api::transaction::TransactionHash; use starknet_mempool::communication::create_mempool_server; use starknet_mempool::mempool::Mempool; @@ -16,6 +19,7 @@ use tokio::sync::mpsc::channel; use tokio::task; use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig}; +use crate::errors::GatewayError; use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient}; use crate::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx}; use crate::state_reader_test_utils::{ @@ -100,6 +104,47 @@ async fn test_add_tx( assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap()); } +#[test] +fn test_compile_contract_class_failure() { + let mut declare_tx_v3 = match declare_tx() { + RPCTransaction::Declare(RPCDeclareTransaction::V3(declare_tx)) => declare_tx, + _ => panic!("Invalid transaction type"), + }; + let expected_hash_result = declare_tx_v3.compiled_class_hash; + let supplied_hash = CompiledClassHash::default(); + + declare_tx_v3.compiled_class_hash = supplied_hash; + let declare_tx = RPCDeclareTransaction::V3(declare_tx_v3); + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result.unwrap_err(), + GatewayError::CompiledClassHashMismatch { supplied, hash_result } + if supplied == supplied_hash && hash_result == expected_hash_result + ); +} + +#[test] +fn test_compile_contract_class() { + let declare_tx = match declare_tx() { + RPCTransaction::Declare(declare_tx) => declare_tx, + _ => panic!("Invalid transaction type"), + }; + let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; + let contract_class = &declare_tx_v3.contract_class; + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result, + Ok(class_info) + if ( + matches!(class_info.contract_class(), ContractClass::V1(_)) + && class_info.sierra_program_length() == contract_class.sierra_program.len() + && class_info.abi_length() == contract_class.abi.len() + ) + ); +} + async fn to_bytes(res: Response) -> Bytes { res.into_body().collect().await.unwrap().to_bytes() } diff --git a/crates/gateway/src/starknet_api_test_utils.rs b/crates/gateway/src/starknet_api_test_utils.rs index 598ba9f6..3f58bd32 100644 --- a/crates/gateway/src/starknet_api_test_utils.rs +++ b/crates/gateway/src/starknet_api_test_utils.rs @@ -18,7 +18,10 @@ use starknet_api::transaction::{ TransactionSignature, TransactionVersion, }; use starknet_api::{calldata, stark_felt}; -use test_utils::{get_absolute_path, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER}; +use test_utils::{ + get_absolute_path, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE, + TEST_FILES_FOLDER, +}; use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args}; @@ -97,6 +100,7 @@ pub fn declare_tx() -> RPCTransaction { env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir."); let json_file_path = Path::new(CONTRACT_CLASS_FILE); let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap(); + let compiled_class_hash = CompiledClassHash(stark_felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)); let cairo_version = CairoVersion::Cairo1; let account_contract = FeatureContract::AccountWithoutValidations(cairo_version); @@ -109,7 +113,8 @@ pub fn declare_tx() -> RPCTransaction { sender_address: account_address, resource_bounds: executable_resource_bounds_mapping(), nonce, - contract_class + class_hash: compiled_class_hash, + contract_class, )) } diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index f8ead84f..5553922a 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -46,7 +46,7 @@ use crate::stateful_transaction_validator::StatefulTransactionValidator; declare_tx(), local_test_state_reader_factory(CairoVersion::Cairo1, false), Ok(TransactionHash(StarkFelt::try_from( - "0x0278ed2700d5a30254a6b895d4e1140438d7d1a3b2b2ce0c096a9d5ee1c61f39" + "0x02da54b89e00d2e201f8e3ed2bcc715a69e89aefdce88aff2d2facb8dec55c0a" ).unwrap())) )] #[case::invalid_tx( diff --git a/crates/test_utils/src/lib.rs b/crates/test_utils/src/lib.rs index 23634bc6..571ec7f3 100644 --- a/crates/test_utils/src/lib.rs +++ b/crates/test_utils/src/lib.rs @@ -3,6 +3,8 @@ use std::path::{Path, PathBuf}; pub const TEST_FILES_FOLDER: &str = "crates/test_utils/test_files"; pub const CONTRACT_CLASS_FILE: &str = "contract_class.json"; +pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str = + "0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17"; /// Returns the absolute path from the project root. pub fn get_absolute_path(relative_path: &str) -> PathBuf {