diff --git a/src/accounts/MinimalBatchExecutor.sol b/src/accounts/MinimalBatchExecutor.sol new file mode 100644 index 000000000..aa2b836d4 --- /dev/null +++ b/src/accounts/MinimalBatchExecutor.sol @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +/// @notice Minimal batch executor mixin. +/// @author Solady (https://github.com/vectorized/solady/blob/main/src/accounts/MinimalBatchExecutor.sol) +abstract contract MinimalBatchExecutor { + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* STRUCTS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Call struct for the `execute` function. + struct Call { + address target; + uint256 value; + bytes data; + } + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* FUNCTIONS TO OVERRIDE */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Ensures that `execute` can only be called by the correct caller or `authData`. + function _authorizeExecute(Call[] calldata calls, bytes calldata authData) internal virtual; + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* EXECUTE */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Executes the `calls` and returns the results. + /// Reverts and bubbles up error if any call fails. + function execute(Call[] calldata calls, bytes calldata authData) + public + payable + virtual + returns (bytes[] memory results) + { + _authorizeExecute(calls, authData); + return _execute(calls); + } + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* SIGNALING */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev This function is provided for frontends to detect support. + function minimalBatchExecutorVersion() public pure virtual returns (uint256) { + return 1; // This number may change. + } + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* INTERNAL HELPERS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Executes the `calls` and returns the results. + /// Reverts and bubbles up error if any call fails. + function _execute(Call[] calldata calls) internal virtual returns (bytes[] memory results) { + /// @solidity memory-safe-assembly + assembly { + results := mload(0x40) // Grab the free memory pointer. + mstore(results, calls.length) // Store the length of results. + mstore(0x40, add(add(results, 0x20), shl(5, calls.length))) // Allocate memory. + } + for (uint256 i; i != calls.length;) { + address target; + uint256 value; + bytes calldata data; + /// @solidity memory-safe-assembly + assembly { + let c := add(calls.offset, calldataload(add(calls.offset, shl(5, i)))) + target := calldataload(c) + value := calldataload(add(c, 0x20)) + let o := add(c, calldataload(add(c, 0x40))) + data.offset := add(o, 0x20) + data.length := calldataload(o) + i := add(i, 1) + } + bytes memory r = _execute(target, value, data); + /// @solidity memory-safe-assembly + assembly { + mstore(add(results, shl(5, i)), r) // Set `results[i]` to `r`. + } + } + } + + /// @dev Executes the `calls` and returns the result. + /// Reverts and bubbles up error if any call fails. + function _execute(address target, uint256 value, bytes calldata data) + internal + virtual + returns (bytes memory result) + { + /// @solidity memory-safe-assembly + assembly { + result := mload(0x40) // Grab the free memory pointer. + calldatacopy(result, data.offset, data.length) + if iszero(call(gas(), target, value, result, data.length, codesize(), 0x00)) { + // Bubble up the revert if the call reverts. + returndatacopy(result, 0x00, returndatasize()) + revert(result, returndatasize()) + } + mstore(result, returndatasize()) // Store the length. + let o := add(result, 0x20) + returndatacopy(o, 0x00, returndatasize()) // Copy the returndata. + mstore(0x40, add(o, returndatasize())) // Allocate the memory. + } + } +} diff --git a/test/MinimalBatchExecutor.t.sol b/test/MinimalBatchExecutor.t.sol new file mode 100644 index 000000000..27e3f67a3 --- /dev/null +++ b/test/MinimalBatchExecutor.t.sol @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +import "./utils/SoladyTest.sol"; +import { + MinimalBatchExecutor, + MockMinimalBatchExecutor +} from "./utils/mocks/MockMinimalBatchExecutor.sol"; +import {LibClone} from "../src/utils/LibClone.sol"; + +contract MinimalBatchExecutorTest is SoladyTest { + error CustomError(); + + MockMinimalBatchExecutor mbe; + + address target; + + function setUp() public { + mbe = new MockMinimalBatchExecutor(); + target = LibClone.clone(address(this)); + } + + function revertsWithCustomError() external payable { + revert CustomError(); + } + + function returnsBytes(bytes memory b) external payable returns (bytes memory) { + return b; + } + + function returnsHash(bytes memory b) external payable returns (bytes32) { + return keccak256(b); + } + + function testMinimalBatchExecutor() public { + vm.deal(address(this), 1 ether); + + MinimalBatchExecutor.Call[] memory calls = new MinimalBatchExecutor.Call[](2); + + calls[0].target = target; + calls[0].value = 123; + calls[0].data = abi.encodeWithSignature("returnsBytes(bytes)", "hehe"); + + calls[1].target = target; + calls[1].value = 789; + calls[1].data = abi.encodeWithSignature("returnsHash(bytes)", "lol"); + + bytes[] memory results = mbe.execute{value: _totalValue(calls)}(calls, ""); + + assertEq(results.length, 2); + assertEq(abi.decode(results[0], (bytes)), "hehe"); + assertEq(abi.decode(results[1], (bytes32)), keccak256("lol")); + } + + function testMinimalBatchExecutorForRevert() public { + MinimalBatchExecutor.Call[] memory calls = new MinimalBatchExecutor.Call[](1); + calls[0].target = target; + calls[0].value = 0; + calls[0].data = abi.encodeWithSignature("revertsWithCustomError()"); + + vm.expectRevert(CustomError.selector); + mbe.execute{value: _totalValue(calls)}(calls, ""); + } + + struct Payload { + bytes data; + uint256 mode; + } + + function testMinimalBatchExecutor(bytes32) public { + vm.deal(address(this), 1 ether); + + MinimalBatchExecutor.Call[] memory calls = + new MinimalBatchExecutor.Call[](_randomUniform() & 3); + Payload[] memory payloads = new Payload[](calls.length); + + for (uint256 i; i < calls.length; ++i) { + calls[i].target = target; + calls[i].value = _randomUniform() & 0xff; + bytes memory data = _truncateBytes(_randomBytes(), 0x1ff); + payloads[i].data = data; + if (_randomChance(2)) { + payloads[i].mode = 0; + calls[i].data = abi.encodeWithSignature("returnsBytes(bytes)", data); + } else { + payloads[i].mode = 1; + calls[i].data = abi.encodeWithSignature("returnsHash(bytes)", data); + } + } + + bytes[] memory results = mbe.executeDirect{value: _totalValue(calls)}(calls); + for (uint256 i; i < calls.length; ++i) { + if (payloads[i].mode == 0) { + assertEq(abi.decode(results[i], (bytes)), payloads[i].data); + } else { + assertEq(abi.decode(results[i], (bytes32)), keccak256(payloads[i].data)); + } + } + + if (calls.length != 0 && _randomChance(32)) { + calls[_randomUniform() % calls.length].data = + abi.encodeWithSignature("revertsWithCustomError()"); + vm.expectRevert(CustomError.selector); + mbe.executeDirect{value: _totalValue(calls)}(calls); + } + } + + function _totalValue(MinimalBatchExecutor.Call[] memory calls) + internal + pure + returns (uint256 result) + { + unchecked { + for (uint256 i; i < calls.length; ++i) { + result += calls[i].value; + } + } + } +} diff --git a/test/utils/mocks/MockMinimalBatchExecutor.sol b/test/utils/mocks/MockMinimalBatchExecutor.sol new file mode 100644 index 000000000..068b42d3b --- /dev/null +++ b/test/utils/mocks/MockMinimalBatchExecutor.sol @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +import {MinimalBatchExecutor} from "../../../src/accounts/MinimalBatchExecutor.sol"; +import {Brutalizer} from "../Brutalizer.sol"; + +/// @dev WARNING! This mock is strictly intended for testing purposes only. +/// Do NOT copy anything here into production code unless you really know what you are doing. +contract MockMinimalBatchExecutor is MinimalBatchExecutor, Brutalizer { + function _authorizeExecute(Call[] calldata calls, bytes calldata authData) + internal + virtual + override + {} + + function executeDirect(Call[] calldata calls) + public + payable + virtual + returns (bytes[] memory results) + { + _misalignFreeMemoryPointer(); + _brutalizeMemory(); + results = _execute(calls); + _checkMemory(); + } +}