diff --git a/src/ITokenFlow.sol b/src/ITokenFlow.sol index 97cb7b9..f887462 100644 --- a/src/ITokenFlow.sol +++ b/src/ITokenFlow.sol @@ -1,13 +1,11 @@ -// SPDX-License-Identifier: MIT +// SPDX-License-Identifier: MIT pragma solidity >=0.8.0; - struct Constraint { address token; int256 value; } - /// @title IFlowScope /// @notice A flow scope is a contract that is called by the TokenFlow contract to execute a transaction. /// @dev The flow scope is free to transfer any token of the payer, as long as they're repaid by the end of the flow. @@ -17,7 +15,8 @@ interface IFlowScope { /// @param constraints The netflows constraints of the token flow. /// @param payer The payer of the token flow. Whoever is paying for the token flow. /// @param data Data to be passed to the entrypoint. - function enter(bytes28 selectorExtension, Constraint[] calldata constraints, address payer, bytes calldata data) external; + function enter(bytes28 selectorExtension, Constraint[] calldata constraints, address payer, bytes calldata data) + external; } /// @notice Reverts when the netflows constraints are violated. @@ -50,7 +49,4 @@ interface ITokenFlow { /// @notice A helper function to get the current flow payer. function payer() external view returns (address); - - - } diff --git a/src/TransientNetflows.sol b/src/TransientNetflows.sol index c4d1b09..6f3d468 100644 --- a/src/TransientNetflows.sol +++ b/src/TransientNetflows.sol @@ -3,18 +3,17 @@ pragma solidity ^0.8.28; import {EfficientHashLib} from "solady/utils/EfficientHashLib.sol"; - /// @notice An helper library to manage netflows in a transient storage. /// @dev the netflows are stored as an array of (token, value) pairs, with length in the first slot of the netflows. library TransientNetflows { /// @notice The slot where the nonce for this set of netflows is stored. /// Equivalent to bytes32(uint256(keccak256("TokenFlow.netflows")) - 1) - bytes32 internal constant NETFLOWS_SLOT = 0xb8ea23bb4fe1252fa49dff7d6168221ebfea7b5c55753f63740c76a259eb8f88; - + bytes32 internal constant NETFLOWS_SLOT = 0xb8ea23bb4fe1252fa49dff7d6168221ebfea7b5c55753f63740c76a259eb8f88; /// @notice The slot where the counter of negative netflows is stored. /// Equivalent to bytes32(uint256(keccak256("TokenFlow.negativeNetflowsCounter")) - 1) - bytes32 internal constant NEGATIVE_NETFLOWS_COUNTER_SLOT = 0x14f6a9c5e25725efcb69b4d15bdae41110c6a38bf78cda4b45b3539514d3fc55; + bytes32 internal constant NEGATIVE_NETFLOWS_COUNTER_SLOT = + 0x14f6a9c5e25725efcb69b4d15bdae41110c6a38bf78cda4b45b3539514d3fc55; /// @notice Sets the netflow for a token. If the netflow is not present, it is created. /// @param token The token to set the netflow for. @@ -70,9 +69,8 @@ library TransientNetflows { } } - - function deriveAddressSlot(address token) internal view returns (bytes32 ) { - uint slot; + function deriveAddressSlot(address token) internal view returns (bytes32) { + uint256 slot; assembly ("memory-safe") { slot := tload(NETFLOWS_SLOT) } diff --git a/test/TokenFlow.t.sol b/test/TokenFlow.t.sol index e9f1c2f..8482df6 100644 --- a/test/TokenFlow.t.sol +++ b/test/TokenFlow.t.sol @@ -9,7 +9,7 @@ import {MockERC20} from "./mocks/MockERC20.sol"; import {MockFlowScope} from "./mocks/MockFlowScope.sol"; contract TokenFlowTest is Test { - TokenFlow tokenFlow; + TokenFlow tokenFlow; MockERC20 token1; MockERC20 token2; MockFlowScope flowScope; @@ -23,15 +23,16 @@ contract TokenFlowTest is Test { token1.approve(address(tokenFlow), type(uint256).max); token2.approve(address(tokenFlow), type(uint256).max); + vm.startPrank(address(flowScope)); + token1.approve(address(tokenFlow), type(uint256).max); + token2.approve(address(tokenFlow), type(uint256).max); + vm.stopPrank(); } function test_avoidCollisionWithERC20() public { Constraint[] memory constraints = new Constraint[](0); - bytes memory data = abi.encodeCall( - ERC20.transferFrom, - (address(alice), address(this), 1 ether) - ); + bytes memory data = abi.encodeCall(ERC20.transferFrom, (address(alice), address(this), 1 ether)); vm.expectRevert(); tokenFlow.main(constraints, IFlowScope(address(token1)), data); @@ -58,4 +59,58 @@ contract TokenFlowTest is Test { // FlowScope will re-enter the tokenFlow contract, which will revert tokenFlow.main(constraints, flowScope, ""); } + + function test_revertOnBadNetflows() public { + Constraint[] memory constraints = new Constraint[](1); + constraints[0] = Constraint({ + token: address(token1), + value: -1 ether // Negative constraint that should fail + }); + + vm.expectRevert(BadNetflows.selector); + tokenFlow.main(constraints, flowScope, ""); + } + + function test_bubbleUpErrors() public { + Constraint[] memory constraints = new Constraint[](0); + + // Setup MockFlowScope to revert with custom error + flowScope.addRevert("CustomError"); + + vm.expectRevert("CustomError"); + tokenFlow.main(constraints, flowScope, ""); + } + + function test_multipleConstraints() public { + Constraint[] memory constraints = new Constraint[](2); + constraints[0] = Constraint({token: address(token1), value: 1 ether}); + constraints[1] = Constraint({token: address(token2), value: 2 ether}); + + // Test multiple token constraints + tokenFlow.main(constraints, flowScope, ""); + } + + function test_emptyConstraints() public { + Constraint[] memory constraints = new Constraint[](0); + + // Should succeed with no constraints + tokenFlow.main(constraints, flowScope, ""); + } + + function test_moveInOutSequence() public { + deal(address(token1), address(this), 10 ether); + deal(address(token1), address(flowScope), 10 ether); + // empty constraints so final netflows must be >= 0 + Constraint[] memory constraints = new Constraint[](0); + + // Add sequence of moveIn/moveOut that nets to zero + flowScope.addMoveIn(address(token1), 1 ether, alice); + flowScope.addMoveOut(address(token1), 1 ether); + + tokenFlow.main(constraints, flowScope, ""); + + assertEq(token1.balanceOf(alice), 1 ether); + assertEq(token1.balanceOf(address(flowScope)), 9 ether); + assertEq(token1.balanceOf(address(this)), 10 ether); + } } diff --git a/test/TransientNetflows.t.sol b/test/TransientNetflows.t.sol index aa666fb..13dc8cc 100644 --- a/test/TransientNetflows.t.sol +++ b/test/TransientNetflows.t.sol @@ -4,23 +4,18 @@ pragma solidity ^0.8.28; import {Test, console2 as console} from "forge-std/Test.sol"; import {TransientNetflows} from "src/TransientNetflows.sol"; - - - contract TransientNetflowsTest is Test { - function test_insert(address token, int256 amount) public { TransientNetflows.insert(token, amount); bytes32 slot = TransientNetflows.deriveAddressSlot(token); - int value; + int256 value; assembly { value := tload(slot) } assertEq(value, amount, "incorrect amount"); } - function test_get(address token, int256 amount) public { TransientNetflows.insert(token, amount); int256 value = TransientNetflows.get(token); @@ -35,7 +30,6 @@ contract TransientNetflowsTest is Test { assertEq(TransientNetflows.get(address(2)), 0, "incorrect amount"); } - function test_are_positive(address token1, address token2, int256 amount1, int256 amount2) public { vm.assume(token1 != token2); vm.assume(amount1 > 0); @@ -44,7 +38,7 @@ contract TransientNetflowsTest is Test { TransientNetflows.insert(token1, amount1); assertTrue(TransientNetflows.arePositive(), "should be positive with single positive amount"); - TransientNetflows.insert(token2, amount2); + TransientNetflows.insert(token2, amount2); assertFalse(TransientNetflows.arePositive(), "should be negative with one negative amount"); TransientNetflows.clear(); @@ -53,4 +47,49 @@ contract TransientNetflowsTest is Test { TransientNetflows.insert(token1, amount2); assertFalse(TransientNetflows.arePositive(), "should be negative after reinserting negative"); } + + function test_addOverflow() public { + // Test adding beyond int256 bounds + TransientNetflows.insert(address(1), type(int256).max); + + vm.expectRevert(); + TransientNetflows.add(address(1), 1); + } + + function test_multipleTokens() public { + address[] memory tokens = new address[](3); + tokens[0] = address(1); + tokens[1] = address(2); + tokens[2] = address(3); + + for (uint256 i = 0; i < tokens.length; i++) { + TransientNetflows.insert(tokens[i], int256(i + 1)); + } + + for (uint256 i = 0; i < tokens.length; i++) { + assertEq(TransientNetflows.get(tokens[i]), int256(i + 1)); + } + + TransientNetflows.clear(); + + for (uint256 i = 0; i < tokens.length; i++) { + assertEq(TransientNetflows.get(tokens[i]), 0); + } + } + + function test_negativeNetflows() public { + TransientNetflows.insert(address(1), -1); + assertFalse(TransientNetflows.arePositive()); + } + + function test_are_positive_empty() public { + // Test arePositive() with no entries + assertTrue(TransientNetflows.arePositive(), "should be positive with no entries"); + } + + function test_add_to_nonexistent() public { + // Test adding to a token that hasn't been inserted + TransientNetflows.add(address(1), 1 ether); + assertEq(TransientNetflows.get(address(1)), 1 ether, "should create new entry"); + } } diff --git a/test/mocks/MockERC20.sol b/test/mocks/MockERC20.sol index f24f4be..5d60e97 100644 --- a/test/mocks/MockERC20.sol +++ b/test/mocks/MockERC20.sol @@ -12,7 +12,6 @@ contract MockERC20 is ERC20 { _name = name_; _symbol = symbol_; _decimals = decimals_; - } function name() public view override returns (string memory) { diff --git a/test/mocks/MockFlowScope.sol b/test/mocks/MockFlowScope.sol index 26aa66e..c3c82ee 100644 --- a/test/mocks/MockFlowScope.sol +++ b/test/mocks/MockFlowScope.sol @@ -34,14 +34,14 @@ enum InstructionType { contract MockFlowScope is IFlowScope { ITokenFlow tokenFlow; - uint instructionIndex; + uint256 instructionIndex; InstructionType[] instructionTypes; - - uint moveInIndex; - uint moveOutIndex; - uint reentryIndex; - uint revertIndex; - + + uint256 moveInIndex; + uint256 moveOutIndex; + uint256 reentryIndex; + uint256 revertIndex; + MoveIn[] moveInInstructions; MoveOut[] moveOutInstructions; Reentry[] reentryInstructions; @@ -64,10 +64,10 @@ contract MockFlowScope is IFlowScope { instructionTypes.push(InstructionType.MoveOut); } - function addReentry(IFlowScope flowScope, bytes calldata data) external { + function addReentry(IFlowScope flowScope, bytes calldata data) external { reentryInstructions.push(Reentry({flowScope: flowScope, data: data})); instructionTypes.push(InstructionType.Reentry); - } + } function addRevert(string calldata reason) external { revertInstructions.push(Revert({reason: reason})); @@ -75,29 +75,33 @@ contract MockFlowScope is IFlowScope { } function enter( - bytes28 /* selectorExtension */, + bytes28, /* selectorExtension */ Constraint[] calldata constraints, - address /* payer */, + address, /* payer */ bytes calldata /* data */ ) external { - // Execute current instruction - if (instructionIndex < instructionTypes.length) { - InstructionType iType = instructionTypes[instructionIndex]; - + // Execute all instructions + for (uint256 i = 0; i < instructionTypes.length; i++) { + InstructionType iType = instructionTypes[i]; + if (iType == InstructionType.MoveIn) { - tokenFlow.moveIn(moveInInstructions[moveInIndex].token, moveInInstructions[moveInIndex].amount, moveInInstructions[moveInIndex].recipient); + tokenFlow.moveIn( + moveInInstructions[moveInIndex].token, + moveInInstructions[moveInIndex].amount, + moveInInstructions[moveInIndex].recipient + ); moveInIndex++; } else if (iType == InstructionType.MoveOut) { tokenFlow.moveOut(moveOutInstructions[moveOutIndex].token, moveOutInstructions[moveOutIndex].amount); moveOutIndex++; } else if (iType == InstructionType.Reentry) { - ITokenFlow(msg.sender).main(constraints, reentryInstructions[reentryIndex].flowScope, reentryInstructions[reentryIndex].data); + ITokenFlow(msg.sender).main( + constraints, reentryInstructions[reentryIndex].flowScope, reentryInstructions[reentryIndex].data + ); reentryIndex++; } else if (iType == InstructionType.Revert) { revert(revertInstructions[revertIndex].reason); } - - instructionIndex++; } } -} \ No newline at end of file +}