diff --git a/proto/pocket/supplier/query.proto b/proto/pocket/supplier/query.proto index 1af7fdb25..016eec26f 100644 --- a/proto/pocket/supplier/query.proto +++ b/proto/pocket/supplier/query.proto @@ -38,10 +38,9 @@ service Query { option (google.api.http).get = "/pocket/supplier/claims"; } - // TODO_UPNEXT(@Olshansk): Update these endpoints after implementing proof persistence // Queries a list of Proof items. rpc Proof (QueryGetProofRequest) returns (QueryGetProofResponse) { - option (google.api.http).get = "/pocket/supplier/proof/{index}"; + option (google.api.http).get = "/pocket/supplier/proof/{session_id}/{supplier_address}"; } rpc AllProofs (QueryAllProofsRequest) returns (QueryAllProofsResponse) { option (google.api.http).get = "/pocket/supplier/proof"; @@ -99,7 +98,8 @@ message QueryAllClaimsResponse { } message QueryGetProofRequest { - string index = 1; + string session_id = 1; + string supplier_address = 2 [(cosmos_proto.scalar) = "cosmos.AddressString"]; } message QueryGetProofResponse { @@ -108,6 +108,11 @@ message QueryGetProofResponse { message QueryAllProofsRequest { cosmos.base.query.v1beta1.PageRequest pagination = 1; + oneof filter { + string supplier_address = 2; + string session_id = 3; + uint64 session_end_height = 4; + } } message QueryAllProofsResponse { diff --git a/x/supplier/client/cli/flags.go b/x/supplier/client/cli/flags.go new file mode 100644 index 000000000..1755bbf2d --- /dev/null +++ b/x/supplier/client/cli/flags.go @@ -0,0 +1,7 @@ +package cli + +const ( + FlagSessionEndHeight = "session-end-height" + FlagSessionId = "session-id" + FlagSupplierAddress = "supplier-address" +) diff --git a/x/supplier/client/cli/query_claim.go b/x/supplier/client/cli/query_claim.go index 6d110c0fa..949ac0d6e 100644 --- a/x/supplier/client/cli/query_claim.go +++ b/x/supplier/client/cli/query_claim.go @@ -14,12 +14,6 @@ import ( // Prevent strconv unused error var _ = strconv.IntSize -const ( - FlagSessionEndHeight = "session-end-height" - FlagSessionId = "session-id" - FlagSupplierAddress = "supplier-address" -) - // AddPaginationFlagsToCmd adds common pagination flags to cmd func AddClaimFilterFlags(cmd *cobra.Command) { cmd.Flags().Uint64(FlagSessionEndHeight, 0, "claims whose session ends at this height will be returned") diff --git a/x/supplier/client/cli/query_proof.go b/x/supplier/client/cli/query_proof.go index e36d0f13f..091c4403c 100644 --- a/x/supplier/client/cli/query_proof.go +++ b/x/supplier/client/cli/query_proof.go @@ -1,6 +1,8 @@ package cli import ( + "fmt" + "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/flags" "github.com/spf13/cobra" @@ -8,28 +10,50 @@ import ( "github.com/pokt-network/poktroll/x/supplier/types" ) +// AddProofFilterFlagsToCmd adds common pagination flags to cmd +func AddProofFilterFlagsToCmd(cmd *cobra.Command) { + cmd.Flags().Uint64(FlagSessionEndHeight, 0, "proofs whose session ends at this height will be returned") + cmd.Flags().String(FlagSessionId, "", "proofs matching this session id will be returned") + cmd.Flags().String(FlagSupplierAddress, "", "proofs submitted by suppliers matching this address will be returned") +} + func CmdListProof() *cobra.Command { cmd := &cobra.Command{ Use: "list-proofs", Short: "list all proofs", + Long: `List all the proofs that the node being queried has in its state. + +The proofs can be optionally filtered by one of --session-end-height --session-id or --supplier-address flags + +Example: +$ poktrolld q proof list-proofs --node $(POCKET_NODE) --home=$(POKTROLLD_HOME) +$ poktrolld q proof list-proofs --session-id --node $(POCKET_NODE) --home=$(POKTROLLD_HOME) +$ poktrolld q proof list-proofs --session-end-height --node $(POCKET_NODE) --home=$(POKTROLLD_HOME) +$ poktrolld q proof list-proofs --supplier-address --node $(POCKET_NODE) --home=$(POKTROLLD_HOME)`, + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - clientCtx, err := client.GetClientQueryContext(cmd) + pageReq, err := client.ReadPageRequest(cmd.Flags()) if err != nil { return err } - pageReq, err := client.ReadPageRequest(cmd.Flags()) - if err != nil { + req := &types.QueryAllProofsRequest{ + Pagination: pageReq, + } + if err := updateProofsFilter(cmd, req); err != nil { + return err + } + if err := req.ValidateBasic(); err != nil { return err } - queryClient := types.NewQueryClient(clientCtx) - - params := &types.QueryAllProofsRequest{ - Pagination: pageReq, + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err } + queryClient := types.NewQueryClient(clientCtx) - res, err := queryClient.AllProofs(cmd.Context(), params) + res, err := queryClient.AllProofs(cmd.Context(), req) if err != nil { return err } @@ -38,19 +62,36 @@ func CmdListProof() *cobra.Command { }, } + AddProofFilterFlagsToCmd(cmd) flags.AddPaginationFlagsToCmd(cmd, cmd.Use) flags.AddQueryFlagsToCmd(cmd) return cmd } -// TODO_UPNEXT(@Olshansk): Remove the dependency on index which was part of the default scaffolding behaviour func CmdShowProof() *cobra.Command { cmd := &cobra.Command{ - Use: "show-proof ", - Short: "shows a proof", - Args: cobra.ExactArgs(1), + Use: "show-proof ", + Short: "shows a specific proof", + Long: `List a specific proof that the node being queried has access to. + +A unique proof can be defined via a session_id that a given supplier participated in. + +Example: +$ poktrolld --home=$(POKTROLLD_HOME) q proof show-proofs --node $(POCKET_NODE)`, + Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) (err error) { + sessionId := args[0] + supplierAddr := args[1] + + getProofRequest := &types.QueryGetProofRequest{ + SessionId: sessionId, + SupplierAddress: supplierAddr, + } + if err := getProofRequest.ValidateBasic(); err != nil { + return err + } + clientCtx, err := client.GetClientQueryContext(cmd) if err != nil { return err @@ -58,13 +99,7 @@ func CmdShowProof() *cobra.Command { queryClient := types.NewQueryClient(clientCtx) - argIndex := args[0] - - params := &types.QueryGetProofRequest{ - Index: argIndex, - } - - res, err := queryClient.Proof(cmd.Context(), params) + res, err := queryClient.Proof(cmd.Context(), getProofRequest) if err != nil { return err } @@ -77,3 +112,54 @@ func CmdShowProof() *cobra.Command { return cmd } + +// updateProofsFilter updates the proofs filter request based on the flags set provided +func updateProofsFilter(cmd *cobra.Command, req *types.QueryAllProofsRequest) error { + sessionId, _ := cmd.Flags().GetString(FlagSessionId) + supplierAddr, _ := cmd.Flags().GetString(FlagSupplierAddress) + sessionEndHeight, _ := cmd.Flags().GetUint64(FlagSessionEndHeight) + + // Preparing a shared error in case more than one flag was set + err := fmt.Errorf("can only specify one flag filter but got sessionId (%s), supplierAddr (%s) and sessionEngHeight (%d)", sessionId, supplierAddr, sessionEndHeight) + + // Use the session id as the filter + if sessionId != "" { + // If the session id is set, then the other flags must not be set + if supplierAddr != "" || sessionEndHeight > 0 { + return err + } + // Set the session id filter + req.Filter = &types.QueryAllProofsRequest_SessionId{ + SessionId: sessionId, + } + return nil + } + + // Use the supplier address as the filter + if supplierAddr != "" { + // If the supplier address is set, then the other flags must not be set + if sessionId != "" || sessionEndHeight > 0 { + return err + } + // Set the supplier address filter + req.Filter = &types.QueryAllProofsRequest_SupplierAddress{ + SupplierAddress: supplierAddr, + } + return nil + } + + // Use the session end height as the filter + if sessionEndHeight > 0 { + // If the session end height is set, then the other flags must not be set + if sessionId != "" || supplierAddr != "" { + return err + } + // Set the session end height filter + req.Filter = &types.QueryAllProofsRequest_SessionEndHeight{ + SessionEndHeight: sessionEndHeight, + } + return nil + } + + return nil +} diff --git a/x/supplier/client/cli/tx_submit_proof.go b/x/supplier/client/cli/tx_submit_proof.go index 0c3ba758f..32a0604d3 100644 --- a/x/supplier/client/cli/tx_submit_proof.go +++ b/x/supplier/client/cli/tx_submit_proof.go @@ -2,12 +2,13 @@ package cli import ( "encoding/base64" - "encoding/json" "strconv" "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/flags" "github.com/cosmos/cosmos-sdk/client/tx" + "github.com/cosmos/cosmos-sdk/codec" + cdctypes "github.com/cosmos/cosmos-sdk/codec/types" "github.com/spf13/cobra" sessiontypes "github.com/pokt-network/poktroll/x/session/types" @@ -24,12 +25,19 @@ func CmdSubmitProof() *cobra.Command { Short: "Broadcast message submit-proof", Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) (err error) { - argSessionHeader := new(sessiontypes.SessionHeader) - err = json.Unmarshal([]byte(args[0]), argSessionHeader) + sessionHeaderEncodedStr := args[0] + smstProofEncodedStr := args[1] + + // Get the session header + sessionHeaderBz, err := base64.StdEncoding.DecodeString(sessionHeaderEncodedStr) if err != nil { return err } - argSmstProof, err := base64.StdEncoding.DecodeString(args[1]) + sessionHeader := &sessiontypes.SessionHeader{} + cdc := codec.NewProtoCodec(cdctypes.NewInterfaceRegistry()) + cdc.MustUnmarshalJSON(sessionHeaderBz, sessionHeader) + + smstProof, err := base64.StdEncoding.DecodeString(smstProofEncodedStr) if err != nil { return err } @@ -41,8 +49,8 @@ func CmdSubmitProof() *cobra.Command { msg := types.NewMsgSubmitProof( clientCtx.GetFromAddress().String(), - argSessionHeader, - argSmstProof, + sessionHeader, + smstProof, ) if err := msg.ValidateBasic(); err != nil { return err diff --git a/x/supplier/keeper/msg_server_create_claim.go b/x/supplier/keeper/msg_server_create_claim.go index 05e21a2e4..3f4b0df30 100644 --- a/x/supplier/keeper/msg_server_create_claim.go +++ b/x/supplier/keeper/msg_server_create_claim.go @@ -5,7 +5,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" - sessiontypes "github.com/pokt-network/poktroll/x/session/types" suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) @@ -20,51 +19,18 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCrea return nil, err } - sessionReq := &sessiontypes.QueryGetSessionRequest{ - ApplicationAddress: msg.GetSessionHeader().GetApplicationAddress(), - Service: msg.GetSessionHeader().GetService(), - BlockHeight: msg.GetSessionHeader().GetSessionStartBlockHeight(), - } - sessionRes, err := k.Keeper.sessionKeeper.GetSession(goCtx, sessionReq) + session, err := k.queryAndValidateSessionHeader( + goCtx, + msg.GetSessionHeader(), + msg.GetSupplierAddress(), + ) if err != nil { return nil, err } logger. With( - "session_id", sessionRes.GetSession().GetSessionId(), - "session_end_height", msg.GetSessionHeader().GetSessionEndBlockHeight(), - "supplier", msg.GetSupplierAddress(), - ). - Debug("got sessionId for claim") - - if sessionRes.Session.SessionId != msg.SessionHeader.SessionId { - return nil, suppliertypes.ErrSupplierInvalidSessionId.Wrapf( - "claimed sessionRes ID does not match on-chain sessionRes ID; expected %q, got %q", - sessionRes.Session.SessionId, - msg.SessionHeader.SessionId, - ) - } - - var found bool - for _, supplier := range sessionRes.GetSession().GetSuppliers() { - if supplier.Address == msg.GetSupplierAddress() { - found = true - break - } - } - - if !found { - return nil, suppliertypes.ErrSupplierNotFound.Wrapf( - "supplier address %q in session ID %q", - msg.GetSupplierAddress(), - sessionRes.GetSession().GetSessionId(), - ) - } - - logger. - With( - "session_id", sessionRes.GetSession().GetSessionId(), + "session_id", session.GetSessionId(), "session_end_height", msg.GetSessionHeader().GetSessionEndBlockHeight(), "supplier", msg.GetSupplierAddress(), ). @@ -88,6 +54,9 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCrea SessionHeader: msg.GetSessionHeader(), RootHash: msg.RootHash, } + + // TODO_BLOCKER: check if this claim already exists and return an appropriate error + // in any case where the supplier should no longer be able to update the given proof. k.Keeper.UpsertClaim(ctx, claim) logger. diff --git a/x/supplier/keeper/proof.go b/x/supplier/keeper/proof.go index aec052d5b..3aacbf273 100644 --- a/x/supplier/keeper/proof.go +++ b/x/supplier/keeper/proof.go @@ -1,6 +1,8 @@ package keeper import ( + "fmt" + "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -9,52 +11,52 @@ import ( // UpsertProof inserts or updates a specific proof in the store by index. func (k Keeper) UpsertProof(ctx sdk.Context, proof types.Proof) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix)) - b := k.cdc.MustMarshal(&proof) - // TODO_NEXT(@bryanchriswhite #141): Refactor keys to support multiple indices. - store.Set(types.ProofKey( - proof.GetSessionHeader().GetSessionId(), - ), b) + logger := k.Logger(ctx).With("method", "UpsertProof") + + proofBz := k.cdc.MustMarshal(&proof) + parentStore := ctx.KVStore(k.storeKey) + + // Update the primary store containing the proof object. + primaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix)) + sessionId := proof.GetSessionHeader().GetSessionId() + primaryKey := types.ProofPrimaryKey(sessionId, proof.GetSupplierAddress()) + primaryStore.Set(primaryKey, proofBz) + + logger.Info(fmt.Sprintf("upserted proof for supplier %s with primaryKey %s", proof.GetSupplierAddress(), primaryKey)) + + // Update the address index: supplierAddress -> [ProofPrimaryKey] + addressStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofSupplierAddressPrefix)) + addressKey := types.ProofSupplierAddressKey(proof.GetSupplierAddress(), primaryKey) + addressStoreIndex.Set(addressKey, primaryKey) + + logger.Info(fmt.Sprintf("indexed Proof for supplier %s with primaryKey %s", proof.GetSupplierAddress(), primaryKey)) + + // Update the session end height index: sessionEndHeight -> [ProofPrimaryKey] + sessionHeightStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofSessionEndHeightPrefix)) + sessionEndHeight := proof.GetSessionHeader().GetSessionEndBlockHeight() + heightKey := types.ProofSupplierEndSessionHeightKey(sessionEndHeight, primaryKey) + sessionHeightStoreIndex.Set(heightKey, primaryKey) } // GetProof returns a proof from its index -func (k Keeper) GetProof( - ctx sdk.Context, - sessionId string, - -) (val types.Proof, found bool) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix)) - - // TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices. - b := store.Get(types.ProofKey( - sessionId, - )) - if b == nil { - return val, false - } - - k.cdc.MustUnmarshal(b, &val) - return val, true +func (k Keeper) GetProof(ctx sdk.Context, sessionId, supplierAdd string) (val types.Proof, found bool) { + primaryKey := types.ProofPrimaryKey(sessionId, supplierAdd) + return k.getProofByPrimaryKey(ctx, primaryKey) } // RemoveProof removes a proof from the store -func (k Keeper) RemoveProof( - ctx sdk.Context, - // TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices. - index string, - -) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix)) - // TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices. - store.Delete(types.ProofKey( - index, - )) +func (k Keeper) RemoveProof(ctx sdk.Context, sessionId, supplierAddr string) { + parentStore := ctx.KVStore(k.storeKey) + proofPrimaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix)) + proofPrimaryKey := types.ProofPrimaryKey(sessionId, supplierAddr) + proofPrimaryStore.Delete(proofPrimaryKey) } // GetAllProofs returns all proof func (k Keeper) GetAllProofs(ctx sdk.Context) (list []types.Proof) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix)) - iterator := sdk.KVStorePrefixIterator(store, []byte{}) + parentStore := ctx.KVStore(k.storeKey) + primaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix)) + iterator := sdk.KVStorePrefixIterator(primaryStore, []byte{}) defer iterator.Close() @@ -66,3 +68,16 @@ func (k Keeper) GetAllProofs(ctx sdk.Context) (list []types.Proof) { return } + +// getProofByPrimaryKey is a helper that retrieves, if exists, the Proof associated with the key provided +func (k Keeper) getProofByPrimaryKey(ctx sdk.Context, primaryKey []byte) (val types.Proof, found bool) { + store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofPrimaryKeyPrefix)) + + proofBz := store.Get(primaryKey) + if proofBz == nil { + return val, false + } + + k.cdc.MustUnmarshal(proofBz, &val) + return val, true +} diff --git a/x/supplier/keeper/proof_test.go b/x/supplier/keeper/proof_test.go index 28d73ca7f..bf34f7027 100644 --- a/x/supplier/keeper/proof_test.go +++ b/x/supplier/keeper/proof_test.go @@ -45,8 +45,10 @@ func TestProofGet(t *testing.T) { keeper, ctx := keepertest.SupplierKeeper(t, nil) proofs := createNProofs(keeper, ctx, 10) for _, proof := range proofs { - rst, found := keeper.GetProof(ctx, + rst, found := keeper.GetProof( + ctx, proof.GetSessionHeader().GetSessionId(), + proof.GetSupplierAddress(), ) require.True(t, found) require.Equal(t, @@ -57,14 +59,11 @@ func TestProofGet(t *testing.T) { } func TestProofRemove(t *testing.T) { keeper, ctx := keepertest.SupplierKeeper(t, nil) - items := createNProofs(keeper, ctx, 10) - for _, item := range items { - keeper.RemoveProof(ctx, - item.GetSessionHeader().GetSessionId(), - ) - _, found := keeper.GetProof(ctx, - item.GetSessionHeader().GetSessionId(), - ) + proofs := createNProofs(keeper, ctx, 10) + for _, proof := range proofs { + sessionId := proof.GetSessionHeader().GetSessionId() + keeper.RemoveProof(ctx, sessionId, proof.GetSupplierAddress()) + _, found := keeper.GetProof(ctx, sessionId, proof.GetSupplierAddress()) require.False(t, found) } } diff --git a/x/supplier/keeper/query_proof.go b/x/supplier/keeper/query_proof.go index 880d5c67b..cb8642679 100644 --- a/x/supplier/keeper/query_proof.go +++ b/x/supplier/keeper/query_proof.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "fmt" "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -17,19 +18,57 @@ func (k Keeper) AllProofs(goCtx context.Context, req *types.QueryAllProofsReques return nil, status.Error(codes.InvalidArgument, "invalid request") } - var proofs []types.Proof - ctx := sdk.UnwrapSDKContext(goCtx) + if err := req.ValidateBasic(); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + ctx := sdk.UnwrapSDKContext(goCtx) store := ctx.KVStore(k.storeKey) - proofStore := prefix.NewStore(store, types.KeyPrefix(types.ProofKeyPrefix)) + var ( + // isCustomIndex is used to determined if we'll be using the store that points + // to the actual Claim values, or a secondary index that points to the primary keys. + isCustomIndex bool + keyPrefix []byte + ) + + switch filter := req.Filter.(type) { + case *types.QueryAllProofsRequest_SupplierAddress: + isCustomIndex = true + keyPrefix = types.KeyPrefix(types.ProofSupplierAddressPrefix) + keyPrefix = append(keyPrefix, []byte(filter.SupplierAddress)...) + case *types.QueryAllProofsRequest_SessionEndHeight: + isCustomIndex = true + keyPrefix = types.KeyPrefix(types.ProofSessionEndHeightPrefix) + keyPrefix = append(keyPrefix, []byte(fmt.Sprintf("%d", filter.SessionEndHeight))...) + case *types.QueryAllProofsRequest_SessionId: + isCustomIndex = false + keyPrefix = types.KeyPrefix(types.ProofPrimaryKeyPrefix) + keyPrefix = append(keyPrefix, []byte(filter.SessionId)...) + default: + isCustomIndex = false + keyPrefix = types.KeyPrefix(types.ProofPrimaryKeyPrefix) + } + proofStore := prefix.NewStore(store, keyPrefix) + + var proofs []types.Proof pageRes, err := query.Paginate(proofStore, req.Pagination, func(key []byte, value []byte) error { - var proof types.Proof - if err := k.cdc.Unmarshal(value, &proof); err != nil { - return err + if isCustomIndex { + // We retrieve the primaryKey, and need to query the actual proof before decoding it. + proof, proofFound := k.getProofByPrimaryKey(ctx, value) + if proofFound { + proofs = append(proofs, proof) + } + } else { + // The value is an encoded proof. + var proof types.Proof + if err := k.cdc.Unmarshal(value, &proof); err != nil { + return err + } + + proofs = append(proofs, proof) } - proofs = append(proofs, proof) return nil }) @@ -42,16 +81,20 @@ func (k Keeper) AllProofs(goCtx context.Context, req *types.QueryAllProofsReques func (k Keeper) Proof(goCtx context.Context, req *types.QueryGetProofRequest) (*types.QueryGetProofResponse, error) { if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + err := types.ErrSupplierInvalidQueryRequest.Wrap("request cannot be nil") + return nil, status.Error(codes.InvalidArgument, err.Error()) } + + if err := req.ValidateBasic(); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + ctx := sdk.UnwrapSDKContext(goCtx) - val, found := k.GetProof( - ctx, - req.Index, - ) + val, found := k.GetProof(ctx, req.GetSessionId(), req.GetSupplierAddress()) if !found { - return nil, status.Error(codes.NotFound, "not found") + err := types.ErrSupplierProofNotFound.Wrapf("session ID %q and supplier %q", req.SessionId, req.SupplierAddress) + return nil, status.Error(codes.NotFound, err.Error()) } return &types.QueryGetProofResponse{Proof: val}, nil diff --git a/x/supplier/keeper/query_proof_test.go b/x/supplier/keeper/query_proof_test.go index dcdec4c4e..787b23d68 100644 --- a/x/supplier/keeper/query_proof_test.go +++ b/x/supplier/keeper/query_proof_test.go @@ -12,6 +12,7 @@ import ( keepertest "github.com/pokt-network/poktroll/testutil/keeper" "github.com/pokt-network/poktroll/testutil/nullify" + "github.com/pokt-network/poktroll/testutil/sample" "github.com/pokt-network/poktroll/x/supplier/types" ) @@ -21,44 +22,108 @@ var _ = strconv.IntSize func TestProofQuerySingle(t *testing.T) { keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) - msgs := createNProofs(keeper, ctx, 2) + proofs := createNProofs(keeper, ctx, 2) + + var randSupplierAddr = sample.AccAddress() tests := []struct { - desc string - request *types.QueryGetProofRequest - response *types.QueryGetProofResponse - err error + desc string + request *types.QueryGetProofRequest + response *types.QueryGetProofResponse + expectedErr error }{ { desc: "First", request: &types.QueryGetProofRequest{ - Index: msgs[0].GetSessionHeader().GetSessionId(), + SessionId: proofs[0].GetSessionHeader().GetSessionId(), + SupplierAddress: proofs[0].SupplierAddress, }, - response: &types.QueryGetProofResponse{Proof: msgs[0]}, + response: &types.QueryGetProofResponse{Proof: proofs[0]}, }, { desc: "Second", request: &types.QueryGetProofRequest{ - Index: msgs[1].GetSessionHeader().GetSessionId(), + SessionId: proofs[1].GetSessionHeader().GetSessionId(), + SupplierAddress: proofs[1].SupplierAddress, + }, + response: &types.QueryGetProofResponse{Proof: proofs[1]}, + }, + { + desc: "Proof Not Found - Random SessionId", + request: &types.QueryGetProofRequest{ + SessionId: "not a real session id", + SupplierAddress: proofs[0].GetSupplierAddress(), + }, + expectedErr: status.Error( + codes.NotFound, + types.ErrSupplierProofNotFound.Wrapf( + "session ID %q and supplier %q", + "not a real session id", + proofs[0].GetSupplierAddress(), + ).Error(), + ), + }, + { + desc: "Proof Not Found - Random Supplier Address", + request: &types.QueryGetProofRequest{ + SessionId: proofs[0].GetSessionHeader().GetSessionId(), + SupplierAddress: randSupplierAddr, }, - response: &types.QueryGetProofResponse{Proof: msgs[1]}, + expectedErr: status.Error( + codes.NotFound, + types.ErrSupplierProofNotFound.Wrapf( + "session ID %q and supplier %q", + proofs[0].GetSessionHeader().GetSessionId(), + randSupplierAddr, + ).Error(), + ), }, { - desc: "KeyNotFound", + desc: "InvalidRequest - Missing SessionId", request: &types.QueryGetProofRequest{ - Index: strconv.Itoa(100000), + // SessionId: Intentionally Omitted + SupplierAddress: proofs[0].GetSupplierAddress(), }, - err: status.Error(codes.NotFound, "not found"), + expectedErr: status.Error( + codes.InvalidArgument, + types.ErrSupplierInvalidSessionId.Wrapf( + "invalid session ID for proof being retrieved %s", + "", + ).Error(), + ), }, { - desc: "InvalidRequest", - err: status.Error(codes.InvalidArgument, "invalid request"), + desc: "InvalidRequest - Missing SupplierAddress", + request: &types.QueryGetProofRequest{ + SessionId: proofs[0].GetSessionHeader().GetSessionId(), + // SupplierAddress: Intentionally Omitted, + }, + expectedErr: status.Error( + codes.InvalidArgument, + types.ErrSupplierInvalidAddress.Wrap( + "invalid supplier address for proof being retrieved ; (empty address string is not allowed)", + ).Error(), + ), + }, + { + desc: "InvalidRequest - nil QueryGetProofRequest", + request: nil, + expectedErr: status.Error( + codes.InvalidArgument, + types.ErrSupplierInvalidQueryRequest.Wrap( + "request cannot be nil", + ).Error(), + ), }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { response, err := keeper.Proof(wctx, tc.request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) + if tc.expectedErr != nil { + actualStatus, ok := status.FromError(err) + require.True(t, ok) + + require.ErrorIs(t, actualStatus.Err(), tc.expectedErr) + require.ErrorContains(t, err, tc.expectedErr.Error()) } else { require.NoError(t, err) require.Equal(t, diff --git a/x/supplier/keeper/session.go b/x/supplier/keeper/session.go new file mode 100644 index 000000000..af24e52f0 --- /dev/null +++ b/x/supplier/keeper/session.go @@ -0,0 +1,83 @@ +package keeper + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" + + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" +) + +// queryAndValidateSessionHeader ensures that a session with the sessionID of the given session +// header exists and that this session includes the supplier with the given address. +func (k msgServer) queryAndValidateSessionHeader( + goCtx context.Context, + sessionHeader *sessiontypes.SessionHeader, + supplierAddr string, +) (*sessiontypes.Session, error) { + ctx := sdk.UnwrapSDKContext(goCtx) + logger := k.Logger(ctx).With("method", "SubmitProof") + + sessionReq := &sessiontypes.QueryGetSessionRequest{ + ApplicationAddress: sessionHeader.GetApplicationAddress(), + Service: sessionHeader.GetService(), + BlockHeight: sessionHeader.GetSessionStartBlockHeight(), + } + + // Get the on-chain session for the ground-truth against which the given + // session header is to be validated. + sessionRes, err := k.Keeper.sessionKeeper.GetSession(goCtx, sessionReq) + if err != nil { + return nil, err + } + onChainSession := sessionRes.GetSession() + + logger. + With( + "session_id", onChainSession.GetSessionId(), + "session_end_height", sessionHeader.GetSessionEndBlockHeight(), + "supplier", supplierAddr, + ). + Debug("got sessionId for proof") + + // Ensure that the given session header's session ID matches the on-chain onChainSession ID. + if sessionHeader.GetSessionId() != onChainSession.GetSessionId() { + return nil, suppliertypes.ErrSupplierInvalidSessionId.Wrapf( + "session ID does not match on-chain session ID; expected %q, got %q", + onChainSession.GetSessionId(), + sessionHeader.GetSessionId(), + ) + } + + // NB: it is redundant to assert that the service ID in the request matches the + // on-chain session service ID because the session is queried using the service + // ID as a parameter. Either a different session (i.e. different session ID) + // or an error would be returned depending on whether an application/supplier + // pair exists for the given service ID or not, respectively. + + // Ensure the given supplier is in the onChainSession supplier list. + if found := foundSupplier( + sessionRes.GetSession().GetSuppliers(), + supplierAddr, + ); !found { + return nil, suppliertypes.ErrSupplierNotFound.Wrapf( + "supplier address %q not found in session ID %q", + supplierAddr, + sessionHeader.GetSessionId(), + ) + } + + return onChainSession, nil +} + +// foundSupplier ensures that the given supplier address is in the given list of suppliers. +func foundSupplier(suppliers []*sharedtypes.Supplier, supplierAddr string) bool { + for _, supplier := range suppliers { + if supplier.Address == supplierAddr { + return true + } + } + return false +} diff --git a/x/supplier/types/key_proof.go b/x/supplier/types/key_proof.go index 6307f4ccb..f74bb7fa9 100644 --- a/x/supplier/types/key_proof.go +++ b/x/supplier/types/key_proof.go @@ -5,20 +5,32 @@ import "encoding/binary" var _ binary.ByteOrder const ( - // ProofKeyPrefix is the prefix to retrieve all Proof - ProofKeyPrefix = "Proof/value/" + // ProofPrimaryKeyPrefix is the prefix to retrieve the entire Proof object (the primary store) + ProofPrimaryKeyPrefix = "Proof/value/" + + // ProofSupplierAddressPrefix is the key to retrieve a Proof's Primary Key from the Address index + ProofSupplierAddressPrefix = "Proof/address/" + + // ProofSessionEndHeightPrefix is the key to retrieve a Proof's Primary Key from the Height index + ProofSessionEndHeightPrefix = "Proof/height/" ) -// ProofKey returns the store key to retrieve a Proof from the index fields -// TODO_UPNEXT(@Olshansk): Implement a similar indexing strategy for Proofs as we do for Claims -func ProofKey( - index string, -) []byte { - var key []byte +// ProofPrimaryKey returns the primary store key used to retrieve a Proof by creating a composite key of the sessionId and supplierAddr. +func ProofPrimaryKey(sessionId, supplierAddr string) []byte { + // We are guaranteed uniqueness of the primary key if it's a composite of the (sessionId, supplierAddr). + // because every supplier can only have one Proof per session. + return KeyComposite([]byte(sessionId), []byte(supplierAddr)) +} + +// ProofSupplierAddressKey returns the key used to iterate through Proofs given a supplier Address. +func ProofSupplierAddressKey(supplierAddr string, primaryKey []byte) []byte { + return KeyComposite([]byte(supplierAddr), primaryKey) +} - indexBytes := []byte(index) - key = append(key, indexBytes...) - key = append(key, []byte("/")...) +// ProofSupplierEndSessionHeightKey returns the key used to iterate through Proofs given a session end height. +func ProofSupplierEndSessionHeightKey(sessionEndHeight int64, primaryKey []byte) []byte { + heightBz := make([]byte, 8) + binary.BigEndian.PutUint64(heightBz, uint64(sessionEndHeight)) - return key + return KeyComposite(heightBz, primaryKey) } diff --git a/x/supplier/types/query_validation.go b/x/supplier/types/query_validation.go index 5eb2a7bb5..c7a352bd6 100644 --- a/x/supplier/types/query_validation.go +++ b/x/supplier/types/query_validation.go @@ -23,13 +23,13 @@ func (query *QueryGetClaimRequest) ValidateBasic() error { if query.SessionId == "" { return ErrSupplierInvalidSessionId.Wrapf("invalid session ID for claim being retrieved %s", query.SessionId) } + return nil } // ValidateBasic performs basic (non-state-dependant) validation on a QueryAllClaimsRequest. func (query *QueryAllClaimsRequest) ValidateBasic() error { - // TODO_TECHDEBT: update function signature to receive a context. - logger := polylog.Ctx(context.TODO()) + logger := polylog.Ctx(context.Background()) switch filter := query.Filter.(type) { case *QueryAllClaimsRequest_SupplierAddress: @@ -38,19 +38,57 @@ func (query *QueryAllClaimsRequest) ValidateBasic() error { } case *QueryAllClaimsRequest_SessionId: - // TODO_TECHDEBT: Validate the session ID once we have a deterministic way to generate it logger.Warn(). Str("session_id", filter.SessionId). - Msg("TODO: SessionID check in claim request validation is currently a noop") + Msg("TODO_TECHDEBT: Validate the session ID once we have a deterministic way to generate it") case *QueryAllClaimsRequest_SessionEndHeight: if filter.SessionEndHeight < 0 { return ErrSupplierInvalidSessionEndHeight.Wrapf("invalid session end height for claims being retrieved %d", filter.SessionEndHeight) } + } + + return nil +} + +func (query *QueryGetProofRequest) ValidateBasic() error { + // Validate the supplier address + if _, err := sdk.AccAddressFromBech32(query.SupplierAddress); err != nil { + return ErrSupplierInvalidAddress.Wrapf("invalid supplier address for proof being retrieved %s; (%v)", query.SupplierAddress, err) + } + + // TODO_TECHDEBT: Validate the session ID once we have a deterministic way to generate it + if query.SessionId == "" { + return ErrSupplierInvalidSessionId.Wrapf("invalid session ID for proof being retrieved %s", query.SessionId) + } + + return nil +} + +func (query *QueryAllProofsRequest) ValidateBasic() error { + // TODO_TECHDEBT: update function signature to receive a context. + logger := polylog.Ctx(context.TODO()) + + switch filter := query.Filter.(type) { + case *QueryAllProofsRequest_SupplierAddress: + if _, err := sdk.AccAddressFromBech32(filter.SupplierAddress); err != nil { + return ErrSupplierInvalidAddress.Wrapf("invalid supplier address for proofs being retrieved %s; (%v)", filter.SupplierAddress, err) + } + + case *QueryAllProofsRequest_SessionId: + logger.Warn(). + Str("session_id", filter.SessionId). + Msg("TODO_TECHDEBT: Validate the session ID once we have a deterministic way to generate it") + + case *QueryAllProofsRequest_SessionEndHeight: + if filter.SessionEndHeight < 0 { + return ErrSupplierInvalidSessionEndHeight.Wrapf("invalid session end height for proofs being retrieved %d", filter.SessionEndHeight) + } default: // No filter is set logger.Debug().Msg("No specific filter set when requesting claims") } + return nil }