diff --git a/src/UniswapFlashLoan.sol b/src/UniswapFlashLoan.sol index e8324e0..0897fa1 100644 --- a/src/UniswapFlashLoan.sol +++ b/src/UniswapFlashLoan.sol @@ -70,6 +70,8 @@ contract UniswapFlashLoan is IUniswapV3FlashCallback, QuarkScript { * @param data FlashLoanCallbackPayload encoded to bytes passed from IUniswapV3Pool.flash(); contains scripts info to execute before repaying the flash loan */ function uniswapV3FlashCallback(uint256 fee0, uint256 fee1, bytes calldata data) external { + disallowCallback(); + FlashLoanCallbackPayload memory input = abi.decode(data, (FlashLoanCallbackPayload)); IUniswapV3Pool pool = IUniswapV3Pool(PoolAddress.computeAddress(UniswapFactoryAddress.getAddress(), input.poolKey)); diff --git a/src/UniswapFlashSwapExactOut.sol b/src/UniswapFlashSwapExactOut.sol index 2e8f500..b42df87 100644 --- a/src/UniswapFlashSwapExactOut.sol +++ b/src/UniswapFlashSwapExactOut.sol @@ -74,6 +74,8 @@ contract UniswapFlashSwapExactOut is IUniswapV3SwapCallback, QuarkScript { * @param data FlashSwap encoded to bytes passed from UniswapV3Pool.swap(); contains script info to execute (possibly with checks) before returning the owed amount */ function uniswapV3SwapCallback(int256 amount0Delta, int256 amount1Delta, bytes calldata data) external { + disallowCallback(); + FlashSwapExactOutInput memory input = abi.decode(data, (FlashSwapExactOutInput)); IUniswapV3Pool pool = IUniswapV3Pool(PoolAddress.computeAddress(UniswapFactoryAddress.getAddress(), input.poolKey)); diff --git a/test/UniswapFlashLoan.t.sol b/test/UniswapFlashLoan.t.sol index 9c35a4b..072eac7 100644 --- a/test/UniswapFlashLoan.t.sol +++ b/test/UniswapFlashLoan.t.sol @@ -187,6 +187,49 @@ contract UniswapFlashLoanTest is Test { assertEq(IComet(comet).borrowBalanceOf(address(wallet)), 1000e6); } + function testRevertsForSecondCallback() public { + vm.pauseGasMetering(); + QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0))); + address[] memory callContracts = new address[](1); + bytes[] memory callDatas = new bytes[](1); + // Call into the wallet and try to execute the fallback function again using the callback mechanism + callContracts[0] = address(wallet); + callDatas[0] = abi.encodeWithSelector( + Ethcall.run.selector, + address(wallet), + abi.encodeCall(UniswapFlashLoan.uniswapV3FlashCallback, (100, 500, bytes(""))), + 0 + ); + QuarkWallet.QuarkOperation memory op = new QuarkOperationHelper().newBasicOpWithCalldata( + wallet, + uniswapFlashLoan, + abi.encodeWithSelector( + UniswapFlashLoan.run.selector, + UniswapFlashLoan.UniswapFlashLoanPayload({ + token0: USDC, + token1: DAI, + fee: 100, + amount0: 1000e6, + amount1: 0, + callContract: multicallAddress, + callData: abi.encodeWithSelector(Multicall.run.selector, callContracts, callDatas) + }) + ), + ScriptType.ScriptAddress + ); + bytes memory signature = new SignatureHelper().signOp(alicePrivateKey, wallet, op); + vm.resumeGasMetering(); + vm.expectRevert( + abi.encodeWithSelector( + Multicall.MulticallError.selector, + 0, + callContracts[0], + abi.encodeWithSelector(QuarkWallet.NoActiveCallback.selector) + ) + ); + wallet.executeQuarkOperation(op, signature); + } + function testRevertsForInvalidCaller() public { vm.pauseGasMetering(); QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0))); diff --git a/test/UniswapFlashSwapExactOut.t.sol b/test/UniswapFlashSwapExactOut.t.sol index 4db67c9..919ee28 100644 --- a/test/UniswapFlashSwapExactOut.t.sol +++ b/test/UniswapFlashSwapExactOut.t.sol @@ -118,6 +118,51 @@ contract UniswapFlashSwapExactOutTest is Test { assertEq(IComet(comet).borrowBalanceOf(address(wallet)), borrowAmountOfUSDC); } + function testRevertsForSecondCallback() public { + vm.pauseGasMetering(); + QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0))); + // Set up some funds for test + deal(WETH, address(wallet), 10 ether); + address[] memory callContracts = new address[](1); + bytes[] memory callDatas = new bytes[](1); + // Call into the wallet and try to execute the fallback function again using the callback mechanism + callContracts[0] = address(wallet); + callDatas[0] = abi.encodeWithSelector( + Ethcall.run.selector, + address(wallet), + abi.encodeCall(UniswapFlashSwapExactOut.uniswapV3SwapCallback, (100, 500, bytes(""))), + 0 + ); + QuarkWallet.QuarkOperation memory op = new QuarkOperationHelper().newBasicOpWithCalldata( + wallet, + uniswapFlashSwapExactOut, + abi.encodeWithSelector( + UniswapFlashSwapExactOut.run.selector, + UniswapFlashSwapExactOut.UniswapFlashSwapExactOutPayload({ + tokenOut: WETH, + tokenIn: USDC, + fee: 500, + amountOut: 1 ether, + sqrtPriceLimitX96: 0, + callContract: multicallAddress, + callData: abi.encodeWithSelector(Multicall.run.selector, callContracts, callDatas) + }) + ), + ScriptType.ScriptAddress + ); + bytes memory signature = new SignatureHelper().signOp(alicePrivateKey, wallet, op); + vm.resumeGasMetering(); + vm.expectRevert( + abi.encodeWithSelector( + Multicall.MulticallError.selector, + 0, + callContracts[0], + abi.encodeWithSelector(QuarkWallet.NoActiveCallback.selector) + ) + ); + wallet.executeQuarkOperation(op, signature); + } + function testInvalidCallerFlashSwap() public { vm.pauseGasMetering(); QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0)));