Skip to content

Commit

Permalink
refactor: shared query client
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite committed Dec 13, 2024
1 parent b24b3a5 commit d630484
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 85 deletions.
6 changes: 3 additions & 3 deletions pkg/client/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ type SessionQueryClient interface {
// SharedQueryClient defines an interface that enables the querying of the
// on-chain shared module params.
type SharedQueryClient interface {
// GetParams queries the chain for the current shared module parameters.
GetParams(ctx context.Context) (*sharedtypes.Params, error)
ParamsQuerier[*sharedtypes.Params]

// GetSessionGracePeriodEndHeight returns the block height at which the grace period
// for the session that includes queryHeight elapses.
// The grace period is the number of blocks after the session ends during which relays
Expand All @@ -320,7 +320,7 @@ type SharedQueryClient interface {
// for the session that includes queryHeight can be committed for a given supplier.
GetEarliestSupplierProofCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error)
// GetComputeUnitsToTokensMultiplier returns the multiplier used to convert compute units to tokens.
GetComputeUnitsToTokensMultiplier(ctx context.Context) (uint64, error)
GetComputeUnitsToTokensMultiplier(ctx context.Context, queryHeight int64) (uint64, error)
}

// BlockQueryClient defines an interface that enables the querying of
Expand Down
92 changes: 47 additions & 45 deletions pkg/client/query/sharedquerier.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,60 @@ import (
var _ client.SharedQueryClient = (*sharedQuerier)(nil)

// sharedQuerier is a wrapper around the sharedtypes.QueryClient that enables the
// querying of on-chain shared information through a single exposed method
// which returns an sharedtypes.Session struct
// querying of on-chain shared information.
type sharedQuerier struct {
client.ParamsQuerier[*sharedtypes.Params]

clientConn grpc.ClientConn
sharedQuerier sharedtypes.QueryClient
blockQuerier client.BlockQueryClient
}

// NewSharedQuerier returns a new instance of a client.SharedQueryClient by
// injecting the dependecies provided by the depinject.Config.
// injecting the dependencies provided by the depinject.Config.
//
// Required dependencies:
// - clientCtx (grpc.ClientConn)
// - client.BlockQueryClient
func NewSharedQuerier(deps depinject.Config) (client.SharedQueryClient, error) {
querier := &sharedQuerier{}
func NewSharedQuerier(
deps depinject.Config,
paramsQuerierOpts ...ParamsQuerierOptionFn,
) (client.SharedQueryClient, error) {
paramsQuerierCfg := DefaultParamsQuerierConfig()
for _, opt := range paramsQuerierOpts {
opt(paramsQuerierCfg)
}

paramsQuerier, err := NewCachedParamsQuerier[*sharedtypes.Params, sharedtypes.SharedQueryClient](
deps, sharedtypes.NewSharedQueryClient,
WithModuleInfo(sharedtypes.ModuleName, sharedtypes.ErrSharedParamInvalid),
WithQueryCacheOptions(paramsQuerierCfg.CacheOpts...),
)
if err != nil {
return nil, err
}

sq := &sharedQuerier{
ParamsQuerier: paramsQuerier,
}

if err := depinject.Inject(
if err = depinject.Inject(
deps,
&querier.clientConn,
&querier.blockQuerier,
&sq.clientConn,
&sq.blockQuerier,
); err != nil {
return nil, err
}

querier.sharedQuerier = sharedtypes.NewQueryClient(querier.clientConn)
sq.sharedQuerier = sharedtypes.NewQueryClient(sq.clientConn)

return querier, nil
}

// GetParams queries & returns the shared module on-chain parameters.
//
// TODO_TECHDEBT(#543): We don't really want to have to query the params for every method call.
// Once `ModuleParamsClient` is implemented, use its replay observable's `#Last()` method
// to get the most recently (asynchronously) observed (and cached) value.
func (sq *sharedQuerier) GetParams(ctx context.Context) (*sharedtypes.Params, error) {
req := &sharedtypes.QueryParamsRequest{}
res, err := sq.sharedQuerier.Params(ctx, req)
if err != nil {
return nil, ErrQuerySessionParams.Wrapf("[%v]", err)
}
return &res.Params, nil
return sq, nil
}

// GetClaimWindowOpenHeight returns the block height at which the claim window of
// the session that includes queryHeight opens.
//
// TODO_MAINNET(#543): We don't really want to have to query the params for every method call.
// Once `ModuleParamsClient` is implemented, use its replay observable's `#Last()` method
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite,#543): We also don't really want to use the current value of the params. Instead,
// we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetClaimWindowOpenHeight(ctx context.Context, queryHeight int64) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand All @@ -75,14 +75,8 @@ func (sq *sharedQuerier) GetClaimWindowOpenHeight(ctx context.Context, queryHeig

// GetProofWindowOpenHeight returns the block height at which the proof window of
// the session that includes queryHeight opens.
//
// TODO_MAINNET(#543): We don't really want to have to query the params for every method call.
// Once `ModuleParamsClient` is implemented, use its replay observable's `#Last()` method
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite,#543): We also don't really want to use the current value of the params. Instead,
// we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetProofWindowOpenHeight(ctx context.Context, queryHeight int64) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand All @@ -103,7 +97,7 @@ func (sq *sharedQuerier) GetSessionGracePeriodEndHeight(
ctx context.Context,
queryHeight int64,
) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand All @@ -118,8 +112,12 @@ func (sq *sharedQuerier) GetSessionGracePeriodEndHeight(
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params.
// Instead, we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(
ctx context.Context,
queryHeight int64,
supplierOperatorAddr string,
) (int64, error) {
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -151,8 +149,12 @@ func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Contex
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params.
// Instead, we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(
ctx context.Context,
queryHeight int64,
supplierOperatorAddr string,
) (int64, error) {
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -180,8 +182,8 @@ func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(ctx context.Contex
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params.
// Instead, we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetComputeUnitsToTokensMultiplier(ctx context.Context) (uint64, error) {
sharedParams, err := sq.GetParams(ctx)
func (sq *sharedQuerier) GetComputeUnitsToTokensMultiplier(ctx context.Context, queryHeight int64) (uint64, error) {
sharedParams, err := sq.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}
Expand Down
123 changes: 123 additions & 0 deletions pkg/client/query/sharedquerier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package query_test

import (
"context"
"testing"
"time"

"cosmossdk.io/depinject"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"

"github.com/pokt-network/poktroll/pkg/client"
"github.com/pokt-network/poktroll/pkg/client/query"
"github.com/pokt-network/poktroll/pkg/client/query/cache"
_ "github.com/pokt-network/poktroll/pkg/polylog/polyzero"
"github.com/pokt-network/poktroll/testutil/mockclient"
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

type SharedQuerierTestSuite struct {
suite.Suite
ctrl *gomock.Controller
ctx context.Context
querier client.SharedQueryClient
TTL time.Duration
clientConnMock *mockclient.MockClientConn
blockClientMock *mockclient.MockCometRPC
}

func TestSharedQuerierSuite(t *testing.T) {
suite.Run(t, new(SharedQuerierTestSuite))
}

func (s *SharedQuerierTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.ctx = context.Background()
s.clientConnMock = mockclient.NewMockClientConn(s.ctrl)
s.blockClientMock = mockclient.NewMockCometRPC(s.ctrl)
s.TTL = 200 * time.Millisecond

deps := depinject.Supply(s.clientConnMock, s.blockClientMock)

// Create a shared querier with test-specific cache settings.
querier, err := query.NewSharedQuerier(deps,
query.WithQueryCacheOptions(
cache.WithTTL(s.TTL),
cache.WithHistoricalMode(100),
),
)
require.NoError(s.T(), err)
require.NotNil(s.T(), querier)

s.querier = querier
}

func (s *SharedQuerierTestSuite) TearDownTest() {
s.ctrl.Finish()
}

func (s *SharedQuerierTestSuite) TestRetrievesAndCachesParamsValues() {
multiplier := uint64(1000)

s.expectMockConnToReturnParamsWithMultiplierOnce(multiplier)

// Initial get should be a cache miss.
params1, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params1.ComputeUnitsToTokensMultiplier)

// Second get should be a cache hit.
params2, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params2.ComputeUnitsToTokensMultiplier)

// Third get, after 90% of the TTL - should still be a cache hit.
time.Sleep(time.Duration(float64(s.TTL) * .9))
params3, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params3.ComputeUnitsToTokensMultiplier)
}

func (s *SharedQuerierTestSuite) TestHandlesCacheExpiration() {
s.expectMockConnToReturnParamsWithMultiplierOnce(2000)

params1, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(uint64(2000), params1.ComputeUnitsToTokensMultiplier)

// Wait for cache to expire
time.Sleep(300 * time.Millisecond)

// Next query should be a cache miss again.
s.expectMockConnToReturnParamsWithMultiplierOnce(3000)

params2, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(uint64(3000), params2.ComputeUnitsToTokensMultiplier)
}

// expectMockConnToReturnParamsWithMultiplerOnce registers an expectation on s.clientConnMock
// such that this test will fail if the mock connection doesn't see exactly one params request.
// When it does see the params request, it will respond with a sharedtypes.Params object where
// the ComputeUnitsToTokensMultiplier field is set to the given multiplier.
func (s *SharedQuerierTestSuite) expectMockConnToReturnParamsWithMultiplierOnce(multiplier uint64) {
s.clientConnMock.EXPECT().
Invoke(
gomock.Any(),
"/poktroll.shared.Query/Params",
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
DoAndReturn(func(_ context.Context, _ string, _, reply any, _ ...grpc.CallOption) error {
resp := reply.(*sharedtypes.QueryParamsResponse)
params := sharedtypes.DefaultParams()
params.ComputeUnitsToTokensMultiplier = multiplier

resp.Params = params
return nil
}).Times(1)
}
Loading

0 comments on commit d630484

Please sign in to comment.