Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe external calls option #59

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/cli/commands/detect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ pub struct DetectArgs {
#[arg(long, num_args(0..))]
contract_path: Option<Vec<String>>,

/// Functions name that are safe when called (e.g. they don't cause a reentrancy)
#[arg(long, num_args(0..))]
safe_external_calls: Option<Vec<String>>,

/// Detectors to run
#[arg(long, num_args(0..), conflicts_with_all(["exclude", "exclude_informational", "exclude_low", "exclude_medium", "exclude_high"]))]
detect: Option<Vec<String>>,
Expand Down Expand Up @@ -53,6 +57,7 @@ impl From<&DetectArgs> for CoreOpts {
target: args.target.clone(),
corelib: args.corelib.clone(),
contract_path: args.contract_path.clone(),
safe_external_calls: args.safe_external_calls.clone(),
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/cli/commands/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pub struct PrintArgs {
#[arg(long, num_args(0..))]
contract_path: Option<Vec<String>>,

/// Functions name that are safe when called (e.g. they don't cause a reentrancy)
#[arg(long, num_args(0..))]
safe_external_calls: Option<Vec<String>>,

/// Which functions to run the printer (all, user-functions)
#[arg(short, long, default_value_t = Filter::UserFunctions)]
filter: Filter,
Expand All @@ -33,6 +37,7 @@ impl From<&PrintArgs> for CoreOpts {
target: args.target.clone(),
corelib: args.corelib.clone(),
contract_path: args.contract_path.clone(),
safe_external_calls: args.safe_external_calls.clone(),
}
}
}
Expand Down
12 changes: 11 additions & 1 deletion src/core/core_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ pub struct CoreOpts {
pub target: PathBuf,
pub corelib: Option<PathBuf>,
pub contract_path: Option<Vec<String>>,
pub safe_external_calls: Option<Vec<String>>,
}

pub struct CoreUnit {
compilation_units: Vec<CompilationUnit>,
safe_external_calls: Option<Vec<String>>,
}

impl CoreUnit {
pub fn new(opts: CoreOpts) -> Result<Self> {
let safe_external_calls = opts.safe_external_calls.clone();
let program_compiled = compile(opts)?;
let compilation_units = program_compiled
.par_iter()
Expand All @@ -31,10 +34,17 @@ impl CoreUnit {
compilation_unit
})
.collect();
Ok(CoreUnit { compilation_units })
Ok(CoreUnit {
compilation_units,
safe_external_calls,
})
}

pub fn get_compilation_units(&self) -> &Vec<CompilationUnit> {
&self.compilation_units
}

pub fn get_safe_external_calls(&self) -> &Option<Vec<String>> {
&self.safe_external_calls
}
}
19 changes: 15 additions & 4 deletions src/detectors/read_only_reentrancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ impl Detector for ReadOnlyReentrancy {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

for written_variable in reentrancy_info.storage_variables_written.iter()
{
let written_variable_name = written_variable
Expand All @@ -82,10 +96,7 @@ impl Detector for ReadOnlyReentrancy {
message: format!(
"Read only reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}",
view_function,
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ impl Detector for Reentrancy {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

if let Some(current_vars_read_before_call) = reentrancy_info
.variables_read_before_calls
.iter()
Expand Down Expand Up @@ -79,10 +93,7 @@ impl Detector for Reentrancy {
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy_benign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ impl Detector for ReentrancyBenign {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

if let Some(current_vars_read_before_call) = reentrancy_info
.variables_read_before_calls
.iter()
Expand Down Expand Up @@ -79,10 +93,7 @@ impl Detector for ReentrancyBenign {
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,28 @@ impl Detector for ReentrancyEvents {
{
for event in reentrancy_info.events.iter() {
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

results.insert(Result {
name: self.name().to_string(),
impact: self.impact(),
confidence: self.confidence(),
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tEvent emitted after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
event.get_event_emitted().as_ref().unwrap().get_statement(),
event.get_function()
Expand Down
9 changes: 8 additions & 1 deletion tests/detectors/read_only_reentrancy.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand Down Expand Up @@ -32,8 +33,14 @@ mod TestContract {
}

#[external(v0)]
fn ok(ref self: ContractState, address: ContractAddress) {
fn good1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.a.write(4);
}

}
8 changes: 8 additions & 0 deletions tests/detectors/reentrancy.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -22,6 +23,13 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(a);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
let a = self.a.read();
IAnotherContractDispatcher { contract_address: address }.safe_foo(a);
self.a.write(4);
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
let a = self.a.read();
Expand Down
7 changes: 7 additions & 0 deletions tests/detectors/reentrancy_benign.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -22,6 +23,12 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(a);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.a.write(4);
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
Expand Down
7 changes: 7 additions & 0 deletions tests/detectors/reentrancy_events.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -27,6 +28,12 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(4);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.emit(MyEvent { });
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fn test_detectors() {
env::var("CARGO_MANIFEST_DIR").unwrap() + "/corelib/src",
)),
contract_path: None,
safe_external_calls: Some(vec!["::safe_foo".to_string()]),
};
let core = CoreUnit::new(opts).unwrap();
let mut results = get_detectors()
Expand Down