From d83c76b15d6b7c3a12afb0d9be6e42f01a87fc05 Mon Sep 17 00:00:00 2001 From: GregTheDev Date: Wed, 3 Jan 2024 12:39:58 -0700 Subject: [PATCH] update v1Status and BaseRecall to validate pricing address --- src/status/AstariaV1Status.sol | 21 +++++++++++++++++++-- src/status/BaseRecall.sol | 13 +++++++++++++ test/AstariaV1Test.sol | 2 ++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/status/AstariaV1Status.sol b/src/status/AstariaV1Status.sol index 7d53f12..8093fd7 100644 --- a/src/status/AstariaV1Status.sol +++ b/src/status/AstariaV1Status.sol @@ -19,15 +19,22 @@ import {Validation} from "starport-core/lib/Validation.sol"; import {BasePricing} from "starport-core/pricing/BasePricing.sol"; import {BaseRecall} from "v1-core/status/BaseRecall.sol"; import {BaseStatus} from "v1-core/status/BaseStatus.sol"; +import {Ownable} from "solady/src/auth/Ownable.sol"; -contract AstariaV1Status is BaseStatus, BaseRecall { +contract AstariaV1Status is BaseStatus, BaseRecall, Ownable { using {StarportLib.getId} for Starport.Loan; + mapping(address => bool) public isValidPricing; + + error InvalidPricingContract(); + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ /* CONSTRUCTOR */ /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ - constructor(Starport SP_) BaseRecall(SP_) {} + constructor(Starport SP_) BaseRecall(SP_) { + _initializeOwner(msg.sender); + } /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ /* EXTERNAL FUNCTIONS */ @@ -59,4 +66,14 @@ contract AstariaV1Status is BaseStatus, BaseRecall { } return valid ? Validation.validate.selector : bytes4(0xFFFFFFFF); } + + function validatePricingContract(address pricingContract) internal virtual override { + if (!isValidPricing[pricingContract]) { + revert InvalidPricingContract(); + } + } + + function setValidPricing(address pricing, bool valid) external onlyOwner { + isValidPricing[pricing] = valid; + } } diff --git a/src/status/BaseRecall.sol b/src/status/BaseRecall.sol index 19a7e16..8c13d17 100644 --- a/src/status/BaseRecall.sol +++ b/src/status/BaseRecall.sol @@ -21,6 +21,7 @@ import {ItemType} from "seaport-types/src/lib/ConsiderationEnums.sol"; import {ConsiderationInterface} from "seaport-types/src/interfaces/ConsiderationInterface.sol"; import {ERC20} from "solady/src/tokens/ERC20.sol"; import {FixedPointMathLib} from "solady/src/utils/FixedPointMathLib.sol"; +import {Ownable} from "solady/src/auth/Ownable.sol"; abstract contract BaseRecall { using FixedPointMathLib for uint256; @@ -109,6 +110,14 @@ abstract contract BaseRecall { return (details.recallMax * ratio) / baseAdjustment; } + /** + * @dev Implement to validate the pricing contract. + * @dev Malicious pricing contracts can lie about the withdrawable recall amount. + * @dev This function should revert if the pricing contract is invalid. + * @param pricingContract The pricing contract to validate + */ + function validatePricingContract(address pricingContract) internal virtual; + /** * @dev Recalls a loan * @param loan The loan to recall @@ -129,6 +138,8 @@ abstract contract BaseRecall { revert RecallAlreadyExists(); } + validatePricingContract(loan.terms.pricingContract); + AdditionalTransfer[] memory recallConsideration = _generateRecallConsideration( msg.sender, loan, 0, details.recallStakeDuration, 0, msg.sender, payable(address(this)) ); @@ -153,6 +164,8 @@ abstract contract BaseRecall { revert LoanHasNotBeenRefinanced(); } + validatePricingContract(loan.terms.pricingContract); + Recall storage recall = recalls[loanId]; address recaller = recall.recaller; // Ensure that a recall exists for the provided tokenId, ensure that the recall diff --git a/test/AstariaV1Test.sol b/test/AstariaV1Test.sol index d978f1b..3e5909f 100644 --- a/test/AstariaV1Test.sol +++ b/test/AstariaV1Test.sol @@ -41,6 +41,8 @@ contract AstariaV1Test is StarportTest { settlement = new AstariaV1Settlement(SP); vm.label(address(settlement), "V1Settlement"); status = new AstariaV1Status(SP); + status.setValidPricing(address(pricing), true); + vm.label(address(status), "V1Status"); lenderEnforcer = new AstariaV1LenderEnforcer();