Skip to content

Commit

Permalink
Add safe external calls option
Browse files Browse the repository at this point in the history
  • Loading branch information
smonicas committed Jan 25, 2024
1 parent 7b4d571 commit d452ff4
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 18 deletions.
5 changes: 5 additions & 0 deletions src/cli/commands/detect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ pub struct DetectArgs {
#[arg(long, num_args(0..))]
contract_path: Option<Vec<String>>,

/// Functions name that are safe when called (e.g. they don't cause a reentrancy)
#[arg(long, num_args(0..))]
safe_external_calls: Option<Vec<String>>,

/// Detectors to run
#[arg(long, num_args(0..), conflicts_with_all(["exclude", "exclude_informational", "exclude_low", "exclude_medium", "exclude_high"]))]
detect: Option<Vec<String>>,
Expand Down Expand Up @@ -53,6 +57,7 @@ impl From<&DetectArgs> for CoreOpts {
target: args.target.clone(),
corelib: args.corelib.clone(),
contract_path: args.contract_path.clone(),
safe_external_calls: args.safe_external_calls.clone(),
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/cli/commands/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pub struct PrintArgs {
#[arg(long, num_args(0..))]
contract_path: Option<Vec<String>>,

/// Functions name that are safe when called (e.g. they don't cause a reentrancy)
#[arg(long, num_args(0..))]
safe_external_calls: Option<Vec<String>>,

/// Which functions to run the printer (all, user-functions)
#[arg(short, long, default_value_t = Filter::UserFunctions)]
filter: Filter,
Expand All @@ -33,6 +37,7 @@ impl From<&PrintArgs> for CoreOpts {
target: args.target.clone(),
corelib: args.corelib.clone(),
contract_path: args.contract_path.clone(),
safe_external_calls: args.safe_external_calls.clone(),
}
}
}
Expand Down
12 changes: 11 additions & 1 deletion src/core/core_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ pub struct CoreOpts {
pub target: PathBuf,
pub corelib: Option<PathBuf>,
pub contract_path: Option<Vec<String>>,
pub safe_external_calls: Option<Vec<String>>,
}

pub struct CoreUnit {
compilation_units: Vec<CompilationUnit>,
safe_external_calls: Option<Vec<String>>,
}

impl CoreUnit {
pub fn new(opts: CoreOpts) -> Result<Self> {
let safe_external_calls = opts.safe_external_calls.clone();
let program_compiled = compile(opts)?;
let compilation_units = program_compiled
.par_iter()
Expand All @@ -31,10 +34,17 @@ impl CoreUnit {
compilation_unit
})
.collect();
Ok(CoreUnit { compilation_units })
Ok(CoreUnit {
compilation_units,
safe_external_calls,
})
}

pub fn get_compilation_units(&self) -> &Vec<CompilationUnit> {
&self.compilation_units
}

pub fn get_safe_external_calls(&self) -> &Option<Vec<String>> {
&self.safe_external_calls
}
}
19 changes: 15 additions & 4 deletions src/detectors/read_only_reentrancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ impl Detector for ReadOnlyReentrancy {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

for written_variable in reentrancy_info.storage_variables_written.iter()
{
let written_variable_name = written_variable
Expand All @@ -82,10 +96,7 @@ impl Detector for ReadOnlyReentrancy {
message: format!(
"Read only reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}",
view_function,
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ impl Detector for Reentrancy {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

if let Some(current_vars_read_before_call) = reentrancy_info
.variables_read_before_calls
.iter()
Expand Down Expand Up @@ -79,10 +93,7 @@ impl Detector for Reentrancy {
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy_benign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ impl Detector for ReentrancyBenign {
} = bb_info.1
{
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

if let Some(current_vars_read_before_call) = reentrancy_info
.variables_read_before_calls
.iter()
Expand Down Expand Up @@ -79,10 +93,7 @@ impl Detector for ReentrancyBenign {
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
written_variable
.get_storage_variable_written()
Expand Down
19 changes: 15 additions & 4 deletions src/detectors/reentrancy_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,28 @@ impl Detector for ReentrancyEvents {
{
for event in reentrancy_info.events.iter() {
for call in reentrancy_info.external_calls.iter() {
let external_function_call = format!(
"{}",
call.get_external_call().as_ref().unwrap().get_statement()
);

if let Some(safe_external_calls) = core.get_safe_external_calls() {
if safe_external_calls
.iter()
.any(|f_name| external_function_call.contains(f_name))
{
continue;
}
}

results.insert(Result {
name: self.name().to_string(),
impact: self.impact(),
confidence: self.confidence(),
message: format!(
"Reentrancy in {}\n\tExternal call {} done in {}\n\tEvent emitted after {} in {}.",
f.name(),
call.get_external_call()
.as_ref()
.unwrap()
.get_statement(),
external_function_call,
call.get_function(),
event.get_event_emitted().as_ref().unwrap().get_statement(),
event.get_function()
Expand Down
9 changes: 8 additions & 1 deletion tests/detectors/read_only_reentrancy.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand Down Expand Up @@ -32,8 +33,14 @@ mod TestContract {
}

#[external(v0)]
fn ok(ref self: ContractState, address: ContractAddress) {
fn good1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.a.write(4);
}

}
8 changes: 8 additions & 0 deletions tests/detectors/reentrancy.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -22,6 +23,13 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(a);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
let a = self.a.read();
IAnotherContractDispatcher { contract_address: address }.safe_foo(a);
self.a.write(4);
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
let a = self.a.read();
Expand Down
7 changes: 7 additions & 0 deletions tests/detectors/reentrancy_benign.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -22,6 +23,12 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(a);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.a.write(4);
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
Expand Down
7 changes: 7 additions & 0 deletions tests/detectors/reentrancy_events.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[starknet::interface]
trait IAnotherContract<T> {
fn foo(self: @T, a: felt252);
fn safe_foo(self: @T, a: felt252);
}

#[starknet::contract]
Expand All @@ -27,6 +28,12 @@ mod TestContract {
IAnotherContractDispatcher { contract_address: address }.foo(4);
}

#[external(v0)]
fn good2(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.safe_foo(4);
self.emit(MyEvent { });
}

#[external(v0)]
fn bad1(ref self: ContractState, address: ContractAddress) {
IAnotherContractDispatcher { contract_address: address }.foo(4);
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fn test_detectors() {
env::var("CARGO_MANIFEST_DIR").unwrap() + "/corelib/src",
)),
contract_path: None,
safe_external_calls: Some(vec!["::safe_foo".to_string()]),
};
let core = CoreUnit::new(opts).unwrap();
let mut results = get_detectors()
Expand Down

0 comments on commit d452ff4

Please sign in to comment.