Skip to content

Commit

Permalink
feat: check that the compiled class hash matches the supplied class (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware authored Jul 1, 2024
1 parent d7017e8 commit 1e568dc
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 4 deletions.
6 changes: 6 additions & 0 deletions crates/gateway/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use blockifier::transaction::errors::TransactionExecutionError;
use cairo_vm::types::errors::program_errors::ProgramError;
use serde_json::{Error as SerdeError, Value};
use starknet_api::block::{BlockNumber, GasPrice};
use starknet_api::core::CompiledClassHash;
use starknet_api::transaction::{Resource, ResourceBounds};
use starknet_api::StarknetApiError;
use thiserror::Error;
Expand All @@ -19,6 +20,11 @@ use crate::compiler_version::{VersionId, VersionIdError};
pub enum GatewayError {
#[error(transparent)]
CompilationError(#[from] starknet_sierra_compile::compile::CompilationUtilError),
#[error(
"The supplied compiled class hash {supplied:?} does not match the hash of the Casm class \
compiled from the supplied Sierra {hash_result:?}."
)]
CompiledClassHashMismatch { supplied: CompiledClassHash, hash_result: CompiledClassHash },
#[error(transparent)]
DeclaredContractClassError(#[from] ContractClassError),
#[error(transparent)]
Expand Down
11 changes: 11 additions & 0 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use axum::extract::State;
use axum::routing::{get, post};
use axum::{Json, Router};
use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1};
use blockifier::execution::execution_utils::felt_to_stark_felt;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_mempool_types::communication::SharedMempoolClient;
Expand Down Expand Up @@ -157,6 +159,15 @@ pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResu
}
};

let hash_result =
CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash()));
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
Expand Down
62 changes: 61 additions & 1 deletion crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
use std::sync::Arc;

use assert_matches::assert_matches;
use axum::body::{Bytes, HttpBody};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::execution::contract_class::ContractClass;
use blockifier::test_utils::CairoVersion;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use rstest::{fixture, rstest};
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_mempool::communication::create_mempool_server;
use starknet_mempool::mempool::Mempool;
use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender};
use starknet_sierra_compile::compile::CompilationUtilError;
use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::errors::GatewayError;
use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
use crate::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
use crate::state_reader_test_utils::{
Expand Down Expand Up @@ -103,6 +109,60 @@ async fn test_add_tx(
assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap());
}

#[test]
fn test_compile_contract_class_compiled_class_hash_missmatch() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
);
}

#[test]
fn test_compile_contract_class_bad_sierra() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
AllowedLibfuncsError::SierraProgramError
))
)
}

#[test]
fn test_compile_contract_class() {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
}

async fn to_bytes(res: Response) -> Bytes {
res.into_body().collect().await.unwrap().to_bytes()
}
Expand Down
9 changes: 7 additions & 2 deletions crates/gateway/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use starknet_api::transaction::{
TransactionSignature, TransactionVersion,
};
use starknet_api::{calldata, stark_felt};
use test_utils::{get_absolute_path, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER};
use test_utils::{
get_absolute_path, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE,
TEST_FILES_FOLDER,
};

use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args};

Expand Down Expand Up @@ -97,6 +100,7 @@ pub fn declare_tx() -> RPCTransaction {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(stark_felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand All @@ -108,7 +112,8 @@ pub fn declare_tx() -> RPCTransaction {
sender_address: account_address,
resource_bounds: executable_resource_bounds_mapping(),
nonce,
contract_class
class_hash: compiled_class_hash,
contract_class,
))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use crate::stateful_transaction_validator::StatefulTransactionValidator;
declare_tx(),
local_test_state_reader_factory(CairoVersion::Cairo1, false),
Ok(TransactionHash(StarkFelt::try_from(
"0x0278ed2700d5a30254a6b895d4e1140438d7d1a3b2b2ce0c096a9d5ee1c61f39"
"0x02da54b89e00d2e201f8e3ed2bcc715a69e89aefdce88aff2d2facb8dec55c0a"
).unwrap()))
)]
#[case::invalid_tx(
Expand Down
2 changes: 2 additions & 0 deletions crates/test_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::path::{Path, PathBuf};

pub const TEST_FILES_FOLDER: &str = "crates/test_utils/test_files";
pub const CONTRACT_CLASS_FILE: &str = "contract_class.json";
pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str =
"0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17";

/// Returns the absolute path from the project root.
pub fn get_absolute_path(relative_path: &str) -> PathBuf {
Expand Down

0 comments on commit 1e568dc

Please sign in to comment.