diff --git a/src/Custodian.sol b/src/Custodian.sol index 2cda6333..96946be3 100644 --- a/src/Custodian.sol +++ b/src/Custodian.sol @@ -25,12 +25,14 @@ import {Pricing} from "src/pricing/Pricing.sol"; import {LoanManager} from "src/LoanManager.sol"; import "forge-std/console.sol"; import {ConduitHelper} from "src/ConduitHelper.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; contract Custodian is ContractOffererInterface, TokenReceiverInterface, ConduitHelper { + using {StarPortLib.getId} for LoanManager.Loan; LoanManager public immutable LM; address public immutable seaport; event SeaportCompatibleContractDeployed(); @@ -80,7 +82,7 @@ contract Custodian is // we burn the loan on repayment in generateOrder, but in ratify order where we would trigger any post settlement actions // we burn it here so that in the case it was minted and an owner is set for settlement their pointer can still be utilized // in this case we are not a repayment we have burnt the loan in the generate order for a repayment - uint256 loanId = LM.getLoanIdFromLoan(loan); + uint256 loanId = loan.getId(); if (LM.active(loanId)) { if ( SettlementHandler(loan.terms.handler).execute(loan) != diff --git a/src/LoanManager.sol b/src/LoanManager.sol index a884a04e..ea8b4f74 100644 --- a/src/LoanManager.sol +++ b/src/LoanManager.sol @@ -47,6 +47,7 @@ import {ConduitHelper} from "src/ConduitHelper.sol"; contract LoanManager is ERC721, ContractOffererInterface, ConduitHelper { using FixedPointMathLib for uint256; using {StarPortLib.toReceivedItems} for SpentItem[]; + using {StarPortLib.getId} for LoanManager.Loan; ConsiderationInterface public constant seaport = ConsiderationInterface(0x2e234DAe75C793f67A35089C9d99245E1C58470b); @@ -256,12 +257,8 @@ contract LoanManager is ERC721, ContractOffererInterface, ConduitHelper { _settle(loan); } - function getLoanIdFromLoan(Loan memory loan) public pure returns (uint256) { - return uint256(keccak256(abi.encode(loan))); - } - function _settle(Loan memory loan) internal { - uint256 tokenId = getLoanIdFromLoan(loan); + uint256 tokenId = loan.getId(); if (!_issued(tokenId)) { revert InvalidLoan(tokenId); } @@ -414,7 +411,7 @@ contract LoanManager is ERC721, ContractOffererInterface, ConduitHelper { function _issueLoanManager(Loan memory loan, bool mint) internal { bytes memory encodedLoan = abi.encode(loan); - uint256 loanId = uint256(keccak256(encodedLoan)); + uint256 loanId = loan.getId(); _setExtraData(loanId, uint8(FieldFlags.ACTIVE)); if (mint) { @@ -564,7 +561,11 @@ contract LoanManager is ERC721, ContractOffererInterface, ConduitHelper { // used for any additional payments beyond consideration and carry ReceivedItem[] memory additionalPayment - ) = Pricing(loan.terms.pricing).isValidRefinance(loan, newPricingData); + ) = Pricing(loan.terms.pricing).isValidRefinance( + loan, + newPricingData, + msg.sender + ); ReceivedItem[] memory refinanceConsideration = _mergeConsiderations( considerationPayment, diff --git a/src/handlers/AstariaV1SettlementHandler.sol b/src/handlers/AstariaV1SettlementHandler.sol index 02d96dc2..3d707aa3 100644 --- a/src/handlers/AstariaV1SettlementHandler.sol +++ b/src/handlers/AstariaV1SettlementHandler.sol @@ -9,8 +9,11 @@ import { import {BaseHook} from "src/hooks/BaseHook.sol"; import {BaseRecall} from "src/hooks/BaseRecall.sol"; import {DutchAuctionHandler} from "src/handlers/DutchAuctionHandler.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; contract AstariaV1SettlementHandler is DutchAuctionHandler { + using {StarPortLib.getId} for LoanManager.Loan; + constructor(LoanManager LM_) DutchAuctionHandler(LM_) {} function getSettlement( @@ -22,9 +25,7 @@ contract AstariaV1SettlementHandler is DutchAuctionHandler { override returns (ReceivedItem[] memory, address restricted) { - (address recaller, ) = BaseRecall(loan.terms.hook).recalls( - LM.getLoanIdFromLoan(loan) - ); + (address recaller, ) = BaseRecall(loan.terms.hook).recalls(loan.getId()); if (recaller == loan.issuer) { return (new ReceivedItem[](0), recaller); diff --git a/src/hooks/AstariaV1SettlementHook.sol b/src/hooks/AstariaV1SettlementHook.sol index 0c8b4d21..29c6b7b4 100644 --- a/src/hooks/AstariaV1SettlementHook.sol +++ b/src/hooks/AstariaV1SettlementHook.sol @@ -4,15 +4,18 @@ import {LoanManager} from "src/LoanManager.sol"; import {BaseRecall} from "src/hooks/BaseRecall.sol"; import "forge-std/console2.sol"; import {BaseHook} from "src/hooks/BaseHook.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; contract AstariaV1SettlementHook is BaseHook, BaseRecall { + using {StarPortLib.getId} for LoanManager.Loan; + constructor(LoanManager LM_) BaseRecall(LM_) {} function isActive( LoanManager.Loan calldata loan ) external view override returns (bool) { Details memory details = abi.decode(loan.terms.hookData, (Details)); - uint256 tokenId = LM.getLoanIdFromLoan(loan); + uint256 tokenId = loan.getId(); return !(uint256(recalls[tokenId].start) + details.recallWindow > block.timestamp); @@ -22,7 +25,7 @@ contract AstariaV1SettlementHook is BaseHook, BaseRecall { LoanManager.Loan calldata loan ) external view override returns (bool) { Details memory details = abi.decode(loan.terms.hookData, (Details)); - uint256 tokenId = LM.getLoanIdFromLoan(loan); + uint256 tokenId = loan.getId(); Recall memory recall = recalls[tokenId]; return (recall.start + details.recallWindow > block.timestamp) && diff --git a/src/hooks/BaseRecall.sol b/src/hooks/BaseRecall.sol index a3f358af..7df88a99 100644 --- a/src/hooks/BaseRecall.sol +++ b/src/hooks/BaseRecall.sol @@ -26,10 +26,11 @@ import { } from "seaport-types/src/interfaces/ConduitInterface.sol"; import {FixedPointMathLib} from "solady/src/utils/FixedPointMathLib.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; abstract contract BaseRecall is ConduitHelper { using FixedPointMathLib for uint256; - + using {StarPortLib.getId} for LoanManager.Loan; event Recalled(uint256 loandId, address recaller, uint256 end); event Withdraw(uint256 loanId, address withdrawer); LoanManager LM; @@ -69,7 +70,7 @@ abstract contract BaseRecall is ConduitHelper { LoanManager.Loan calldata loan ) external view returns (uint256) { Details memory details = abi.decode(loan.terms.hookData, (Details)); - uint256 loanId = LM.getLoanIdFromLoan(loan); + uint256 loanId = loan.getId(); // calculates the porportion of time elapsed, then multiplies times the max rate return details.recallMax.mulWad( diff --git a/src/pricing/AstariaV1Pricing.sol b/src/pricing/AstariaV1Pricing.sol index 6743dc0c..22b202f0 100644 --- a/src/pricing/AstariaV1Pricing.sol +++ b/src/pricing/AstariaV1Pricing.sol @@ -8,9 +8,11 @@ import {AstariaV1SettlementHook} from "src/hooks/AstariaV1SettlementHook.sol"; import {BaseRecall} from "src/hooks/BaseRecall.sol"; import {FixedPointMathLib} from "solady/src/utils/FixedPointMathLib.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; contract AstariaV1Pricing is CompoundInterestPricing { using FixedPointMathLib for uint256; + using {StarPortLib.getId} for LoanManager.Loan; constructor(LoanManager LM_) Pricing(LM_) {} @@ -18,7 +20,8 @@ contract AstariaV1Pricing is CompoundInterestPricing { function isValidRefinance( LoanManager.Loan memory loan, - bytes memory newPricingData + bytes memory newPricingData, + address caller ) external view @@ -31,7 +34,7 @@ contract AstariaV1Pricing is CompoundInterestPricing { ) { // borrowers can refinance a loan at any time - if (msg.sender != loan.borrower) { + if (caller != loan.borrower) { // check if a recall is occuring AstariaV1SettlementHook hook = AstariaV1SettlementHook(loan.terms.hook); Details memory newDetails = abi.decode(newPricingData, (Details)); @@ -46,7 +49,7 @@ contract AstariaV1Pricing is CompoundInterestPricing { uint256 proportion; address payable receiver = payable(loan.issuer); - uint256 loanId = LM.getLoanIdFromLoan(loan); + uint256 loanId = loan.getId(); // scenario where the recaller is not penalized // recaller stake is refunded if (newDetails.rate > oldDetails.rate) { diff --git a/src/pricing/BasePricing.sol b/src/pricing/BasePricing.sol index 8d756cd0..02b6598f 100644 --- a/src/pricing/BasePricing.sol +++ b/src/pricing/BasePricing.sol @@ -12,7 +12,7 @@ import {StarPortLib} from "src/lib/StarPortLib.sol"; abstract contract BasePricing is Pricing { using FixedPointMathLib for uint256; - using StarPortLib for LoanManager.Loan; + using {StarPortLib.getId} for LoanManager.Loan; struct Details { uint256 rate; uint256 carryRate; diff --git a/src/pricing/BaseRecallPricing.sol b/src/pricing/BaseRecallPricing.sol index 8e1be115..1b1cfaac 100644 --- a/src/pricing/BaseRecallPricing.sol +++ b/src/pricing/BaseRecallPricing.sol @@ -13,7 +13,8 @@ import {StarPortLib} from "src/lib/StarPortLib.sol"; abstract contract BaseRecallPricing is BasePricing { function isValidRefinance( LoanManager.Loan memory loan, - bytes memory newPricingData + bytes memory newPricingData, + address caller ) external view diff --git a/src/pricing/Pricing.sol b/src/pricing/Pricing.sol index c188b858..d65bfb79 100644 --- a/src/pricing/Pricing.sol +++ b/src/pricing/Pricing.sol @@ -8,7 +8,7 @@ import "seaport/lib/seaport-sol/src/lib/ReceivedItemLib.sol"; abstract contract Pricing { LoanManager LM; error InvalidRefinance(); - + constructor(LoanManager LM_) { LM = LM_; } @@ -19,10 +19,15 @@ abstract contract Pricing { function isValidRefinance( LoanManager.Loan memory loan, - bytes memory newPricingData + bytes memory newPricingData, + address caller ) external view virtual - returns (ReceivedItem[] memory, ReceivedItem[] memory, ReceivedItem[] memory); + returns ( + ReceivedItem[] memory, + ReceivedItem[] memory, + ReceivedItem[] memory + ); } diff --git a/src/pricing/SimpleInterestPricing.sol b/src/pricing/SimpleInterestPricing.sol index c5134492..366154b7 100644 --- a/src/pricing/SimpleInterestPricing.sol +++ b/src/pricing/SimpleInterestPricing.sol @@ -20,7 +20,8 @@ contract SimpleInterestPricing is BasePricing { function isValidRefinance( LoanManager.Loan memory loan, - bytes memory newPricingData + bytes memory newPricingData, + address caller ) external view diff --git a/test/TestAstariaV1Loan.sol b/test/TestAstariaV1Loan.sol index 485aec73..c324ca1c 100644 --- a/test/TestAstariaV1Loan.sol +++ b/test/TestAstariaV1Loan.sol @@ -4,7 +4,10 @@ import {BaseRecall} from "src/hooks/BaseRecall.sol"; // import {Base} from "src/pricing/CompoundInterestPricing.sol"; // import {AstariaV1Pricing} from "src/pricing/AstariaV1Pricing.sol"; import "forge-std/console2.sol"; +import {StarPortLib} from "src/lib/StarPortLib.sol"; + contract TestAstariaV1Loan is AstariaV1Test { + using {StarPortLib.getId} for LoanManager.Loan; function testNewLoanERC721CollateralDefaultTermsRecall() public { Custodian custody = Custodian(LM.defaultCustodian()); @@ -46,7 +49,6 @@ contract TestAstariaV1Loan is AstariaV1Test { collateral: ConsiderationItemLib.toSpentItemArray(selectedCollateral), debt: debt }); - bool isTrusted = true; LoanManager.Loan memory loan = newLoan( NewLoanData( @@ -57,8 +59,11 @@ contract TestAstariaV1Loan is AstariaV1Test { Originator(UO), selectedCollateral ); - uint256 loanId = LM.getLoanIdFromLoan(loan); - assertTrue(LM.active(loanId), "LoanId not in active state after a new loan"); + uint256 loanId = loan.getId(); + assertTrue( + LM.active(loanId), + "LoanId not in active state after a new loan" + ); { vm.startPrank(recaller.addr); @@ -86,22 +91,40 @@ contract TestAstariaV1Loan is AstariaV1Test { uint256 stake; { uint256 balanceBefore = erc20s[0].balanceOf(recaller.addr); - uint256 recallContractBalanceBefore = erc20s[0].balanceOf(address(hook)); - BaseRecall.Details memory details = abi.decode(loan.terms.hookData, (BaseRecall.Details)); + uint256 recallContractBalanceBefore = erc20s[0].balanceOf(address(hook)); + BaseRecall.Details memory details = abi.decode( + loan.terms.hookData, + (BaseRecall.Details) + ); vm.warp(block.timestamp + details.honeymoon); vm.startPrank(recaller.addr); BaseRecall recallContract = BaseRecall(address(hook)); recallContract.recall(loan, recallerConduit); vm.stopPrank(); - + uint256 balanceAfter = erc20s[0].balanceOf(recaller.addr); - uint256 recallContractBalanceAfter = erc20s[0].balanceOf(address(hook)); + uint256 recallContractBalanceAfter = erc20s[0].balanceOf(address(hook)); - BasePricing.Details memory pricingDetails = abi.decode(loan.terms.pricingData, (BasePricing.Details)); - stake = BasePricing(address(pricing)).calculateInterest(details.recallStakeDuration, loan.debt[0].amount, pricingDetails.rate); - assertEq(balanceBefore, balanceAfter + stake, "Recaller balance not transfered correctly"); - assertEq(recallContractBalanceBefore + stake, recallContractBalanceAfter, "Balance not transfered to recall contract correctly"); + BasePricing.Details memory pricingDetails = abi.decode( + loan.terms.pricingData, + (BasePricing.Details) + ); + stake = BasePricing(address(pricing)).calculateInterest( + details.recallStakeDuration, + loan.debt[0].amount, + pricingDetails.rate + ); + assertEq( + balanceBefore, + balanceAfter + stake, + "Recaller balance not transfered correctly" + ); + assertEq( + recallContractBalanceBefore + stake, + recallContractBalanceAfter, + "Balance not transfered to recall contract correctly" + ); } { BaseRecall recallContract = BaseRecall(address(hook)); @@ -109,7 +132,11 @@ contract TestAstariaV1Loan is AstariaV1Test { uint64 start; (recallerAddr, start) = recallContract.recalls(loanId); - assertEq(recaller.addr, recallerAddr, "Recaller address logged incorrectly"); + assertEq( + recaller.addr, + recallerAddr, + "Recaller address logged incorrectly" + ); assertEq(start, block.timestamp, "Recall start logged incorrectly"); } { @@ -139,29 +166,51 @@ contract TestAstariaV1Loan is AstariaV1Test { uint256 newLenderBefore = erc20s[0].balanceOf(refinancer.addr); uint256 oldLenderBefore = erc20s[0].balanceOf(lender.addr); uint256 recallerBefore = erc20s[0].balanceOf(recaller.addr); - BaseRecall.Details memory details = abi.decode(loan.terms.hookData, (BaseRecall.Details)); + BaseRecall.Details memory details = abi.decode( + loan.terms.hookData, + (BaseRecall.Details) + ); vm.startPrank(refinancer.addr); vm.warp(block.timestamp + (details.recallWindow / 2)); LM.refinance( loan, abi.encode( - BasePricing.Details({ - rate: details.recallMax / 2, - carryRate: 0 - }) + BasePricing.Details({rate: details.recallMax / 2, carryRate: 0}) ), refinancerConduit ); vm.stopPrank(); uint256 delta_t = block.timestamp - loan.start; - BasePricing.Details memory pricingDetails = abi.decode(loan.terms.pricingData, (BasePricing.Details)); - uint256 interest = BasePricing(address(pricing)).calculateInterest(delta_t, loan.debt[0].amount, pricingDetails.rate); + BasePricing.Details memory pricingDetails = abi.decode( + loan.terms.pricingData, + (BasePricing.Details) + ); + uint256 interest = BasePricing(address(pricing)).calculateInterest( + delta_t, + loan.debt[0].amount, + pricingDetails.rate + ); uint256 newLenderAfter = erc20s[0].balanceOf(refinancer.addr); uint256 oldLenderAfter = erc20s[0].balanceOf(lender.addr); - assertEq(oldLenderAfter, oldLenderBefore + loan.debt[0].amount + interest, "Payment to old lender calculated incorrectly"); - assertEq(newLenderAfter, newLenderBefore - (loan.debt[0].amount + interest + stake), "Payment from new lender calculated incorrectly"); - assertEq(recallerBefore + stake, erc20s[0].balanceOf(recaller.addr), "Recaller did not recover stake as expected"); - assertTrue(LM.inactive(loanId), "LoanId not properly flipped to inactive after refinance"); + assertEq( + oldLenderAfter, + oldLenderBefore + loan.debt[0].amount + interest, + "Payment to old lender calculated incorrectly" + ); + assertEq( + newLenderAfter, + newLenderBefore - (loan.debt[0].amount + interest + stake), + "Payment from new lender calculated incorrectly" + ); + assertEq( + recallerBefore + stake, + erc20s[0].balanceOf(recaller.addr), + "Recaller did not recover stake as expected" + ); + assertTrue( + LM.inactive(loanId), + "LoanId not properly flipped to inactive after refinance" + ); } { uint256 withdrawerBalanceBefore = erc20s[0].balanceOf(address(this)); @@ -172,8 +221,16 @@ contract TestAstariaV1Loan is AstariaV1Test { recallContract.withdraw(loan, payable(address(this))); uint256 withdrawerBalanceAfter = erc20s[0].balanceOf(address(this)); uint256 recallContractBalanceAfter = erc20s[0].balanceOf(address(hook)); - assertEq(withdrawerBalanceBefore + stake, withdrawerBalanceAfter, "Withdrawer did not recover stake as expected"); - assertEq(recallContractBalanceBefore - stake, recallContractBalanceAfter, "BaseRecall did not return the stake as expected"); + assertEq( + withdrawerBalanceBefore + stake, + withdrawerBalanceAfter, + "Withdrawer did not recover stake as expected" + ); + assertEq( + recallContractBalanceBefore - stake, + recallContractBalanceAfter, + "BaseRecall did not return the stake as expected" + ); } } -} \ No newline at end of file +} diff --git a/test/TestExoticLoans.t.sol b/test/TestExoticLoans.t.sol index e2e94840..0e881e09 100644 --- a/test/TestExoticLoans.t.sol +++ b/test/TestExoticLoans.t.sol @@ -86,7 +86,8 @@ contract SwapPricing is Pricing { function isValidRefinance( LoanManager.Loan memory loan, - bytes memory newPricingData + bytes memory newPricingData, + address caller ) external view override returns (ReceivedItem[] memory, ReceivedItem[] memory, ReceivedItem[] memory) { return (new ReceivedItem[](0), new ReceivedItem[](0), new ReceivedItem[](0)); } diff --git a/test/TestLoanCombinations.t.sol b/test/TestLoanCombinations.t.sol index a39b5821..a6cfa886 100644 --- a/test/TestLoanCombinations.t.sol +++ b/test/TestLoanCombinations.t.sol @@ -2,9 +2,12 @@ import "./StarPortTest.sol"; import {FixedPointMathLib} from "solady/src/utils/FixedPointMathLib.sol"; import { LibString } from "solady/src/utils/LibString.sol"; +import { StarPortLib } from "src/lib/StarPortLib.sol"; + import "forge-std/console.sol"; contract TestLoanCombinations is StarPortTest { + using {StarPortLib.getId} for LoanManager.Loan; // TODO test liquidations function testLoan721for20SimpleInterestDutchFixedRepay() public { LoanManager.Terms memory terms = LoanManager.Terms({ @@ -31,7 +34,7 @@ contract TestLoanCombinations is StarPortTest { assertTrue(erc20s[0].balanceOf(borrower.addr) > initial20Balance, "Borrower did not receive ERC20"); - uint256 loanId = LM.getLoanIdFromLoan(loan); + uint256 loanId = loan.getId(); assertTrue(LM.active(loanId), "LoanId not in active state after a new loan"); skip(10 days);