diff --git a/.changeset/dull-students-eat.md b/.changeset/dull-students-eat.md new file mode 100644 index 00000000000..94c4fc21ef2 --- /dev/null +++ b/.changeset/dull-students-eat.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Memory`: Add library with utilities to manipulate memory diff --git a/.changeset/sharp-scissors-drum.md b/.changeset/sharp-scissors-drum.md new file mode 100644 index 00000000000..b701eccf3fa --- /dev/null +++ b/.changeset/sharp-scissors-drum.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`LowLevelCall`: Add a library to perform low-level calls and deal with the `returndata` more granularly. diff --git a/contracts/access/manager/AuthorityUtils.sol b/contracts/access/manager/AuthorityUtils.sol index fb3018ca805..d6caeff3c61 100644 --- a/contracts/access/manager/AuthorityUtils.sol +++ b/contracts/access/manager/AuthorityUtils.sol @@ -4,6 +4,8 @@ pragma solidity ^0.8.20; import {IAuthority} from "./IAuthority.sol"; +import {Memory} from "../../utils/Memory.sol"; +import {LowLevelCall} from "../../utils/LowLevelCall.sol"; library AuthorityUtils { /** @@ -17,16 +19,21 @@ library AuthorityUtils { address target, bytes4 selector ) internal view returns (bool immediate, uint32 delay) { - (bool success, bytes memory data) = authority.staticcall( - abi.encodeCall(IAuthority.canCall, (caller, target, selector)) + Memory.Pointer ptr = Memory.getFreePointer(); + bytes memory params = abi.encodeCall(IAuthority.canCall, (caller, target, selector)); + (bool success, bytes32 immediateWord, bytes32 delayWord) = LowLevelCall.staticcallReturnBytes32Pair( + authority, + params ); - if (success) { - if (data.length >= 0x40) { - (immediate, delay) = abi.decode(data, (bool, uint32)); - } else if (data.length >= 0x20) { - immediate = abi.decode(data, (bool)); - } + Memory.setFreePointer(ptr); + + if (!success) { + return (false, 0); } - return (immediate, delay); + + return ( + uint256(immediateWord) != 0, + uint32(uint256(delayWord)) // Intentional overflow to truncate the higher 224 bits + ); } } diff --git a/contracts/mocks/CallReceiverMock.sol b/contracts/mocks/CallReceiverMock.sol index e371c7db800..496386c4a7c 100644 --- a/contracts/mocks/CallReceiverMock.sol +++ b/contracts/mocks/CallReceiverMock.sol @@ -24,6 +24,12 @@ contract CallReceiverMock { return "0x1234"; } + function mockFunctionWithArgsReturn(uint256 a, uint256 b) public payable returns (uint256, uint256) { + emit MockFunctionCalledWithArgs(a, b); + + return (a, b); + } + function mockFunctionNonPayable() public returns (string memory) { emit MockFunctionCalled(); @@ -34,6 +40,10 @@ contract CallReceiverMock { return "0x1234"; } + function mockStaticFunctionWithArgsReturn(uint256 a, uint256 b) public pure returns (uint256, uint256) { + return (a, b); + } + function mockFunctionRevertsNoReason() public payable { revert(); } diff --git a/contracts/mocks/Stateless.sol b/contracts/mocks/Stateless.sol index 846c77d98e8..6e68e4b7327 100644 --- a/contracts/mocks/Stateless.sol +++ b/contracts/mocks/Stateless.sol @@ -22,8 +22,10 @@ import {ERC165} from "../utils/introspection/ERC165.sol"; import {ERC165Checker} from "../utils/introspection/ERC165Checker.sol"; import {ERC1967Utils} from "../proxy/ERC1967/ERC1967Utils.sol"; import {ERC721Holder} from "../token/ERC721/utils/ERC721Holder.sol"; +import {LowLevelCall} from "../utils/LowLevelCall.sol"; import {Heap} from "../utils/structs/Heap.sol"; import {Math} from "../utils/math/Math.sol"; +import {Memory} from "../utils/Memory.sol"; import {MerkleProof} from "../utils/cryptography/MerkleProof.sol"; import {MessageHashUtils} from "../utils/cryptography/MessageHashUtils.sol"; import {P256} from "../utils/cryptography/P256.sol"; diff --git a/contracts/token/ERC20/extensions/ERC4626.sol b/contracts/token/ERC20/extensions/ERC4626.sol index c71b14ad48c..e1604792fc6 100644 --- a/contracts/token/ERC20/extensions/ERC4626.sol +++ b/contracts/token/ERC20/extensions/ERC4626.sol @@ -7,6 +7,8 @@ import {IERC20, IERC20Metadata, ERC20} from "../ERC20.sol"; import {SafeERC20} from "../utils/SafeERC20.sol"; import {IERC4626} from "../../../interfaces/IERC4626.sol"; import {Math} from "../../../utils/math/Math.sol"; +import {Memory} from "../../../utils/Memory.sol"; +import {LowLevelCall} from "../../../utils/LowLevelCall.sol"; /** * @dev Implementation of the ERC-4626 "Tokenized Vault Standard" as defined in @@ -75,25 +77,21 @@ abstract contract ERC4626 is ERC20, IERC4626 { * @dev Set the underlying asset contract. This must be an ERC20-compatible contract (ERC-20 or ERC-777). */ constructor(IERC20 asset_) { - (bool success, uint8 assetDecimals) = _tryGetAssetDecimals(asset_); - _underlyingDecimals = success ? assetDecimals : 18; + _underlyingDecimals = _tryGetAssetDecimalsWithFallback(asset_, 18); _asset = asset_; } - /** - * @dev Attempts to fetch the asset decimals. A return value of false indicates that the attempt failed in some way. - */ - function _tryGetAssetDecimals(IERC20 asset_) private view returns (bool, uint8) { - (bool success, bytes memory encodedDecimals) = address(asset_).staticcall( - abi.encodeCall(IERC20Metadata.decimals, ()) - ); - if (success && encodedDecimals.length >= 32) { - uint256 returnedDecimals = abi.decode(encodedDecimals, (uint256)); - if (returnedDecimals <= type(uint8).max) { - return (true, uint8(returnedDecimals)); - } - } - return (false, 0); + function _tryGetAssetDecimalsWithFallback(IERC20 asset_, uint8 defaultValue) private view returns (uint8) { + Memory.Pointer ptr = Memory.getFreePointer(); + bytes memory params = abi.encodeCall(IERC20Metadata.decimals, ()); + + (bool success, bytes32 rawValue) = LowLevelCall.staticcallReturnBytes32(address(asset_), params); + uint256 length = LowLevelCall.returnDataSize(); + uint256 value = uint256(rawValue); + + Memory.setFreePointer(ptr); + + return uint8(Math.ternary(success && length >= 0x20 && value <= type(uint8).max, value, defaultValue)); } /** diff --git a/contracts/token/ERC20/utils/SafeERC20.sol b/contracts/token/ERC20/utils/SafeERC20.sol index ed41fb042c9..16e40305ee1 100644 --- a/contracts/token/ERC20/utils/SafeERC20.sol +++ b/contracts/token/ERC20/utils/SafeERC20.sol @@ -6,6 +6,8 @@ pragma solidity ^0.8.20; import {IERC20} from "../IERC20.sol"; import {IERC1363} from "../../../interfaces/IERC1363.sol"; import {Address} from "../../../utils/Address.sol"; +import {Memory} from "../../../utils/Memory.sol"; +import {LowLevelCall} from "../../../utils/LowLevelCall.sol"; /** * @title SafeERC20 @@ -32,7 +34,9 @@ library SafeERC20 { * non-reverting calls are assumed to be successful. */ function safeTransfer(IERC20 token, address to, uint256 value) internal { + Memory.Pointer ptr = Memory.getFreePointer(); _callOptionalReturn(token, abi.encodeCall(token.transfer, (to, value))); + Memory.setFreePointer(ptr); } /** @@ -40,7 +44,9 @@ library SafeERC20 { * calling contract. If `token` returns no value, non-reverting calls are assumed to be successful. */ function safeTransferFrom(IERC20 token, address from, address to, uint256 value) internal { + Memory.Pointer ptr = Memory.getFreePointer(); _callOptionalReturn(token, abi.encodeCall(token.transferFrom, (from, to, value))); + Memory.setFreePointer(ptr); } /** @@ -72,12 +78,13 @@ library SafeERC20 { * to be set to zero before setting it to a non-zero value, such as USDT. */ function forceApprove(IERC20 token, address spender, uint256 value) internal { + Memory.Pointer ptr = Memory.getFreePointer(); bytes memory approvalCall = abi.encodeCall(token.approve, (spender, value)); - if (!_callOptionalReturnBool(token, approvalCall)) { _callOptionalReturn(token, abi.encodeCall(token.approve, (spender, 0))); _callOptionalReturn(token, approvalCall); } + Memory.setFreePointer(ptr); } /** @@ -144,21 +151,18 @@ library SafeERC20 { * This is a variant of {_callOptionalReturnBool} that reverts if call fails to meet the requirements. */ function _callOptionalReturn(IERC20 token, bytes memory data) private { - uint256 returnSize; - uint256 returnValue; + (bool success, bytes32 returnValue) = LowLevelCall.callReturnBytes32(address(token), data); + uint256 returnSize = LowLevelCall.returnDataSize(); + assembly ("memory-safe") { - let success := call(gas(), token, 0, add(data, 0x20), mload(data), 0, 0x20) - // bubble errors if iszero(success) { - let ptr := mload(0x40) - returndatacopy(ptr, 0, returndatasize()) - revert(ptr, returndatasize()) + // Bubble up revert reason + returndatacopy(data, 0, returnSize) + revert(data, returnSize) } - returnSize := returndatasize() - returnValue := mload(0) } - if (returnSize == 0 ? address(token).code.length == 0 : returnValue != 1) { + if (returnSize == 0 ? address(token).code.length == 0 : uint256(returnValue) != 1) { revert SafeERC20FailedOperation(address(token)); } } @@ -172,14 +176,8 @@ library SafeERC20 { * This is a variant of {_callOptionalReturn} that silently catches all reverts and returns a bool instead. */ function _callOptionalReturnBool(IERC20 token, bytes memory data) private returns (bool) { - bool success; - uint256 returnSize; - uint256 returnValue; - assembly ("memory-safe") { - success := call(gas(), token, 0, add(data, 0x20), mload(data), 0, 0x20) - returnSize := returndatasize() - returnValue := mload(0) - } - return success && (returnSize == 0 ? address(token).code.length > 0 : returnValue == 1); + (bool success, bytes32 returnValue) = LowLevelCall.callReturnBytes32(address(token), data); + uint256 returnSize = LowLevelCall.returnDataSize(); + return success && (returnSize == 0 ? address(token).code.length > 0 : uint256(returnValue) == 1); } } diff --git a/contracts/utils/Address.sol b/contracts/utils/Address.sol index 40f01a93dca..b940291b6c8 100644 --- a/contracts/utils/Address.sol +++ b/contracts/utils/Address.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.20; import {Errors} from "./Errors.sol"; +import {LowLevelCall} from "./LowLevelCall.sol"; /** * @dev Collection of functions related to the address type @@ -35,7 +36,7 @@ library Address { revert Errors.InsufficientBalance(address(this).balance, amount); } - (bool success, ) = recipient.call{value: amount}(""); + bool success = LowLevelCall.callRaw(recipient, "", amount); if (!success) { revert Errors.FailedCall(); } diff --git a/contracts/utils/LowLevelCall.sol b/contracts/utils/LowLevelCall.sol new file mode 100644 index 00000000000..bf552a3e78a --- /dev/null +++ b/contracts/utils/LowLevelCall.sol @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +import {Errors} from "./Errors.sol"; + +/** + * @dev Library of low level call functions that implement different calling strategies to deal with the return data. + * + * WARNING: Using this library requires an advanced understanding of Solidity and how the EVM works. It is recommended + * to use the {Address} library instead. + */ +library LowLevelCall { + /// === CALL === + + /// @dev Performs a Solidity function call using a low level `call` and ignoring the return data. + function callRaw(address target, bytes memory data) internal returns (bool success) { + return callRaw(target, data, 0); + } + + /// @dev Same as {callRaw}, but allows to specify the value to be sent in the call. + function callRaw(address target, bytes memory data, uint256 value) internal returns (bool success) { + assembly ("memory-safe") { + success := call(gas(), target, value, add(data, 0x20), mload(data), 0, 0) + } + } + + /// @dev Performs a Solidity function call using a low level `call` and returns the first 32 bytes of the result + /// in the scratch space of memory. Useful for functions that return a single-word value. + /// + /// WARNING: Do not assume that the result is zero if `success` is false. Memory can be already allocated + /// and this function doesn't zero it out. + function callReturnBytes32(address target, bytes memory data) internal returns (bool success, bytes32 result) { + return callReturnBytes32(target, data, 0); + } + + /// @dev Same as {callReturnBytes32}, but allows to specify the value to be sent in the call. + function callReturnBytes32( + address target, + bytes memory data, + uint256 value + ) internal returns (bool success, bytes32 result) { + assembly ("memory-safe") { + success := call(gas(), target, value, add(data, 0x20), mload(data), 0, 0x20) + result := mload(0) + } + } + + /// @dev Performs a Solidity function call using a low level `call` and returns the first 64 bytes of the result + /// in the scratch space of memory. Useful for functions that return a tuple of single-word values. + /// + /// WARNING: Do not assume that the results are zero if `success` is false. Memory can be already allocated + /// and this function doesn't zero it out. + function callReturnBytes32Pair( + address target, + bytes memory data + ) internal returns (bool success, bytes32 result1, bytes32 result2) { + return callReturnBytes32Pair(target, data, 0); + } + + /// @dev Same as {callReturnBytes32Pair}, but allows to specify the value to be sent in the call. + function callReturnBytes32Pair( + address target, + bytes memory data, + uint256 value + ) internal returns (bool success, bytes32 result1, bytes32 result2) { + assembly ("memory-safe") { + success := call(gas(), target, value, add(data, 0x20), mload(data), 0, 0x40) + result1 := mload(0) + result2 := mload(0x20) + } + } + + /// === STATICCALL === + + /// @dev Performs a Solidity function call using a low level `staticcall` and ignoring the return data. + function staticcallRaw(address target, bytes memory data) internal view returns (bool success) { + assembly ("memory-safe") { + success := staticcall(gas(), target, add(data, 0x20), mload(data), 0, 0) + } + } + + /// @dev Performs a Solidity function call using a low level `staticcall` and returns the first 32 bytes of the result + /// in the scratch space of memory. Useful for functions that return a single-word value. + /// + /// WARNING: Do not assume that the result is zero if `success` is false. Memory can be already allocated + /// and this function doesn't zero it out. + function staticcallReturnBytes32( + address target, + bytes memory data + ) internal view returns (bool success, bytes32 result) { + assembly ("memory-safe") { + success := staticcall(gas(), target, add(data, 0x20), mload(data), 0, 0x20) + result := mload(0) + } + } + + /// @dev Performs a Solidity function call using a low level `staticcall` and returns the first 64 bytes of the result + /// in the scratch space of memory. Useful for functions that return a tuple of single-word values. + /// + /// WARNING: Do not assume that the results are zero if `success` is false. Memory can be already allocated + /// and this function doesn't zero it out. + function staticcallReturnBytes32Pair( + address target, + bytes memory data + ) internal view returns (bool success, bytes32 result1, bytes32 result2) { + assembly ("memory-safe") { + success := staticcall(gas(), target, add(data, 0x20), mload(data), 0, 0x40) + result1 := mload(0) + result2 := mload(0x20) + } + } + + /// @dev Returns the size of the return data buffer. + function returnDataSize() internal pure returns (uint256 size) { + assembly ("memory-safe") { + size := returndatasize() + } + } +} diff --git a/contracts/utils/Memory.sol b/contracts/utils/Memory.sol new file mode 100644 index 00000000000..a0fc881e318 --- /dev/null +++ b/contracts/utils/Memory.sol @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +/// @dev Memory utility library. +library Memory { + type Pointer is bytes32; + + /// @dev Returns a memory pointer to the current free memory pointer. + function getFreePointer() internal pure returns (Pointer ptr) { + assembly ("memory-safe") { + ptr := mload(0x40) + } + } + + /// @dev Sets the free memory pointer to a specific value. + /// + /// WARNING: Everything after the pointer may be overwritten. + function setFreePointer(Pointer ptr) internal pure { + assembly ("memory-safe") { + mstore(0x40, ptr) + } + } + + /// @dev Pointer to `bytes32`. + function asBytes32(Pointer ptr) internal pure returns (bytes32) { + return Pointer.unwrap(ptr); + } + + /// @dev `bytes32` to pointer. + function asPointer(bytes32 value) internal pure returns (Pointer) { + return Pointer.wrap(value); + } +} diff --git a/contracts/utils/README.adoc b/contracts/utils/README.adoc index 0ef3e5387c8..0e2f2a96034 100644 --- a/contracts/utils/README.adoc +++ b/contracts/utils/README.adoc @@ -39,6 +39,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t * {Context}: A utility for abstracting the sender and calldata in the current execution context. * {Packing}: A library for packing and unpacking multiple values into bytes32 * {Panic}: A library to revert with https://docs.soliditylang.org/en/v0.8.20/control-structures.html#panic-via-assert-and-error-via-require[Solidity panic codes]. + * {LowLevelCall}: Collection of functions to perform calls with low-level assembly. * {Comparators}: A library that contains comparator functions to use with with the {Heap} library. [NOTE] @@ -134,4 +135,6 @@ Ethereum contracts have no native concept of an interface, so applications must {{Panic}} +{{LowLevelCall}} + {{Comparators}} diff --git a/contracts/utils/cryptography/SignatureChecker.sol b/contracts/utils/cryptography/SignatureChecker.sol index 9aaa2e0716c..2a80d1017d4 100644 --- a/contracts/utils/cryptography/SignatureChecker.sol +++ b/contracts/utils/cryptography/SignatureChecker.sol @@ -5,6 +5,8 @@ pragma solidity ^0.8.20; import {ECDSA} from "./ECDSA.sol"; import {IERC1271} from "../../interfaces/IERC1271.sol"; +import {Memory} from "../Memory.sol"; +import {LowLevelCall} from "../LowLevelCall.sol"; /** * @dev Signature verification helper that can be used instead of `ECDSA.recover` to seamlessly support both ECDSA @@ -40,11 +42,12 @@ library SignatureChecker { bytes32 hash, bytes memory signature ) internal view returns (bool) { - (bool success, bytes memory result) = signer.staticcall( - abi.encodeCall(IERC1271.isValidSignature, (hash, signature)) - ); - return (success && - result.length >= 32 && - abi.decode(result, (bytes32)) == bytes32(IERC1271.isValidSignature.selector)); + Memory.Pointer ptr = Memory.getFreePointer(); + bytes memory params = abi.encodeCall(IERC1271.isValidSignature, (hash, signature)); + (bool success, bytes32 result) = LowLevelCall.staticcallReturnBytes32(signer, params); + uint256 length = LowLevelCall.returnDataSize(); + Memory.setFreePointer(ptr); + + return success && length >= 32 && result == bytes32(IERC1271.isValidSignature.selector); } } diff --git a/docs/modules/ROOT/pages/utilities.adoc b/docs/modules/ROOT/pages/utilities.adoc index b8afec4eabd..a713987cf08 100644 --- a/docs/modules/ROOT/pages/utilities.adoc +++ b/docs/modules/ROOT/pages/utilities.adoc @@ -386,3 +386,47 @@ await instance.multicall([ instance.interface.encodeFunctionData("bar") ]); ---- + +=== LowLevelCall + +The `LowLevelCall` library contains a set of functions to perform external calls with low-level assembly, allowing them to deal with the callee's `returndata` in different ways. This is especially useful to make a call in a way that is safe against return bombing (i.e. the callee allocates too much memory using a long returndata). + +The functions in the library efficiently allocates a fixed sized of the `returndata` up to 64 bytes. You can either ignore the returned data, or get 1 or 2 `bytes32` values. + +[source,solidity] +---- +using LowLevelCall for address; + +function _foo(address target, bytes memory data) internal { + bool success; + bytes32 returnValue1; + bytes32 returnValue2; + + // Ignore return data + success = target.callRaw(data); + + // Copy only 32 bytes from return data + (success, returnValue1) = target.callReturnBytes32(data); + + // Copy two (32 bytes) EVM words from returndata + (success, returnValue1, returnValue2) = target.callReturnBytes32Pair(data); +} +---- + +There are cases where you would like to check the size of the returned data, either to make sure it fits the expected size, or to check it before loading it to memory. In those case, the library also includes a function to separately check the return data size: + + +[source,solidity] +---- +using LowLevelCall for address; + +function _foo(address target, bytes memory data) internal returns (bool returnBool) { + if (!target.callRaw(data)) { + // Unsuccessful call + return false; + } + + // As long as the contract returned data, its content doesn't matter + return LowLevelCall.returnDataSize() >= 32; +} +---- diff --git a/test/utils/LowLevelCall.test.js b/test/utils/LowLevelCall.test.js new file mode 100644 index 00000000000..717ace7d607 --- /dev/null +++ b/test/utils/LowLevelCall.test.js @@ -0,0 +1,204 @@ +const { ethers } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); + +async function fixture() { + const [recipient, other] = await ethers.getSigners(); + + const mock = await ethers.deployContract('$LowLevelCall'); + const target = await ethers.deployContract('CallReceiverMock'); + const targetEther = await ethers.deployContract('EtherReceiverMock'); + + return { recipient, other, mock, target, targetEther, value: BigInt(1e18) }; +} + +describe('LowLevelCall', function () { + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + describe('callRaw', function () { + beforeEach(function () { + this.call = this.target.interface.encodeFunctionData('mockFunction'); + }); + + it('calls the requested function and returns true', async function () { + await expect(this.mock.$callRaw(this.target, this.call)) + .to.emit(this.target, 'MockFunctionCalled') + .to.emit(this.mock, 'return$callRaw_address_bytes') + .withArgs(true); + }); + + it('calls the requested function with value and returns true', async function () { + await this.other.sendTransaction({ to: this.mock, value: this.value }); + + const tx = this.mock['$callRaw(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.changeEtherBalance(this.target, this.value); + await expect(tx).to.emit(this.mock, 'return$callRaw_address_bytes_uint256').withArgs(true); + }); + + it("calls the requested function and returns false if the caller doesn't have enough balance", async function () { + const tx = this.mock['$callRaw(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.not.changeEtherBalance(this.target, this.value); + await expect(tx).to.emit(this.mock, 'return$callRaw_address_bytes_uint256').withArgs(false); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const call = this.target.interface.encodeFunctionData('mockFunctionRevertsNoReason'); + const tx = await this.mock.$callRaw(this.target, call); + expect(tx).to.emit(this.mock, 'return$callRaw_address_bytes').withArgs(false); + }); + }); + + describe('callReturnBytes32', function () { + beforeEach(function () { + this.returnValue = ethers.id('returnDataBytes32'); + this.call = this.target.interface.encodeFunctionData('mockFunctionWithArgsReturn', [ + this.returnValue, + ethers.ZeroHash, + ]); + }); + + it('calls the requested function and returns true', async function () { + await expect(this.mock.$callReturnBytes32(this.target, this.call)) + .to.emit(this.target, 'MockFunctionCalledWithArgs') + .withArgs(this.returnValue, ethers.ZeroHash) + .to.emit(this.mock, 'return$callReturnBytes32_address_bytes') + .withArgs(true, this.returnValue); + }); + + it('calls the requested function with value and returns true', async function () { + await this.other.sendTransaction({ to: this.mock, value: this.value }); + + const tx = this.mock['$callReturnBytes32(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.changeEtherBalance(this.target, this.value); + await expect(tx) + .to.emit(this.mock, 'return$callReturnBytes32_address_bytes_uint256') + .withArgs(true, this.returnValue); + }); + + it("calls the requested function and returns false if the caller doesn't have enough balance", async function () { + const tx = this.mock['$callReturnBytes32(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.not.changeEtherBalance(this.target, this.value); + await expect(tx) + .to.emit(this.mock, 'return$callReturnBytes32_address_bytes_uint256') + .withArgs(false, ethers.ZeroHash); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const call = this.target.interface.encodeFunctionData('mockFunctionRevertsNoReason'); + const tx = await this.mock.$callReturnBytes32(this.target, call); + expect(tx).to.emit(this.mock, 'return$callReturnBytes32_address_bytes').withArgs(false, ethers.ZeroHash); + }); + }); + + describe('callReturnBytes32Pair', function () { + beforeEach(function () { + this.returnValue1 = ethers.id('returnDataBytes32Pair1'); + this.returnValue2 = ethers.id('returnDataBytes32Pair2'); + this.call = this.target.interface.encodeFunctionData('mockFunctionWithArgsReturn', [ + this.returnValue1, + this.returnValue2, + ]); + }); + + it('calls the requested function and returns true', async function () { + await expect(this.mock.$callReturnBytes32Pair(this.target, this.call)) + .to.emit(this.target, 'MockFunctionCalledWithArgs') + .withArgs(this.returnValue1, this.returnValue2) + .to.emit(this.mock, 'return$callReturnBytes32Pair_address_bytes') + .withArgs(true, this.returnValue1, this.returnValue2); + }); + + it('calls the requested function with value and returns true', async function () { + await this.other.sendTransaction({ to: this.mock, value: this.value }); + + const tx = this.mock['$callReturnBytes32Pair(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.changeEtherBalance(this.target, this.value); + await expect(tx) + .to.emit(this.mock, 'return$callReturnBytes32Pair_address_bytes_uint256') + .withArgs(true, this.returnValue1, this.returnValue2); + }); + + it("calls the requested function and returns false if the caller doesn't have enough balance", async function () { + const tx = this.mock['$callReturnBytes32Pair(address,bytes,uint256)'](this.target, this.call, this.value); + await expect(tx).to.not.changeEtherBalance(this.target, this.value); + await expect(tx) + .to.emit(this.mock, 'return$callReturnBytes32Pair_address_bytes_uint256') + .withArgs(false, ethers.ZeroHash, ethers.ZeroHash); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const call = this.target.interface.encodeFunctionData('mockFunctionRevertsNoReason'); + const tx = await this.mock.$callReturnBytes32Pair(this.target, call); + expect(tx) + .to.emit(this.mock, 'return$callReturnBytes32Pair_address_bytes') + .withArgs(false, ethers.ZeroHash, ethers.ZeroHash); + }); + }); + + describe('staticcallRaw', function () { + it('calls the requested function and returns true', async function () { + const call = this.target.interface.encodeFunctionData('mockStaticFunction'); + expect(await this.mock.$staticcallRaw(this.target, call)).to.equal(true); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const interface = new ethers.Interface(['function mockFunctionDoesNotExist()']); + + const call = interface.encodeFunctionData('mockFunctionDoesNotExist'); + expect(await this.mock.$staticcallRaw(this.target, call)).to.equal(false); + }); + }); + + describe('staticcallReturnBytes32', function () { + beforeEach(function () { + this.returnValue = ethers.id('returnDataBytes32'); + }); + + it('calls the requested function and returns true', async function () { + const call = this.target.interface.encodeFunctionData('mockStaticFunctionWithArgsReturn', [ + this.returnValue, + ethers.ZeroHash, + ]); + expect(await this.mock.$staticcallReturnBytes32(this.target, call)).to.deep.equal([true, this.returnValue]); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const interface = new ethers.Interface(['function mockFunctionDoesNotExist()']); + + const call = interface.encodeFunctionData('mockFunctionDoesNotExist'); + expect(await this.mock.$staticcallReturnBytes32(this.target, call)).to.deep.equal([false, ethers.ZeroHash]); + }); + }); + + describe('staticcallReturnBytes32Pair', function () { + beforeEach(function () { + this.returnValue1 = ethers.id('returnDataBytes32Pair1'); + this.returnValue2 = ethers.id('returnDataBytes32Pair2'); + }); + + it('calls the requested function and returns true', async function () { + const call = this.target.interface.encodeFunctionData('mockStaticFunctionWithArgsReturn', [ + this.returnValue1, + this.returnValue2, + ]); + expect(await this.mock.$staticcallReturnBytes32Pair(this.target, call)).to.deep.equal([ + true, + this.returnValue1, + this.returnValue2, + ]); + }); + + it('calls the requested function and returns false if the subcall reverts', async function () { + const interface = new ethers.Interface(['function mockFunctionDoesNotExist()']); + + const call = interface.encodeFunctionData('mockFunctionDoesNotExist'); + expect(await this.mock.$staticcallReturnBytes32Pair(this.target, call)).to.deep.equal([ + false, + ethers.ZeroHash, + ethers.ZeroHash, + ]); + }); + }); +}); diff --git a/test/utils/Memory.t.sol b/test/utils/Memory.t.sol new file mode 100644 index 00000000000..4cc60b88f9c --- /dev/null +++ b/test/utils/Memory.t.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Memory} from "@openzeppelin/contracts/utils/Memory.sol"; + +contract MemoryTest is Test { + using Memory for *; + + function testSymbolicGetSetFreePointer(bytes32 ptr) public { + Memory.Pointer memoryPtr = ptr.asPointer(); + Memory.setFreePointer(memoryPtr); + assertEq(Memory.getFreePointer().asBytes32(), memoryPtr.asBytes32()); + } +} diff --git a/test/utils/Memory.test.js b/test/utils/Memory.test.js new file mode 100644 index 00000000000..5698728dcfd --- /dev/null +++ b/test/utils/Memory.test.js @@ -0,0 +1,41 @@ +const { ethers } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); + +async function fixture() { + const mock = await ethers.deployContract('$Memory'); + + return { mock }; +} + +describe('Memory', function () { + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + describe('free pointer', function () { + it('sets memory pointer', async function () { + const ptr = '0x00000000000000000000000000000000000000000000000000000000000000a0'; + expect(await this.mock.$setFreePointer(ptr)).to.not.be.reverted; + }); + + it('gets memory pointer', async function () { + expect(await this.mock.$getFreePointer()).to.equal( + // Default pointer + '0x0000000000000000000000000000000000000000000000000000000000000080', + ); + }); + + it('asBytes32', async function () { + const ptr = ethers.toBeHex('0x1234', 32); + await this.mock.$setFreePointer(ptr); + expect(await this.mock.$asBytes32(ptr)).to.equal(ptr); + }); + + it('asPointer', async function () { + const ptr = ethers.toBeHex('0x1234', 32); + await this.mock.$setFreePointer(ptr); + expect(await this.mock.$asPointer(ptr)).to.equal(ptr); + }); + }); +});