Skip to content

Commit

Permalink
Enforce rules entitlement checks on isEntitledToSpace (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
brianathere authored May 21, 2024
1 parent 0b2d214 commit cd6d529
Show file tree
Hide file tree
Showing 26 changed files with 265 additions and 235 deletions.
2 changes: 1 addition & 1 deletion core/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ COPY --from=builder /bin/xchain_node /usr/bin/xchain_node
RUN setcap 'cap_net_bind_service=+ep' /usr/bin/stream_node

COPY --from=builder /build/node/default_config.yaml /riveruser/stream_node/config/config.yaml
COPY --from=builder /build/xchain/default_config.yaml /riveruser/xchain_node/config/config.yaml
COPY --from=builder /build/node/default_config.yaml /riveruser/xchain_node/config/config.yaml

RUN mkdir -p /riveruser/stream_node/logs
RUN mkdir -p /riveruser/xchain_node/logs
Expand Down
105 changes: 75 additions & 30 deletions core/node/auth/auth_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand All @@ -13,12 +14,13 @@ import (
"github.com/river-build/river/core/node/infra"
. "github.com/river-build/river/core/node/protocol"
"github.com/river-build/river/core/node/shared"
"github.com/river-build/river/core/xchain/entitlement"

"github.com/ethereum/go-ethereum/common"
)

type ChainAuth interface {
IsEntitled(ctx context.Context, args *ChainAuthArgs) error
IsEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) error
}

var everyone = common.HexToAddress("0x1") // This represents an Ethereum address of "0x1"
Expand Down Expand Up @@ -57,11 +59,12 @@ const (
)

type ChainAuthArgs struct {
kind chainAuthKind
spaceId shared.StreamId
channelId shared.StreamId
principal common.Address
permission Permission
kind chainAuthKind
spaceId shared.StreamId
channelId shared.StreamId
principal common.Address
permission Permission
linkedWallets string // a serialized list of linked wallets to comply with the cache key constraints
}

// Replaces principal with given wallet and returns new copy of args.
Expand All @@ -71,6 +74,19 @@ func (args *ChainAuthArgs) withWallet(wallet common.Address) *ChainAuthArgs {
return &ret
}

func (args *ChainAuthArgs) withLinkedWallets(linkedWallets []common.Address) *ChainAuthArgs {
ret := *args
var builder strings.Builder
for i, addr := range linkedWallets {
if i > 0 {
builder.WriteString(",")
}
builder.WriteString(addr.Hex())
}
ret.linkedWallets = builder.String()
return &ret
}

func newArgsForEnabledSpace(spaceId shared.StreamId) *ChainAuthArgs {
return &ChainAuthArgs{
kind: chainAuthKindSpaceEnabled,
Expand Down Expand Up @@ -163,10 +179,11 @@ func NewChainAuth(
}, nil
}

func (ca *chainAuth) IsEntitled(ctx context.Context, args *ChainAuthArgs) error {
func (ca *chainAuth) IsEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) error {
// TODO: counter for cache hits here?
result, _, err := ca.entitlementCache.executeUsingCache(
ctx,
cfg,
args,
ca.checkEntitlement,
)
Expand All @@ -190,28 +207,29 @@ func (ca *chainAuth) IsEntitled(ctx context.Context, args *ChainAuthArgs) error
return nil
}

func (ca *chainAuth) isWalletEntitled(ctx context.Context, args *ChainAuthArgs) (bool, error) {
func (ca *chainAuth) isWalletEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) {
log := dlog.FromCtx(ctx)
if args.kind == chainAuthKindSpace {
log.Debug("isWalletEntitled", "kind", "space", "args", args)
return ca.isEntitledToSpace(ctx, args)
return ca.isEntitledToSpace(ctx, cfg, args)
} else if args.kind == chainAuthKindChannel {
log.Debug("isWalletEntitled", "kind", "channel", "args", args)
return ca.isEntitledToChannel(ctx, args)
return ca.isEntitledToChannel(ctx, cfg, args)
} else {
return false, RiverError(Err_INTERNAL, "Unknown chain auth kind").Func("isWalletEntitled")
}
}

func (ca *chainAuth) isSpaceEnabledUncached(ctx context.Context, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isSpaceEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
// This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse
isDisabled, err := ca.spaceContract.IsSpaceDisabled(ctx, args.spaceId)
return &boolCacheResult{allowed: !isDisabled}, err
}

func (ca *chainAuth) checkSpaceEnabled(ctx context.Context, spaceId shared.StreamId) error {
func (ca *chainAuth) checkSpaceEnabled(ctx context.Context, cfg *config.Config, spaceId shared.StreamId) error {
isEnabled, cacheHit, err := ca.entitlementCache.executeUsingCache(
ctx,
cfg,
newArgsForEnabledSpace(spaceId),
ca.isSpaceEnabledUncached,
)
Expand All @@ -231,19 +249,21 @@ func (ca *chainAuth) checkSpaceEnabled(ctx context.Context, spaceId shared.Strea
}
}

func (ca *chainAuth) isChannelEnabledUncached(ctx context.Context, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isChannelEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
// This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse
isDisabled, err := ca.spaceContract.IsChannelDisabled(ctx, args.spaceId, args.channelId)
return &boolCacheResult{allowed: !isDisabled}, err
}

func (ca *chainAuth) checkChannelEnabled(
ctx context.Context,
cfg *config.Config,
spaceId shared.StreamId,
channelId shared.StreamId,
) error {
isEnabled, cacheHit, err := ca.entitlementCache.executeUsingCache(
ctx,
cfg,
newArgsForEnabledChannel(spaceId, channelId),
ca.isChannelEnabledUncached,
)
Expand Down Expand Up @@ -280,6 +300,7 @@ func (scr *entitlementCacheResult) IsAllowed() bool {
// If the call fails or the space is not found, the allowed flag is set to false so the negative caching time applies.
func (ca *chainAuth) getSpaceEntitlementsForPermissionUncached(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
log := dlog.FromCtx(ctx)
Expand All @@ -301,11 +322,21 @@ func (ca *chainAuth) getSpaceEntitlementsForPermissionUncached(
return &entitlementCacheResult{allowed: true, entitlementData: entitlementData, owner: owner}, nil
}

func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, args *ChainAuthArgs) (CacheResult, error) {
func deserializeWallets(serialized string) []common.Address {
addressStrings := strings.Split(serialized, ",")
linkedWallets := make([]common.Address, len(addressStrings))
for i, addrStr := range addressStrings {
linkedWallets[i] = common.HexToAddress(addrStr)
}
return linkedWallets
}

func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
log := dlog.FromCtx(ctx)
log.Debug("isEntitledToSpaceUncached", "args", args)
result, cacheHit, err := ca.entitlementManagerCache.executeUsingCache(
ctx,
cfg,
args,
ca.getSpaceEntitlementsForPermissionUncached,
)
Expand Down Expand Up @@ -333,12 +364,25 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, args *ChainA

entitlementData := temp.(*entitlementCacheResult) // Assuming result is of *entitlementCacheResult type
log.Debug("entitlementData", "args", args, "entitlementData", entitlementData)
for _, entitlement := range entitlementData.entitlementData {
log.Debug("entitlement", "entitlement", entitlement)
if entitlement.entitlementType == "RuleEntitlement" {
// TODO implement rule entitlment
} else if entitlement.entitlementType == "UserEntitlement" {
for _, user := range entitlement.userEntitlement {
for _, ent := range entitlementData.entitlementData {
log.Debug("entitlement", "entitlement", ent)
if ent.entitlementType == "RuleEntitlement" {
re := ent.ruleEntitlement
log.Debug("RuleEntitlement", "ruleEntitlement", re)
result, err := entitlement.EvaluateRuleData(ctx, cfg, deserializeWallets(args.linkedWallets), re)

if err != nil {
return &boolCacheResult{allowed: false}, AsRiverError(err).Func("isEntitledToSpace")
}
if result {
log.Debug("rule entitlement is true", "spaceId", args.spaceId)
return &boolCacheResult{allowed: true}, nil
} else {
log.Debug("rule entitlement is false", "spaceId", args.spaceId)
return &boolCacheResult{allowed: false}, nil
}
} else if ent.entitlementType == "UserEntitlement" {
for _, user := range ent.userEntitlement {
if user == everyone {
log.Debug("everyone is entitled to space", "spaceId", args.spaceId)
return &boolCacheResult{allowed: true}, nil
Expand All @@ -348,19 +392,19 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, args *ChainA
}
}
} else {
log.Warn("Invalid entitlement type", "entitlement", entitlement)
log.Warn("Invalid entitlement type", "entitlement", ent)
}
}

return &boolCacheResult{allowed: false}, nil
}

func (ca *chainAuth) isEntitledToSpace(ctx context.Context, args *ChainAuthArgs) (bool, error) {
func (ca *chainAuth) isEntitledToSpace(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) {
if args.kind != chainAuthKindSpace {
return false, RiverError(Err_INTERNAL, "Wrong chain auth kind")
}

isEntitled, cacheHit, err := ca.entitlementCache.executeUsingCache(ctx, args, ca.isEntitledToSpaceUncached)
isEntitled, cacheHit, err := ca.entitlementCache.executeUsingCache(ctx, cfg, args, ca.isEntitledToSpaceUncached)
if err != nil {
return false, err
}
Expand All @@ -373,7 +417,7 @@ func (ca *chainAuth) isEntitledToSpace(ctx context.Context, args *ChainAuthArgs)
return isEntitled.IsAllowed(), nil
}

func (ca *chainAuth) isEntitledToChannelUncached(ctx context.Context, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isEntitledToChannelUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
allowed, err := ca.spaceContract.IsEntitledToChannel(
ctx,
args.spaceId,
Expand All @@ -384,12 +428,12 @@ func (ca *chainAuth) isEntitledToChannelUncached(ctx context.Context, args *Chai
return &boolCacheResult{allowed: allowed}, err
}

func (ca *chainAuth) isEntitledToChannel(ctx context.Context, args *ChainAuthArgs) (bool, error) {
func (ca *chainAuth) isEntitledToChannel(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) {
if args.kind != chainAuthKindChannel {
return false, RiverError(Err_INTERNAL, "Wrong chain auth kind")
}

isEntitled, cacheHit, err := ca.entitlementCache.executeUsingCache(ctx, args, ca.isEntitledToChannelUncached)
isEntitled, cacheHit, err := ca.entitlementCache.executeUsingCache(ctx, cfg, args, ca.isEntitledToChannelUncached)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -452,19 +496,19 @@ func (ca *chainAuth) checkMembership(
* If any of the operations fail before getting positive result, the whole operation fails.
* A prerequisite for this function is that one of the linked wallets is a member of the space.
*/
func (ca *chainAuth) checkEntitlement(ctx context.Context, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) checkEntitlement(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
log := dlog.FromCtx(ctx)

ctx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(ca.contractCallsTimeoutMs))
defer cancel()

if args.kind == chainAuthKindSpace {
err := ca.checkSpaceEnabled(ctx, args.spaceId)
err := ca.checkSpaceEnabled(ctx, cfg, args.spaceId)
if err != nil {
return &boolCacheResult{allowed: false}, nil
}
} else if args.kind == chainAuthKindChannel {
err := ca.checkChannelEnabled(ctx, args.spaceId, args.channelId)
err := ca.checkChannelEnabled(ctx, cfg, args.spaceId, args.channelId)
if err != nil {
return &boolCacheResult{allowed: false}, nil
}
Expand All @@ -480,6 +524,7 @@ func (ca *chainAuth) checkEntitlement(ctx context.Context, args *ChainAuthArgs)

// Add the root key to the list of wallets.
wallets = append(wallets, args.principal)
args = args.withLinkedWallets(wallets)

isMemberCtx, isMemberCancel := context.WithCancel(ctx)
defer isMemberCancel()
Expand Down Expand Up @@ -534,7 +579,7 @@ func (ca *chainAuth) checkEntitlement(ctx context.Context, args *ChainAuthArgs)
wg.Add(1)
go func(address common.Address) {
defer wg.Done()
result, err := ca.isWalletEntitled(ctx, args.withWallet(address))
result, err := ca.isWalletEntitled(ctx, cfg, args.withWallet(address))
resultsChan <- entitlementCheckResult{allowed: result, err: err}
}(wallet)
}
Expand Down
5 changes: 3 additions & 2 deletions core/node/auth/auth_impl_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ func newEntitlementManagerCache(ctx context.Context, cfg *config.ChainConfig) (*

func (ec *entitlementCache) executeUsingCache(
ctx context.Context,
cfg *config.Config,
key *ChainAuthArgs,
onMiss func(context.Context, *ChainAuthArgs) (CacheResult, error),
onMiss func(context.Context, *config.Config, *ChainAuthArgs) (CacheResult, error),
) (CacheResult, bool, error) {
// Check positive cache first
if val, ok := ec.positiveCache.Get(*key); ok {
Expand All @@ -167,7 +168,7 @@ func (ec *entitlementCache) executeUsingCache(
}

// Cache miss, execute the closure
result, err := onMiss(ctx, key)
result, err := onMiss(ctx, cfg, key)
if err != nil {
return nil, false, err
}
Expand Down
8 changes: 6 additions & 2 deletions core/node/auth/auth_impl_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func TestCache(t *testing.T) {
ctx, cancel := test.NewTestContext()
defer cancel()

cfg := &config.Config{}

c, err := newEntitlementCache(
ctx,
&config.ChainConfig{
Expand All @@ -41,8 +43,9 @@ func TestCache(t *testing.T) {
var cacheMissForReal bool
result, cacheHit, err := c.executeUsingCache(
ctx,
cfg,
NewChainAuthArgsForChannel(spaceId, channelId, "3", PermissionWrite),
func(context.Context, *ChainAuthArgs) (CacheResult, error) {
func(context.Context, *config.Config, *ChainAuthArgs) (CacheResult, error) {
cacheMissForReal = true
return &simpleCacheResult{allowed: true}, nil
},
Expand All @@ -55,8 +58,9 @@ func TestCache(t *testing.T) {
cacheMissForReal = false
result, cacheHit, err = c.executeUsingCache(
ctx,
cfg,
NewChainAuthArgsForChannel(spaceId, channelId, "3", PermissionWrite),
func(context.Context, *ChainAuthArgs) (CacheResult, error) {
func(context.Context, *config.Config, *ChainAuthArgs) (CacheResult, error) {
cacheMissForReal = true
return &simpleCacheResult{allowed: false}, nil
},
Expand Down
4 changes: 3 additions & 1 deletion core/node/auth/fake_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package auth

import (
"context"

"github.com/river-build/river/core/node/config"
)

// This checkers always returns true, used for some testing scenarios.
Expand All @@ -13,6 +15,6 @@ type fakeChainAuth struct{}

var _ ChainAuth = (*fakeChainAuth)(nil)

func (a *fakeChainAuth) IsEntitled(ctx context.Context, args *ChainAuthArgs) error {
func (a *fakeChainAuth) IsEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) error {
return nil
}
5 changes: 2 additions & 3 deletions core/node/auth/space_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package auth
import (
"context"

v3 "github.com/river-build/river/core/xchain/contracts/v3"

"github.com/ethereum/go-ethereum/common"
"github.com/river-build/river/core/node/shared"
"github.com/river-build/river/core/xchain/contracts"
)

type SpaceEntitlements struct {
entitlementType string
ruleEntitlement v3.IRuleEntitlementRuleData
ruleEntitlement *contracts.IRuleData
userEntitlement []common.Address
}

Expand Down
Loading

0 comments on commit cd6d529

Please sign in to comment.