Skip to content

Commit

Permalink
Add SnapshotStakingPool and SignedSnapshotStakingPool Audit Remed…
Browse files Browse the repository at this point in the history
…iations (#4)

* add audit remediations
  • Loading branch information
pblivin0x authored Jul 8, 2024
1 parent 93ee6d1 commit 24f2c3c
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 52 deletions.
4 changes: 4 additions & 0 deletions src/interfaces/staking/ISignedSnapshotStakingPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ interface ISignedSnapshotStakingPool is ISnapshotStakingPool {
/// @param signature The signature of the message
function approveStaker(bytes calldata signature) external;

/// @notice Set the message to sign when staking
/// @param newMessage The new message
function setMessage(string memory newMessage) external;

/// @notice Get the hashed digest of the message to be signed for staking
/// @return The hashed bytes to be signed
function getStakeSignatureDigest() external view returns (bytes32);
Expand Down
29 changes: 24 additions & 5 deletions src/interfaces/staking/ISnapshotStakingPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ interface ISnapshotStakingPool is IERC20 {
/// @notice Distributor of rewards
function distributor() external view returns (address);

/// @notice Snapshot delay
/// @notice The buffer time before snapshots during which staking is not allowed
function snapshotBuffer() external view returns (uint256);

/// @notice The minimum amount of time between snapshots
function snapshotDelay() external view returns (uint256);

/// @notice Last snapshot time
Expand Down Expand Up @@ -49,18 +52,18 @@ interface ISnapshotStakingPool is IERC20 {
/// @param endSnapshotId The snapshot id to end the partial claim
function claimPartial(uint256 startSnapshotId, uint256 endSnapshotId) external;

/* ========== Admin Functions ========== */

/// @notice ONLY OWNER: Update the distributor address.
/// @param newDistributor The new distributor address
function setDistributor(address newDistributor) external;

/// @notice ONLY OWNER: Update the snapshot buffer.
/// @param newSnapshotBuffer The new snapshot buffer
function setSnapshotBuffer(uint256 newSnapshotBuffer) external;

/// @notice ONLY OWNER: Update the snapshot delay. Can set to 0 to disable snapshot delay.
/// @param newSnapshotDelay The new snapshot delay
function setSnapshotDelay(uint256 newSnapshotDelay) external;

/* ========== View Functions ========== */

/// @notice Get the current snapshot id.
/// @return The current snapshot id
function getCurrentSnapshotId() external view returns (uint256);
Expand Down Expand Up @@ -100,4 +103,20 @@ interface ISnapshotStakingPool is IERC20 {
/// @notice Get the time until the next snapshot.
/// @return The time until the next snapshot
function getTimeUntilNextSnapshot() external view returns (uint256);

/// @notice Get the next snapshot time.
/// @return The next snapshot time
function getNextSnapshotTime() external view returns (uint256);

/// @notice Check if staking is allowed.
/// @return Boolean indicating if staking is allowed
function canStake() external view returns (bool);

/// @notice Get the time until the next snapshot buffer begins.
/// @return The time until the next snapshot buffer begins
function getTimeUntilNextSnapshotBuffer() external view returns (uint256);

/// @notice Get the next snapshot buffer time.
/// @return The next snapshot buffer time
function getNextSnapshotBufferTime() external view returns (uint256);
}
41 changes: 31 additions & 10 deletions src/staking/SignedSnapshotStakingPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import {SnapshotStakingPool} from "./SnapshotStakingPool.sol";
/// @title SignedSnapshotStakingPool
/// @author Index Cooperative
/// @notice A contract for staking `stakeToken` and receiving `rewardToken` based
/// on snapshots taken when rewards are accrued.
/// on snapshots taken when rewards are accrued. Snapshots are taken at a minimum
/// interval of `snapshotDelay` seconds. Staking is not allowed `snapshotBuffer`
/// seconds before a snapshot is taken. Rewards are distributed by the `distributor`.
/// Stakers must sign an agreement `message` to stake.
contract SignedSnapshotStakingPool is ISignedSnapshotStakingPool, SnapshotStakingPool, EIP712 {
string private constant MESSAGE_TYPE = "StakeMessage(string message)";

Expand All @@ -24,6 +27,8 @@ contract SignedSnapshotStakingPool is ISignedSnapshotStakingPool, SnapshotStakin

/* EVENTS */

/// @notice Emitted when the message is changed
event MessageChanged(string newMessage);
/// @notice Emitted when a staker has message signature approved
event StakerApproved(address indexed staker);

Expand All @@ -44,6 +49,7 @@ contract SignedSnapshotStakingPool is ISignedSnapshotStakingPool, SnapshotStakin
/// @param rewardToken Instance of the reward token
/// @param stakeToken Instance of the stake token
/// @param distributor Address of the distributor
/// @param snapshotBuffer The buffer time before snapshots during which staking is not allowed
/// @param snapshotDelay The minimum amount of time between snapshots
constructor(
string memory eip712Name,
Expand All @@ -54,31 +60,39 @@ contract SignedSnapshotStakingPool is ISignedSnapshotStakingPool, SnapshotStakin
IERC20 rewardToken,
IERC20 stakeToken,
address distributor,
uint256 snapshotBuffer,
uint256 snapshotDelay
)
EIP712(eip712Name, eip712Version)
SnapshotStakingPool(name, symbol, rewardToken, stakeToken, distributor, snapshotDelay)
SnapshotStakingPool(name, symbol, rewardToken, stakeToken, distributor, snapshotBuffer, snapshotDelay)
{
message = stakeMessage;
_setMessage(stakeMessage);
}

/* STAKER FUNCTIONS */

/// @inheritdoc ISignedSnapshotStakingPool
function stake(uint256 _amount) external override(SnapshotStakingPool, ISignedSnapshotStakingPool) nonReentrant {
function stake(uint256 amount) external override(SnapshotStakingPool, ISignedSnapshotStakingPool) nonReentrant {
if (!isApprovedStaker[msg.sender]) revert NotApprovedStaker();
_stake(msg.sender, _amount);
_stake(msg.sender, amount);
}

/// @inheritdoc ISignedSnapshotStakingPool
function stake(uint256 _amount, bytes calldata _signature) external nonReentrant {
_approveStaker(msg.sender, _signature);
_stake(msg.sender, _amount);
function stake(uint256 amount, bytes calldata signature) external nonReentrant {
_approveStaker(msg.sender, signature);
_stake(msg.sender, amount);
}

/// @inheritdoc ISignedSnapshotStakingPool
function approveStaker(bytes calldata _signature) external {
_approveStaker(msg.sender, _signature);
function approveStaker(bytes calldata signature) external {
_approveStaker(msg.sender, signature);
}

/* ADMIN FUNCTIONS */

/// @inheritdoc ISignedSnapshotStakingPool
function setMessage(string memory newMessage) external onlyOwner {
_setMessage(newMessage);
}

/* VIEW FUNCTIONS */
Expand All @@ -105,4 +119,11 @@ contract SignedSnapshotStakingPool is ISignedSnapshotStakingPool, SnapshotStakin
isApprovedStaker[staker] = true;
emit StakerApproved(staker);
}

/// @dev Set the stake `message` to `newMessage`
/// @param newMessage The new message
function _setMessage(string memory newMessage) internal {
message = newMessage;
emit MessageChanged(newMessage);
}
}
111 changes: 83 additions & 28 deletions src/staking/SnapshotStakingPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@ import {ReentrancyGuard} from "@openzeppelin/contracts/security/ReentrancyGuard.
/// @title SnapshotStakingPool
/// @author Index Cooperative
/// @notice A contract for staking `stakeToken` and receiving `rewardToken` based
/// on snapshots taken when rewards are accrued.
/// on snapshots taken when rewards are accrued. Snapshots are taken at a minimum
/// interval of `snapshotDelay` seconds. Staking is not allowed `snapshotBuffer`
/// seconds before a snapshot is taken. Rewards are distributed by the `distributor`.
contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, ReentrancyGuard {

/* ERRORS */

/// @notice Error when snapshot buffer is greater than snapshot delay
error InvalidSnapshotBuffer();
/// @notice Error when snapshot delay is less than snapshot buffer
error InvalidSnapshotDelay();
/// @notice Error when staking during snapshot buffer period
error CannotStakeDuringBuffer();
/// @notice Error when accrue is called by non-distributor
error MustBeDistributor();
/// @notice Error when trying to accrue zero rewards
Expand All @@ -38,6 +46,8 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @notice Emitted when the reward distributor is changed.
event DistributorChanged(address newDistributor);
/// @notice Emitted when the snapshot buffer is changed.
event SnapshotBufferChanged(uint256 newSnapshotBuffer);
/// @notice Emitted when the snapshot delay is changed.
event SnapshotDelayChanged(uint256 newSnapshotDelay);

Expand All @@ -57,32 +67,39 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
/// @inheritdoc ISnapshotStakingPool
uint256[] public rewardSnapshots;
/// @inheritdoc ISnapshotStakingPool
uint256 public snapshotBuffer;
/// @inheritdoc ISnapshotStakingPool
uint256 public snapshotDelay;
/// @inheritdoc ISnapshotStakingPool
uint256 public lastSnapshotTime;

/* CONSTRUCTOR */

/// @param _name Name of the staked token
/// @param _symbol Symbol of the staked token
/// @param _rewardToken Instance of the reward token
/// @param _stakeToken Instance of the stake token
/// @param _distributor Address of the distributor
/// @param _snapshotDelay The minimum amount of time between snapshots
/// @param name Name of the staked token
/// @param symbol Symbol of the staked token
/// @param rewardToken_ Instance of the reward token
/// @param stakeToken_ Instance of the stake token
/// @param distributor_ Address of the distributor
/// @param snapshotBuffer_ The buffer time before snapshots during which staking is not allowed
/// @param snapshotDelay_ The minimum amount of time between snapshots
constructor(
string memory _name,
string memory _symbol,
IERC20 _rewardToken,
IERC20 _stakeToken,
address _distributor,
uint256 _snapshotDelay
string memory name,
string memory symbol,
IERC20 rewardToken_,
IERC20 stakeToken_,
address distributor_,
uint256 snapshotBuffer_,
uint256 snapshotDelay_
)
ERC20(_name, _symbol)
ERC20(name, symbol)
{
rewardToken = _rewardToken;
stakeToken = _stakeToken;
distributor = _distributor;
snapshotDelay = _snapshotDelay;
if (snapshotBuffer_ > snapshotDelay_) revert InvalidSnapshotBuffer();
rewardToken = rewardToken_;
stakeToken = stakeToken_;
distributor = distributor_;
snapshotBuffer = snapshotBuffer_;
snapshotDelay = snapshotDelay_;
lastSnapshotTime = block.timestamp;
}

/* MODIFIERS */
Expand Down Expand Up @@ -138,8 +155,16 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
emit DistributorChanged(newDistributor);
}

/// @inheritdoc ISnapshotStakingPool
function setSnapshotBuffer(uint256 newSnapshotBuffer) external onlyOwner {
if (newSnapshotBuffer > snapshotDelay) revert InvalidSnapshotBuffer();
snapshotBuffer = newSnapshotBuffer;
emit SnapshotBufferChanged(newSnapshotBuffer);
}

/// @inheritdoc ISnapshotStakingPool
function setSnapshotDelay(uint256 newSnapshotDelay) external onlyOwner {
if (snapshotBuffer > newSnapshotDelay) revert InvalidSnapshotDelay();
snapshotDelay = newSnapshotDelay;
emit SnapshotDelayChanged(newSnapshotDelay);
}
Expand Down Expand Up @@ -167,19 +192,15 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
function getPendingRewards(address account) public view returns (uint256) {
uint256 currentId = _getCurrentSnapshotId();
uint256 lastId = nextClaimId[account];
return rewardOfInRange(account, lastId, currentId);
if (lastId == 0 || currentId == 0 || lastId > currentId) return 0;
return _rewardOfInRange(account, lastId, currentId);
}

/// @inheritdoc ISnapshotStakingPool
function rewardOfInRange(address account, uint256 startSnapshotId, uint256 endSnapshotId) public view returns (uint256) {
if (startSnapshotId == 0) revert InvalidSnapshotId();
if (startSnapshotId > endSnapshotId || endSnapshotId > _getCurrentSnapshotId()) revert NonExistentSnapshotId();

uint256 rewards = 0;
for (uint256 i = startSnapshotId; i <= endSnapshotId; i++) {
rewards += _rewardOfAt(account, i);
}
return rewards;
return _rewardOfInRange(account, startSnapshotId, endSnapshotId);
}

/// @inheritdoc ISnapshotStakingPool
Expand All @@ -203,25 +224,51 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re

/// @inheritdoc ISnapshotStakingPool
function getLifetimeRewards(address account) public view returns (uint256) {
return rewardOfInRange(account, 1, _getCurrentSnapshotId());
uint256 currentId = _getCurrentSnapshotId();
if (nextClaimId[account] == 0 || currentId == 0) return 0;
return _rewardOfInRange(account, 1, currentId);
}

/// @inheritdoc ISnapshotStakingPool
function canAccrue() public view returns (bool) {
return block.timestamp >= lastSnapshotTime + snapshotDelay;
return block.timestamp >= getNextSnapshotTime();
}

/// @inheritdoc ISnapshotStakingPool
function getTimeUntilNextSnapshot() public view returns (uint256) {
if (canAccrue()) {
return 0;
}
return (lastSnapshotTime + snapshotDelay) - block.timestamp;
return getNextSnapshotTime() - block.timestamp;
}

/// @inheritdoc ISnapshotStakingPool
function getNextSnapshotTime() public view returns (uint256) {
return lastSnapshotTime + snapshotDelay;
}

/// @inheritdoc ISnapshotStakingPool
function canStake() public view returns (bool) {
return block.timestamp < getNextSnapshotBufferTime();
}

/// @inheritdoc ISnapshotStakingPool
function getTimeUntilNextSnapshotBuffer() public view returns (uint256) {
if (!canStake()) {
return 0;
}
return getNextSnapshotBufferTime() - block.timestamp;
}

/// @inheritdoc ISnapshotStakingPool
function getNextSnapshotBufferTime() public view returns (uint256) {
return getNextSnapshotTime() - snapshotBuffer;
}

/* INTERNAL FUNCTIONS */

function _stake(address account, uint256 amount) internal {
if (!canStake()) revert CannotStakeDuringBuffer();
if (nextClaimId[account] == 0) {
uint256 currentId = _getCurrentSnapshotId();
nextClaimId[account] = currentId > 0 ? currentId : 1;
Expand All @@ -243,4 +290,12 @@ contract SnapshotStakingPool is ISnapshotStakingPool, Ownable, ERC20Snapshot, Re
function _rewardOfAt(address account, uint256 snapshotId) internal view returns (uint256) {
return _rewardAt(snapshotId) * balanceOfAt(account, snapshotId) / totalSupplyAt(snapshotId);
}

function _rewardOfInRange(address account, uint256 startSnapshotId, uint256 endSnapshotId) internal view returns (uint256) {
uint256 rewards = 0;
for (uint256 i = startSnapshotId; i <= endSnapshotId; i++) {
rewards += _rewardOfAt(account, i);
}
return rewards;
}
}
3 changes: 3 additions & 0 deletions test/staking/HyEthSnapshotStakingPool.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ contract HyEthSnapshotStakingPoolTest is Test {
IDebtIssuanceModuleV2 issuanceModule = IDebtIssuanceModuleV2(issuanceModuleAddress);
IStreamingFeeModule streamingFeeModule = IStreamingFeeModule(streamingFeeModuleAddress);

uint256 public snapshotBuffer = 1 days;
uint256 public snapshotDelay = 30 days;

SnapshotStakingPool public snapshotStakingPool;
Expand All @@ -55,6 +56,7 @@ contract HyEthSnapshotStakingPoolTest is Test {
IERC20(hyEthAddress),
prtHyEth,
prtFeeSplitExtensionAddress,
snapshotBuffer,
snapshotDelay
);

Expand Down Expand Up @@ -87,6 +89,7 @@ contract HyEthSnapshotStakingPoolTest is Test {
_stake(alice.addr, 1 ether);
_stake(bob.addr, 1 ether);

vm.warp(block.timestamp + snapshotDelay + 1);
prtFeeSplitExtension.accrueFeesAndDistribute();

assert(hyEth.balanceOf(address(snapshotStakingPool)) > 0);
Expand Down
Loading

0 comments on commit 24f2c3c

Please sign in to comment.