Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check that the compiled class hash matches the supplied class #246

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/gateway/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use blockifier::transaction::errors::TransactionExecutionError;
use cairo_vm::types::errors::program_errors::ProgramError;
use serde_json::{Error as SerdeError, Value};
use starknet_api::block::{BlockNumber, GasPrice};
use starknet_api::core::CompiledClassHash;
use starknet_api::transaction::{Resource, ResourceBounds};
use starknet_api::StarknetApiError;
use thiserror::Error;
Expand All @@ -19,6 +20,11 @@ use crate::compiler_version::{VersionId, VersionIdError};
pub enum GatewayError {
#[error(transparent)]
CompilationError(#[from] starknet_sierra_compile::compile::CompilationUtilError),
#[error(
"The supplied compiled class hash {supplied:?} does not match the hash of the Casm class \
compiled from the supplied Sierra {hash_result:?}."
)]
CompiledClassHashMismatch { supplied: CompiledClassHash, hash_result: CompiledClassHash },
#[error(transparent)]
DeclaredContractClassError(#[from] ContractClassError),
#[error(transparent)]
Expand Down
11 changes: 11 additions & 0 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use axum::extract::State;
use axum::routing::{get, post};
use axum::{Json, Router};
use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1};
use blockifier::execution::execution_utils::felt_to_stark_felt;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_mempool_types::communication::SharedMempoolClient;
Expand Down Expand Up @@ -157,6 +159,15 @@ pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResu
}
};

let hash_result =
CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash()));
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
Expand Down
62 changes: 61 additions & 1 deletion crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
use std::sync::Arc;

use assert_matches::assert_matches;
use axum::body::{Bytes, HttpBody};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::execution::contract_class::ContractClass;
use blockifier::test_utils::CairoVersion;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use rstest::{fixture, rstest};
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_mempool::communication::create_mempool_server;
use starknet_mempool::mempool::Mempool;
use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender};
use starknet_sierra_compile::compile::CompilationUtilError;
use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::errors::GatewayError;
use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
use crate::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
use crate::state_reader_test_utils::{
Expand Down Expand Up @@ -103,6 +109,60 @@ async fn test_add_tx(
assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap());
}

#[test]
fn test_compile_contract_class_compiled_class_hash_missmatch() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
);
}

#[test]
fn test_compile_contract_class_bad_sierra() {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
AllowedLibfuncsError::SierraProgramError
))
)
}

#[test]
fn test_compile_contract_class() {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
}

async fn to_bytes(res: Response) -> Bytes {
res.into_body().collect().await.unwrap().to_bytes()
}
Expand Down
9 changes: 7 additions & 2 deletions crates/gateway/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use starknet_api::transaction::{
TransactionSignature, TransactionVersion,
};
use starknet_api::{calldata, stark_felt};
use test_utils::{get_absolute_path, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER};
use test_utils::{
get_absolute_path, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE,
TEST_FILES_FOLDER,
};

use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args};

Expand Down Expand Up @@ -97,6 +100,7 @@ pub fn declare_tx() -> RPCTransaction {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(stark_felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand All @@ -108,7 +112,8 @@ pub fn declare_tx() -> RPCTransaction {
sender_address: account_address,
resource_bounds: executable_resource_bounds_mapping(),
nonce,
contract_class
class_hash: compiled_class_hash,
contract_class,
))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use crate::stateful_transaction_validator::StatefulTransactionValidator;
declare_tx(),
local_test_state_reader_factory(CairoVersion::Cairo1, false),
Ok(TransactionHash(StarkFelt::try_from(
"0x0278ed2700d5a30254a6b895d4e1140438d7d1a3b2b2ce0c096a9d5ee1c61f39"
"0x02da54b89e00d2e201f8e3ed2bcc715a69e89aefdce88aff2d2facb8dec55c0a"
).unwrap()))
)]
#[case::invalid_tx(
Expand Down
2 changes: 2 additions & 0 deletions crates/test_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::path::{Path, PathBuf};

pub const TEST_FILES_FOLDER: &str = "crates/test_utils/test_files";
pub const CONTRACT_CLASS_FILE: &str = "contract_class.json";
pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str =
"0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17";

/// Returns the absolute path from the project root.
pub fn get_absolute_path(relative_path: &str) -> PathBuf {
Expand Down
Loading