From a0c9bcc17092dda3607f6a39b67c9ee90a3e9e34 Mon Sep 17 00:00:00 2001 From: Pranav Bhardwaj Date: Tue, 11 Jun 2024 23:55:25 -0400 Subject: [PATCH] add partial rewards claiming --- contracts/staking/PrtStakingPool.sol | 49 +++++ test/staking/prtStakingPool.spec.ts | 272 +++++++++++++++++++++++++++ 2 files changed, 321 insertions(+) diff --git a/contracts/staking/PrtStakingPool.sol b/contracts/staking/PrtStakingPool.sol index 1223aefc..4978e79a 100644 --- a/contracts/staking/PrtStakingPool.sol +++ b/contracts/staking/PrtStakingPool.sol @@ -139,6 +139,18 @@ contract PrtStakingPool is Ownable, ERC20Snapshot, ReentrancyGuard { setToken.transfer(msg.sender, amount); } + /** + * @notice Claim partial staking rewards from pending snapshots for `msg.sender` up to `_endClaimId`. + * @param _endClaimId The snapshot id to end the partial claim + */ + function claimPartial(uint256 _endClaimId) public nonReentrant { + uint256 currentId = getCurrentId(); + uint256 amount = _getPendingPartialRewards(currentId, _endClaimId, msg.sender); + require(amount > 0, "No rewards to claim"); + lastSnapshotId[msg.sender] = _endClaimId; + setToken.transfer(msg.sender, amount); + } + /** * @notice ONLY OWNER: Update the PrtFeeSplitExtension address. */ @@ -179,6 +191,20 @@ contract PrtStakingPool is Ownable, ERC20Snapshot, ReentrancyGuard { return _getPendingRewards(currentId, _account); } + /** + * @notice Get pending partial rewards for an account. + * @param _account The address of the account + * @param _endClaimId The snapshot id to end the partial claim + * @return The pending partial rewards for the account + */ + function getPendingPartialRewards( + address _account, + uint256 _endClaimId + ) external view returns (uint256) { + uint256 currentId = getCurrentId(); + return _getPendingPartialRewards(currentId, _endClaimId, _account); + } + /** * @notice Get rewards for an account from a specific snapshot id. * @param _snapshotId The snapshot id @@ -250,6 +276,29 @@ contract PrtStakingPool is Ownable, ERC20Snapshot, ReentrancyGuard { } } + /** + * @dev Get pending partial rewards for an account. + * @param _currentId The current snapshot id + * @param _endClaimId The snapshot id to end the partial claim + * @param _account The address of the account + * @return amount The pending partial rewards for the account + */ + function _getPendingPartialRewards( + uint256 _currentId, + uint256 _endClaimId, + address _account + ) + private + view + returns (uint256 amount) + { + require(_endClaimId < _currentId, "End claim id must be less than current id"); + uint256 lastRewardId = lastSnapshotId[_account]; + for (uint256 i = lastRewardId; i < _endClaimId; i++) { + amount += _getSnapshotRewards(i, _account); + } + } + /** * @dev Get rewards for an account from a specific snapshot id. * @param _snapshotId The snapshot id diff --git a/test/staking/prtStakingPool.spec.ts b/test/staking/prtStakingPool.spec.ts index d75422a5..78c5cae6 100644 --- a/test/staking/prtStakingPool.spec.ts +++ b/test/staking/prtStakingPool.spec.ts @@ -545,6 +545,161 @@ describe.only("PrtStakingPool", () => { }); }); + describe("#claimPartial", async () => { + let bobPrtAmount: BigNumber; + let alicePrtAmount: BigNumber; + let carolPrtAmount: BigNumber; + let snap1Amount: BigNumber; + let snap2Amount: BigNumber; + let snap3Amount: BigNumber; + + let subjectSnapshotId: BigNumber; + + beforeEach(async () => { + // PRT balances (bob: 6 PRT, alice: 4 PRT, carol: 5 PRT) + bobPrtAmount = ether(6); + alicePrtAmount = ether(4); + carolPrtAmount = ether(5); + + // Snapshot rewards amounts (snap1: 1 SetToken, snap2: 1.5 SetToken, snap3: 2 SetToken) + snap1Amount = ether(1); + snap2Amount = ether(1.5); + snap3Amount = ether(2); + + // Fund bob, alice, and carol with PRT + await prt.connect(owner.wallet).transfer(bob.address, bobPrtAmount); + await prt.connect(owner.wallet).transfer(alice.address, alicePrtAmount); + await prt.connect(owner.wallet).transfer(carol.address, carolPrtAmount); + + // Approve staking pool to spend PRT + await prt.connect(bob.wallet).approve(prtStakingPool.address, bobPrtAmount); + await prt.connect(alice.wallet).approve(prtStakingPool.address, alicePrtAmount); + await prt.connect(carol.wallet).approve(prtStakingPool.address, carolPrtAmount); + + // Before snapshot 1, bob stakes PRTs + await prtStakingPool.connect(bob.wallet).stake(bobPrtAmount); + + // Take snapshot 1 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap1Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap1Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap1Amount); + + // After snapshot 1, alice stakes PRTs + await prtStakingPool.connect(alice.wallet).stake(alicePrtAmount); + + // After snapshot 1, carol stakes PRTs + await prtStakingPool.connect(carol.wallet).stake(carolPrtAmount); + + // Take snapshot 2 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap2Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap2Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap2Amount); + + // After snapshot 2, carol unstakes PRTs + await prtStakingPool.connect(carol.wallet).unstake(carolPrtAmount); + + // Take snapshot 3 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap3Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap3Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap3Amount); + + subjectSnapshotId = TWO; + }); + + async function subject(caller: Account): Promise { + return prtStakingPool.connect(caller.wallet).claimPartial(subjectSnapshotId); + } + + it("should transfer the pending SetToken rewards from the PrtStakingPool to the staker", async () => { + const prtStakingPoolSetTokenBalanceBefore = await setToken.balanceOf(prtStakingPool.address); + const prtHolderOneSetTokenBalanceBefore = await setToken.balanceOf(bob.address); + + const totalSupplySnap2 = await prtStakingPool.totalSupplyAt(TWO); + + // (bob) who stakes before snapshot 1 and never unstakes + const expectedBobPendingRewards = snap1Amount.add( + bobPrtAmount.mul(snap2Amount).div(totalSupplySnap2) + ); + + await subject(bob); + + const prtStakingPoolSetTokenBalanceAfter = await setToken.balanceOf(prtStakingPool.address); + const prtHolderOneSetTokenBalanceAfter = await setToken.balanceOf(bob.address); + + expect(prtStakingPoolSetTokenBalanceAfter).to.eq(prtStakingPoolSetTokenBalanceBefore.sub(expectedBobPendingRewards)); + expect(prtHolderOneSetTokenBalanceAfter).to.eq(prtHolderOneSetTokenBalanceBefore.add(expectedBobPendingRewards)); + }); + + it("should update the lastSnapshotId", async () => { + const lastSnapshotIdBefore = await prtStakingPool.lastSnapshotId(bob.address); + expect(lastSnapshotIdBefore).to.eq(0); + + await subject(bob); + + const lastSnapshotIdAfter = await prtStakingPool.lastSnapshotId(bob.address); + expect(lastSnapshotIdAfter).to.eq(2); + + const currentId = await prtStakingPool.getCurrentId(); + expect(lastSnapshotIdAfter).to.be.lt(currentId); + }); + + describe("when the user stakes after the first snapshot", async () => { + it("should still return pending rewards for staked snapshots", async () => { + const totalSupplySnap2 = await prtStakingPool.totalSupplyAt(TWO); + + // (alice) who stakes after snapshot 1 and never unstakes + const expectedAlicePendingRewards = (alicePrtAmount.mul(snap2Amount).div(totalSupplySnap2)); + + const aliceSetTokenBalanceBefore = await setToken.balanceOf(alice.address); + await subject(alice); + const aliceSetTokenBalanceAfter = await setToken.balanceOf(alice.address); + const actualAliceSetTokenChange = aliceSetTokenBalanceAfter.sub(aliceSetTokenBalanceBefore); + expect(actualAliceSetTokenChange).to.eq(expectedAlicePendingRewards); + }); + }); + + describe("when the user unstakes before the latest snapshot", async () => { + it("should still return pending rewards for staked snapshots", async () => { + const totalSupplySnap2 = await prtStakingPool.totalSupplyAt(TWO); + + // (carol) who stakes after snapshot 1 and unstakes after snapshot 2 + const expectedCarolPendingRewards = carolPrtAmount.mul(snap2Amount).div(totalSupplySnap2); + + const carolSetTokenBalanceBefore = await setToken.balanceOf(carol.address); + await subject(carol); + const carolSetTokenBalanceAfter = await setToken.balanceOf(carol.address); + const actualCarolSetTokenChange = carolSetTokenBalanceAfter.sub(carolSetTokenBalanceBefore); + expect(actualCarolSetTokenChange).to.eq(expectedCarolPendingRewards); + }); + }); + + describe("when there are no pending rewards", async () => { + it("should revert", async () => { + await expect(subject(owner)).to.be.revertedWith("No rewards to claim"); + }); + }); + + describe("when the end id not less than the current id", async () => { + beforeEach(async () => { + subjectSnapshotId = await prtStakingPool.getCurrentId(); + }); + + it("should revert", async () => { + await expect(subject(bob)).to.be.revertedWith("End claim id must be less than current id"); + }); + }); + + describe("when the rewards have been claimed", async () => { + beforeEach(async () => { + await prtStakingPool.connect(bob.wallet).claimPartial(subjectSnapshotId); + }); + + it("should return 0", async () => { + await expect(subject(bob)).to.be.revertedWith("No rewards to claim"); + }); + }); + }); + describe("#transfer", async () => { let subjectAmount: BigNumber; let subjectCaller: Account; @@ -782,4 +937,121 @@ describe.only("PrtStakingPool", () => { }); }); }); + + describe("#getPendingPartialRewards", async () => { + let bobPrtAmount: BigNumber; + let alicePrtAmount: BigNumber; + let carolPrtAmount: BigNumber; + let snap1Amount: BigNumber; + let snap2Amount: BigNumber; + let snap3Amount: BigNumber; + + let subjectSnapshotId: BigNumber; + + beforeEach(async () => { + // PRT balances (bob: 6 PRT, alice: 4 PRT, carol: 5 PRT) + bobPrtAmount = ether(6); + alicePrtAmount = ether(4); + carolPrtAmount = ether(5); + + // Snapshot rewards amounts (snap1: 1 SetToken, snap2: 1.5 SetToken, snap3: 2 SetToken) + snap1Amount = ether(1); + snap2Amount = ether(1.5); + snap3Amount = ether(2); + + // Fund bob, alice, and carol with PRT + await prt.connect(owner.wallet).transfer(bob.address, bobPrtAmount); + await prt.connect(owner.wallet).transfer(alice.address, alicePrtAmount); + await prt.connect(owner.wallet).transfer(carol.address, carolPrtAmount); + + // Approve staking pool to spend PRT + await prt.connect(bob.wallet).approve(prtStakingPool.address, bobPrtAmount); + await prt.connect(alice.wallet).approve(prtStakingPool.address, alicePrtAmount); + await prt.connect(carol.wallet).approve(prtStakingPool.address, carolPrtAmount); + + // Before snapshot 1, bob stakes PRTs + await prtStakingPool.connect(bob.wallet).stake(bobPrtAmount); + + // Take snapshot 1 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap1Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap1Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap1Amount); + + // After snapshot 1, alice stakes PRTs + await prtStakingPool.connect(alice.wallet).stake(alicePrtAmount); + + // After snapshot 1, carol stakes PRTs + await prtStakingPool.connect(carol.wallet).stake(carolPrtAmount); + + // Take snapshot 2 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap2Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap2Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap2Amount); + + // After snapshot 2, carol unstakes PRTs + await prtStakingPool.connect(carol.wallet).unstake(carolPrtAmount); + + // Take snapshot 3 + await setToken.connect(owner.wallet).transfer(feeSplitExtension.address, snap3Amount); + await setToken.connect(feeSplitExtension.wallet).approve(prtStakingPool.address, snap3Amount); + await prtStakingPool.connect(feeSplitExtension.wallet).accrue(snap3Amount); + + subjectSnapshotId = TWO; // Snapshot 2 + }); + + async function subject(account: Address): Promise { + return prtStakingPool.getPendingPartialRewards(account, subjectSnapshotId); + } + + it("should return the correct pending rewards", async () => { + const bobPendingRewards = await subject(bob.address); + const alicePendingRewards = await subject(alice.address); + const carolPendingRewards = await subject(carol.address); + + const totalSupplySnap2 = await prtStakingPool.totalSupplyAt(TWO); + + // (bob) who stakes before snapshot 1 and never unstakes + const expectedBobPendingRewards = snap1Amount.add( + bobPrtAmount.mul(snap2Amount).div(totalSupplySnap2) + ); + + // (alice) who stakes after snapshot 1 and never unstakes + const expectedAlicePendingRewards = (alicePrtAmount.mul(snap2Amount).div(totalSupplySnap2)); + + // (carol) who stakes after snapshot 1 and unstakes after snapshot 2 + const expectedCarolPendingRewards = carolPrtAmount.mul(snap2Amount).div(totalSupplySnap2); + + expect(bobPendingRewards).to.eq(expectedBobPendingRewards); + expect(alicePendingRewards).to.eq(expectedAlicePendingRewards); + expect(carolPendingRewards).to.eq(expectedCarolPendingRewards); + }); + + describe("when the rewards have been claimed", async () => { + beforeEach(async () => { + await prtStakingPool.connect(bob.wallet).claimPartial(subjectSnapshotId); + }); + + it("should return 0", async () => { + const pendingRewards = await subject(bob.address); + expect(pendingRewards).to.eq(ZERO); + }); + }); + + describe("when the user never staked", async () => { + it("should return 0", async () => { + const pendingRewards = await subject(await getRandomAddress()); + expect(pendingRewards).to.eq(ZERO); + }); + }); + + describe("when the end id not less than the current id", async () => { + beforeEach(async () => { + subjectSnapshotId = await prtStakingPool.getCurrentId(); + }); + + it("should revert", async () => { + await expect(subject(bob.address)).to.be.revertedWith("End claim id must be less than current id"); + }); + }); + }); });