Skip to content

Commit

Permalink
Update ETH balance go implementation to consider all ether-based chai…
Browse files Browse the repository at this point in the history
…ns (#986)

This PR also includes configuration for stream nodes that defaults to
sane values if not explicitly set. It uses the same design as existing
Xchain configuration and should be similarly forward-compatible with the
move to on-chain configuration.

This PR leaves the repo in an inconsistent state, as the client is still
evaluating eth checks on a single chain - however, the eth check is
already broken because it does not meet the spec, so I figured this was
fine.
  • Loading branch information
clemire authored Sep 9, 2024
1 parent 9e61406 commit f9daf99
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ interface IRuleEntitlementBase {
ERC721,
ERC1155,
ISENTITLED,
NATIVE_COIN_BALANCE
ETH_BALANCE
}

// Enum for Operation oneof operation_clause
Expand Down
88 changes: 60 additions & 28 deletions core/config/blockchain_info.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
package config

import "time"
import (
"context"
"time"

"github.com/river-build/river/core/node/dlog"
)

type BlockchainInfo struct {
ChainId uint64
Name string
Blocktime time.Duration
ChainId uint64
Name string
// IsEtherBased is true for chains that use Ether as the currency for fees.
IsEtherBased bool
Blocktime time.Duration
}

func GetEtherBasedBlockchains(
ctx context.Context,
chains []uint64,
defaultBlockchainInfo map[uint64]BlockchainInfo,
) []uint64 {
log := dlog.FromCtx(ctx)
etherBasedChains := make([]uint64, 0, len(chains))
for _, chainId := range chains {
if info, ok := defaultBlockchainInfo[chainId]; ok && info.IsEtherBased {
etherBasedChains = append(etherBasedChains, chainId)
} else if !ok {
log.Error("Missing BlockchainInfo for chain", "chainId", chainId)
}
}
return etherBasedChains
}

func GetDefaultBlockchainInfo() map[uint64]BlockchainInfo {
return map[uint64]BlockchainInfo{
1: {
ChainId: 1,
Name: "Ethereum Mainnet",
Blocktime: 12 * time.Second,
ChainId: 1,
Name: "Ethereum Mainnet",
Blocktime: 12 * time.Second,
IsEtherBased: true,
},
11155111: {
ChainId: 11155111,
Name: "Ethereum Sepolia",
Blocktime: 12 * time.Second,
ChainId: 11155111,
Name: "Ethereum Sepolia",
Blocktime: 12 * time.Second,
IsEtherBased: true,
},
550: {
ChainId: 550,
Expand All @@ -31,39 +57,45 @@ func GetDefaultBlockchainInfo() map[uint64]BlockchainInfo {
Blocktime: 2 * time.Second,
},
8453: {
ChainId: 8453,
Name: "Base Mainnet",
Blocktime: 2 * time.Second,
ChainId: 8453,
Name: "Base Mainnet",
Blocktime: 2 * time.Second,
IsEtherBased: true,
},
84532: {
ChainId: 84532,
Name: "Base Sepolia",
Blocktime: 2 * time.Second,
ChainId: 84532,
Name: "Base Sepolia",
Blocktime: 2 * time.Second,
IsEtherBased: true,
},
137: {
ChainId: 137,
Name: "Polygon Mainnet",
Blocktime: 2 * time.Second,
},
42161: {
ChainId: 42161,
Name: "Arbitrum One",
Blocktime: 250 * time.Millisecond,
ChainId: 42161,
Name: "Arbitrum One",
Blocktime: 250 * time.Millisecond,
IsEtherBased: true,
},
10: {
ChainId: 10,
Name: "Optimism Mainnet",
Blocktime: 2 * time.Second,
ChainId: 10,
Name: "Optimism Mainnet",
Blocktime: 2 * time.Second,
IsEtherBased: true,
},
31337: {
ChainId: 31337,
Name: "Anvil Base",
Blocktime: 2 * time.Second,
ChainId: 31337,
Name: "Anvil Base",
Blocktime: 2 * time.Second,
IsEtherBased: true,
},
31338: {
ChainId: 31338,
Name: "Anvil River",
Blocktime: 2 * time.Second,
ChainId: 31338,
Name: "Anvil River",
Blocktime: 2 * time.Second,
IsEtherBased: true, // This is set for ease of testing.
},
100: {
ChainId: 100,
Expand Down
7 changes: 4 additions & 3 deletions core/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,8 @@ func parseBlockchainDurations(str string, result map[uint64]BlockchainInfo) erro
}

func (c *Config) parseChains() error {
bcDurations := GetDefaultBlockchainInfo()
err := parseBlockchainDurations(c.ChainBlocktimes, bcDurations)
defaultChainInfo := GetDefaultBlockchainInfo()
err := parseBlockchainDurations(c.ChainBlocktimes, defaultChainInfo)
if err != nil {
return err
}
Expand All @@ -493,7 +493,7 @@ func (c *Config) parseChains() error {
return WrapRiverError(Err_BAD_CONFIG, err).Message("Failed to pase chain Id").Tag("chainId", parts[0])
}

info, ok := bcDurations[chainID]
info, ok := defaultChainInfo[chainID]
if !ok {
return RiverError(Err_BAD_CONFIG, "Chain blocktime not set").Tag("chainId", chainID)
}
Expand All @@ -515,6 +515,7 @@ func (c *Config) parseChains() error {
c.XChainBlockchains = append(c.XChainBlockchains, chainID)
}
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion core/contracts/types/entitlement_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func ConvertV1RuleDataToV2(
fallthrough
case ERC721:
fallthrough
case NATIVE_COIN_BALANCE:
case ETH_BALANCE:
params, err := (&ThresholdParams{
Threshold: checkOp.Threshold,
}).AbiEncode()
Expand Down
4 changes: 2 additions & 2 deletions core/contracts/types/entitlement_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestConvertV1RuleDataToV2(t *testing.T) {
},
},
},
"NativeCoinBalance": {
"EthBalance": {
ruleData: test_util.EthBalanceCheck(15, 1500),
expected: base.IRuleEntitlementBaseRuleDataV2{
Operations: []base.IRuleEntitlementBaseOperation{
Expand All @@ -160,7 +160,7 @@ func TestConvertV1RuleDataToV2(t *testing.T) {
},
CheckOperations: []base.IRuleEntitlementBaseCheckOperationV2{
{
OpType: uint8(types.NATIVE_COIN_BALANCE),
OpType: uint8(types.ETH_BALANCE),
ChainId: big.NewInt(15),
Params: encodeThresholdParams(t, 1500),
},
Expand Down
6 changes: 3 additions & 3 deletions core/contracts/types/entitlement_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
ERC721
ERC1155
ISENTITLED
NATIVE_COIN_BALANCE
ETH_BALANCE
)

func (t CheckOperationType) String() string {
Expand All @@ -47,8 +47,8 @@ func (t CheckOperationType) String() string {
return "ERC1155"
case ISENTITLED:
return "ISENTITLED"
case NATIVE_COIN_BALANCE:
return "NATIVE_COIN_BALANCE"
case ETH_BALANCE:
return "ETH_BALANCE"
default:
return "UNKNOWN"
}
Expand Down
2 changes: 1 addition & 1 deletion core/contracts/types/test_util/entitlements.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func EthBalanceCheck(chainId uint64, threshold uint64) base.IRuleEntitlementBase
},
CheckOperations: []base.IRuleEntitlementBaseCheckOperation{
{
OpType: uint8(contract_types.NATIVE_COIN_BALANCE),
OpType: uint8(contract_types.ETH_BALANCE),
ChainId: new(big.Int).SetUint64(chainId),
ContractAddress: common.Address{},
Threshold: new(big.Int).SetUint64(threshold),
Expand Down
80 changes: 43 additions & 37 deletions core/xchain/entitlement/check_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ func validateCheckOperation(ctx context.Context, op *types.CheckOperation) error
// 3. Threshold is positive
// 4. Token ID is non-negative
log := dlog.FromCtx(ctx).With("function", "validateCheckOperation")
if op.ChainID == nil {
if op.CheckType != types.ETH_BALANCE && op.ChainID == nil {
log.Error("Entitlement check: chain ID is nil for operation", "operation", op.CheckType.String())
return fmt.Errorf("validateCheckOperation: chain ID is nil for operation %s", op.CheckType)
}

zeroAddress := common.Address{}
if op.CheckType != types.NATIVE_COIN_BALANCE && op.ContractAddress == zeroAddress {
if op.CheckType != types.ETH_BALANCE && op.ContractAddress == zeroAddress {
log.Error("Entitlement check: contract address is nil for operation", "operation", op.CheckType.String())
return fmt.Errorf(
"validateCheckOperation: contract address is nil for operation %s",
op.CheckType,
)
}

if op.CheckType == types.ERC20 || op.CheckType == types.ERC721 || op.CheckType == types.NATIVE_COIN_BALANCE {
if op.CheckType == types.ERC20 || op.CheckType == types.ERC721 || op.CheckType == types.ETH_BALANCE {
params, err := types.DecodeThresholdParams(op.Params)
if err != nil {
log.Error(
Expand Down Expand Up @@ -148,8 +148,8 @@ func (e *Evaluator) evaluateCheckOperation(
return e.evaluateErc721Operation(ctx, op, linkedWallets)
case types.ERC1155:
return e.evaluateErc1155Operation(ctx, op, linkedWallets)
case types.NATIVE_COIN_BALANCE:
return e.evaluateNativeCoinBalanceOperation(ctx, op, linkedWallets)
case types.ETH_BALANCE:
return e.evaluateEthBalanceOperation(ctx, op, linkedWallets)
case types.CheckNONE:
fallthrough
case types.MOCK:
Expand Down Expand Up @@ -232,48 +232,54 @@ func (e *Evaluator) evaluateIsEntitledOperation(
return false, nil
}

// Check balance in decimals of native token
func (e *Evaluator) evaluateNativeCoinBalanceOperation(
// Check ETH balance, in decimals, across all supported chains that use Ether as the native token for payments.
func (e *Evaluator) evaluateEthBalanceOperation(
ctx context.Context,
op *types.CheckOperation,
linkedWallets []common.Address,
) (bool, error) {
log := dlog.FromCtx(ctx).With("function", "evaluateNativeTokenBalanceOperation")
client, err := e.clients.Get(op.ChainID.Uint64())
if err != nil {
log.Error("Chain ID not found", "chainID", op.ChainID)
return false, fmt.Errorf("evaluateNativeTokenBalanceOperation: Chain ID %v not found", op.ChainID)
}
params, err := types.DecodeThresholdParams(op.Params)
if err != nil {
log.Error("evaluateNativeCoinBalance: failed to decode threshold params", "error", err)
return false, fmt.Errorf("evaluateNativeCoinBalance: failed to decode threshold params, %w", err)
}
log := dlog.FromCtx(ctx).With("function", "evaluateEthBalanceOperation")

// Accumulator for the total balance across all chains.
total := big.NewInt(0)
for _, wallet := range linkedWallets {
// Balance is returned as a representation of the balance according the denomination of the
// native token. The default decimals for most native tokens is 18, and we don't convert
// according to decimals here, but compare the threshold directly with the balance.
balance, err := client.BalanceAt(ctx, wallet, nil)

for _, chainID := range e.ethChainIds {
log.Info("Evaluating ETH balance on chain", "chainID", chainID, "wallets", linkedWallets)
client, err := e.clients.Get(chainID)
if err != nil {
log.Error("Failed to retrieve native token balance", "chain", op.ChainID, "error", err)
return false, err
log.Error("Provider for Chain ID not found", "chainID", chainID)
return false, fmt.Errorf("evaluateEthBalanceOperation: Providerfor chain ID %v not found", chainID)
}
params, err := types.DecodeThresholdParams(op.Params)
if err != nil {
log.Error("Failed to decode threshold params", "error", err)
return false, fmt.Errorf("evaluateEthBalanceOperation: failed to decode threshold params, %w", err)
}
total.Add(total, balance)

log.Info("Retrieved native token balance",
"balance", balance.String(),
"total", total.String(),
"threshold", params.Threshold.String(),
"chainID", op.ChainID.String(),
)
for _, wallet := range linkedWallets {
// Balance is returned as a representation of the balance according the denomination of the
// ETH, which is 18. We do not convert away from decimals here, but compare the threshold
// directly with the decimalized balance.
balance, err := client.BalanceAt(ctx, wallet, nil)
if err != nil {
log.Error("Failed to retrieve ETH balance", "chain", chainID, "error", err)
return false, err
}
total.Add(total, balance)

log.Info("Accumulated ETH balance for chain",
"balance", balance.String(),
"total", total.String(),
"threshold", params.Threshold.String(),
"chainID", chainID,
)

// Balance is a *big.Int
// Iteratively check if the total balance of evaluated wallets is greater than or equal to the
// threshold. Note threshold is always positive and total is non-negative.
if total.Cmp(params.Threshold) >= 0 {
return true, nil
// Balance is a *big.Int
// Iteratively check if the total balance of evaluated wallets is greater than or equal to the
// threshold. Note threshold is always positive and total is non-negative.
if total.Cmp(params.Threshold) >= 0 {
return true, nil
}
}
}
return false, nil
Expand Down
Loading

0 comments on commit f9daf99

Please sign in to comment.