Skip to content

Commit

Permalink
[Supplier] refactor: claim & proof protobufs + (#263)
Browse files Browse the repository at this point in the history
* refactor: `NewMinedRelay` to shared testutil

* refactor: claim & proof protobuf types

* refactor: rename supplier keeper `UpsertClaim` & `UpsertProof`

* refactor: misc. claim-side improvements

* chore: add TODOs

* chore: review feedback improvements

Co-authored-by: harry <[email protected]>
Co-authored-by: Daniel Olshansky <[email protected]>

* chore: review feedback improvements

Co-authored-by: Daniel Olshansky <[email protected]>
Co-authored-by: harry <[email protected]>

* chore: add TODOs

* trigger CI

* chore: add TODO

* chore: review feedback improvements

Co-authored-by: Daniel Olshansky <[email protected]>

* fix: usage raw string literal

---------

Co-authored-by: harry <[email protected]>
Co-authored-by: Daniel Olshansky <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent 8f61a4c commit 6b67ff9
Show file tree
Hide file tree
Showing 16 changed files with 144 additions and 90 deletions.
8 changes: 5 additions & 3 deletions proto/pocket/supplier/claim.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ package pocket.supplier;
option go_package = "github.com/pokt-network/poktroll/x/supplier/types";

import "cosmos_proto/cosmos.proto";
import "pocket/session/session.proto";

// Claim is the serialized object stored on-chain for claims pending to be proven
message Claim {
string supplier_address = 1 [(cosmos_proto.scalar) = "cosmos.AddressString"]; // the address of the supplier that submitted this claim
string session_id = 2; // session id from the SessionHeader
uint64 session_end_block_height = 3; // session end block height from the SessionHeader
bytes root_hash = 4; // smt.SMST#Root()
// The session header of the session that this claim is for.
session.SessionHeader session_header = 2;
// Root hash returned from smt.SMST#Root().
bytes root_hash = 3;
}
14 changes: 8 additions & 6 deletions proto/pocket/supplier/proof.proto
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
syntax = "proto3";
package pocket.supplier;

import "cosmos_proto/cosmos.proto";
import "pocket/session/session.proto";

option go_package = "github.com/pokt-network/poktroll/x/supplier/types";

// TODO_UPNEXT(@Olshansk): The structure below is the default (untouched) scaffolded type. Update
// and productionize it for our use case.
message Proof {
string index = 1;
string supplier_address = 2;
string session_id = 3;
string merkle_proof = 4;
string supplier_address = 1 [(cosmos_proto.scalar) = "cosmos.AddressString"];
// The session header of the session that this claim is for.
session.SessionHeader session_header = 2;
// The serialized SMST proof from the `#ClosestProof()` method.
bytes closest_merkle_proof = 3;
}

5 changes: 3 additions & 2 deletions proto/pocket/supplier/query.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ syntax = "proto3";

package pocket.supplier;

import "cosmos_proto/cosmos.proto";
import "cosmos/base/query/v1beta1/pagination.proto";
import "gogoproto/gogo.proto";
import "google/api/annotations.proto";
import "cosmos/base/query/v1beta1/pagination.proto";
import "pocket/supplier/params.proto";
import "pocket/shared/supplier.proto";
import "pocket/supplier/claim.proto";
Expand Down Expand Up @@ -76,7 +77,7 @@ message QueryAllSupplierResponse {

message QueryGetClaimRequest {
string session_id = 1;
string supplier_address = 2;
string supplier_address = 2 [(cosmos_proto.scalar) = "cosmos.AddressString"];
}

message QueryGetClaimResponse {
Expand Down
13 changes: 9 additions & 4 deletions x/supplier/client/cli/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,15 @@ func createClaim(

// TODO_TECHDEBT: Forward the actual claim in the response once the response is updated to return it.
return &types.Claim{
SupplierAddress: supplierAddr,
SessionId: sessionId,
SessionEndBlockHeight: uint64(sessionEndHeight),
RootHash: rootHash,
SupplierAddress: supplierAddr,
SessionHeader: &sessiontypes.SessionHeader{
ApplicationAddress: appAddress,
Service: &sharedtypes.Service{Id: testServiceId},
SessionId: sessionId,
SessionStartBlockHeight: sessionStartHeight,
SessionEndBlockHeight: sessionEndHeight,
},
RootHash: rootHash,
}
}

Expand Down
10 changes: 7 additions & 3 deletions x/supplier/client/cli/query_claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ func CmdShowClaim() *cobra.Command {
cmd := &cobra.Command{
Use: "show-claim <session_id> <supplier_addr>",
Short: "shows a specific claim",
Long: `List a specific claim that the node being queried has access to (if it still exists)
Long: `List a specific claim that the node being queried has access to (if it still exists).
A unique claim can be defined via a session_id that a supplier participated in
A unique claim can be defined via a ` + "`session_id`" + ` that the given ` + "`supplier`" + ` participated in.
` + "`Claims`" + ` are pruned, according to protocol parameters, some time after their respective ` + "`proof`" + ` has been submitted and any dispute window has elapsed.
This is done to minimize the rate at which state accumulates by eliminating claims as a long-term factor to persistence requirements.
Example:
$ poktrolld --home=$(POKTROLLD_HOME) q claim show-claims <session_id> <supplier_address> --node $(POCKET_NODE)`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) (err error) {
RunE: func(cmd *cobra.Command, args []string) error {
sessionId := args[0]
supplierAddr := args[1]

Expand Down
34 changes: 17 additions & 17 deletions x/supplier/client/cli/query_claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,29 @@ func TestClaim_Show(t *testing.T) {
sessionId string
supplierAddr string

args []string
err error
obj types.Claim
args []string
err error
claim types.Claim
}{
{
desc: "claim found",
sessionId: claims[0].SessionId,
supplierAddr: claims[0].SupplierAddress,
sessionId: claims[0].GetSessionHeader().GetSessionId(),
supplierAddr: claims[0].GetSupplierAddress(),

args: common,
obj: claims[0],
args: common,
claim: claims[0],
},
{
desc: "claim not found (wrong session ID)",
sessionId: "wrong_session_id",
supplierAddr: claims[0].SupplierAddress,
supplierAddr: claims[0].GetSupplierAddress(),

args: common,
err: status.Error(codes.NotFound, "not found"),
},
{
desc: "claim not found (wrong supplier address)",
sessionId: claims[0].SessionId,
sessionId: claims[0].GetSessionHeader().GetSessionId(),
supplierAddr: "wrong_supplier_address",

args: common,
Expand All @@ -82,10 +82,10 @@ func TestClaim_Show(t *testing.T) {
var resp types.QueryGetClaimResponse
require.NoError(t, net.Config.Codec.UnmarshalJSON(out.Bytes(), &resp))
require.NotNil(t, resp.Claim)
require.Equal(t,
nullify.Fill(&tc.obj),
nullify.Fill(&resp.Claim),
)

require.Equal(t, tc.claim.GetSupplierAddress(), resp.Claim.GetSupplierAddress())
require.Equal(t, tc.claim.GetRootHash(), resp.Claim.GetRootHash())
require.Equal(t, tc.claim.GetSessionHeader(), resp.Claim.GetSessionHeader())
}
})
}
Expand Down Expand Up @@ -187,13 +187,13 @@ func TestClaim_List(t *testing.T) {
})

t.Run("BySession", func(t *testing.T) {
sessionId := claims[0].SessionId
sessionId := claims[0].GetSessionHeader().SessionId
args := prepareArgs(nil, 0, uint64(totalClaims), true)
args = append(args, fmt.Sprintf("--%s=%s", cli.FlagSessionId, sessionId))

expectedClaims := make([]types.Claim, 0)
for _, claim := range claims {
if claim.SessionId == sessionId {
if claim.GetSessionHeader().SessionId == sessionId {
expectedClaims = append(expectedClaims, claim)
}
}
Expand All @@ -212,13 +212,13 @@ func TestClaim_List(t *testing.T) {
})

t.Run("ByHeight", func(t *testing.T) {
sessionEndHeight := claims[0].SessionEndBlockHeight
sessionEndHeight := claims[0].GetSessionHeader().GetSessionEndBlockHeight()
args := prepareArgs(nil, 0, uint64(totalClaims), true)
args = append(args, fmt.Sprintf("--%s=%d", cli.FlagSessionEndHeight, sessionEndHeight))

expectedClaims := make([]types.Claim, 0)
for _, claim := range claims {
if claim.SessionEndBlockHeight == sessionEndHeight {
if claim.GetSessionHeader().GetSessionEndBlockHeight() == sessionEndHeight {
expectedClaims = append(expectedClaims, claim)
}
}
Expand Down
19 changes: 11 additions & 8 deletions x/supplier/keeper/claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@ import (
"github.com/pokt-network/poktroll/x/supplier/types"
)

// InsertClaim adds a claim to the store
func (k Keeper) InsertClaim(ctx sdk.Context, claim types.Claim) {
logger := k.Logger(ctx).With("method", "InsertClaim")
// UpsertClaim inserts or updates a specific claim in the store by index.
func (k Keeper) UpsertClaim(ctx sdk.Context, claim types.Claim) {
logger := k.Logger(ctx).With("method", "UpsertClaim")

claimBz := k.cdc.MustMarshal(&claim)
parentStore := ctx.KVStore(k.storeKey)

// Update the primary store: ClaimPrimaryKey -> ClaimObject
primaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ClaimPrimaryKeyPrefix))
primaryKey := types.ClaimPrimaryKey(claim.SessionId, claim.SupplierAddress)
sessionId := claim.GetSessionHeader().GetSessionId()
primaryKey := types.ClaimPrimaryKey(sessionId, claim.SupplierAddress)
primaryStore.Set(primaryKey, claimBz)

logger.Info(fmt.Sprintf("inserted claim for supplier %s with primaryKey %s", claim.SupplierAddress, primaryKey))
logger.Info(fmt.Sprintf("upserted claim for supplier %s with primaryKey %s", claim.SupplierAddress, primaryKey))

// Update the address index: supplierAddress -> [ClaimPrimaryKey]
addressStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ClaimSupplierAddressPrefix))
Expand All @@ -33,10 +34,11 @@ func (k Keeper) InsertClaim(ctx sdk.Context, claim types.Claim) {

// Update the session end height index: sessionEndHeight -> [ClaimPrimaryKey]
sessionHeightStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ClaimSessionEndHeightPrefix))
heightKey := types.ClaimSupplierEndSessionHeightKey(claim.SessionEndBlockHeight, primaryKey)
sessionEndBlockHeight := uint64(claim.GetSessionHeader().GetSessionEndBlockHeight())
heightKey := types.ClaimSupplierEndSessionHeightKey(sessionEndBlockHeight, primaryKey)
sessionHeightStoreIndex.Set(heightKey, primaryKey)

logger.Info(fmt.Sprintf("indexed claim for supplier %s at session ending height %d", claim.SupplierAddress, claim.SessionEndBlockHeight))
logger.Info(fmt.Sprintf("indexed claim for supplier %s at session ending height %d", claim.SupplierAddress, sessionEndBlockHeight))
}

// RemoveClaim removes a claim from the store
Expand All @@ -59,7 +61,8 @@ func (k Keeper) RemoveClaim(ctx sdk.Context, sessionId, supplierAddr string) {
sessionHeightStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ClaimSessionEndHeightPrefix))

addressKey := types.ClaimSupplierAddressKey(claim.SupplierAddress, primaryKey)
heightKey := types.ClaimSupplierEndSessionHeightKey(claim.SessionEndBlockHeight, primaryKey)
sessionEndBlockHeight := uint64(claim.GetSessionHeader().GetSessionEndBlockHeight())
heightKey := types.ClaimSupplierEndSessionHeightKey(sessionEndBlockHeight, primaryKey)

// Delete all the entries (primary store and secondary indices)
primaryStore.Delete(primaryKey)
Expand Down
22 changes: 14 additions & 8 deletions x/supplier/keeper/claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
keepertest "github.com/pokt-network/poktroll/testutil/keeper"
"github.com/pokt-network/poktroll/testutil/nullify"
"github.com/pokt-network/poktroll/testutil/sample"
sessiontypes "github.com/pokt-network/poktroll/x/session/types"
"github.com/pokt-network/poktroll/x/supplier/keeper"
"github.com/pokt-network/poktroll/x/supplier/types"
)
Expand All @@ -22,10 +23,12 @@ func createNClaims(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.Claim
claims := make([]types.Claim, n)
for i := range claims {
claims[i].SupplierAddress = sample.AccAddress()
claims[i].SessionId = fmt.Sprintf("session-%d", i)
claims[i].SessionEndBlockHeight = uint64(i)
claims[i].SessionHeader = &sessiontypes.SessionHeader{
SessionId: fmt.Sprintf("session-%d", i),
SessionEndBlockHeight: int64(i),
}
claims[i].RootHash = []byte(fmt.Sprintf("rootHash-%d", i))
keeper.InsertClaim(ctx, claims[i])
keeper.UpsertClaim(ctx, claims[i])
}
return claims
}
Expand All @@ -35,7 +38,7 @@ func TestClaim_Get(t *testing.T) {
claims := createNClaims(keeper, ctx, 10)
for _, claim := range claims {
foundClaim, isClaimFound := keeper.GetClaim(ctx,
claim.SessionId,
claim.GetSessionHeader().GetSessionId(),
claim.SupplierAddress,
)
require.True(t, isClaimFound)
Expand All @@ -49,12 +52,13 @@ func TestClaim_Remove(t *testing.T) {
keeper, ctx := keepertest.SupplierKeeper(t, nil)
claims := createNClaims(keeper, ctx, 10)
for _, claim := range claims {
sessionId := claim.GetSessionHeader().GetSessionId()
keeper.RemoveClaim(ctx,
claim.SessionId,
sessionId,
claim.SupplierAddress,
)
_, isClaimFound := keeper.GetClaim(ctx,
claim.SessionId,
sessionId,
claim.SupplierAddress,
)
require.False(t, isClaimFound)
Expand Down Expand Up @@ -90,7 +94,8 @@ func TestClaim_GetAll_ByHeight(t *testing.T) {
claims := createNClaims(keeper, ctx, 10)

// Get all claims for a given ending session block height
allFoundClaimsEndingAtHeight := keeper.GetClaimsByHeight(ctx, claims[6].SessionEndBlockHeight)
sessionEndHeight := claims[6].GetSessionHeader().GetSessionEndBlockHeight()
allFoundClaimsEndingAtHeight := keeper.GetClaimsByHeight(ctx, uint64(sessionEndHeight))
require.ElementsMatch(t,
nullify.Fill([]types.Claim{claims[6]}),
nullify.Fill(allFoundClaimsEndingAtHeight),
Expand All @@ -102,7 +107,8 @@ func TestClaim_GetAll_BySession(t *testing.T) {
claims := createNClaims(keeper, ctx, 10)

// Get all claims for a given ending session block height
allFoundClaimsForSession := keeper.GetClaimsBySession(ctx, claims[8].SessionId)
sessionId := claims[8].GetSessionHeader().GetSessionId()
allFoundClaimsForSession := keeper.GetClaimsBySession(ctx, sessionId)
require.ElementsMatch(t,
nullify.Fill([]types.Claim{claims[8]}),
nullify.Fill(allFoundClaimsForSession),
Expand Down
18 changes: 10 additions & 8 deletions x/supplier/keeper/msg_server_create_claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
)

func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCreateClaim) (*suppliertypes.MsgCreateClaimResponse, error) {
// TODO_BLOCKER: Prevent Claim upserts after the ClaimWindow is closed.
// TODO_BLOCKER: Validate the signature on the Claim message corresponds to the supplier before Upserting.

ctx := sdk.UnwrapSDKContext(goCtx)
logger := k.Logger(ctx).With("method", "CreateClaim")

Expand Down Expand Up @@ -79,19 +82,18 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCrea
2. [ ] msg distribution validation
*/

// Construct and insert claim after all validation.
// Construct and upsert claim after all validation.
claim := suppliertypes.Claim{
SupplierAddress: msg.GetSupplierAddress(),
SessionId: msg.GetSessionHeader().GetSessionId(),
SessionEndBlockHeight: uint64(msg.GetSessionHeader().GetSessionEndBlockHeight()),
RootHash: msg.RootHash,
SupplierAddress: msg.GetSupplierAddress(),
SessionHeader: msg.GetSessionHeader(),
RootHash: msg.RootHash,
}
k.Keeper.InsertClaim(ctx, claim)
k.Keeper.UpsertClaim(ctx, claim)

logger.
With(
"session_id", claim.GetSessionId(),
"session_end_height", claim.GetSessionEndBlockHeight(),
"session_id", claim.GetSessionHeader().GetSessionId(),
"session_end_height", claim.GetSessionHeader().GetSessionEndBlockHeight(),
"supplier", claim.GetSupplierAddress(),
).
Debug("created claim")
Expand Down
8 changes: 4 additions & 4 deletions x/supplier/keeper/msg_server_create_claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ func TestMsgServer_CreateClaim_Success(t *testing.T) {

claims := claimRes.GetClaim()
require.Lenf(t, claims, 1, "expected 1 claim, got %d", len(claims))
require.Equal(t, claimMsg.SessionHeader.SessionId, claims[0].SessionId)
require.Equal(t, claimMsg.SupplierAddress, claims[0].SupplierAddress)
require.Equal(t, uint64(claimMsg.SessionHeader.GetSessionEndBlockHeight()), claims[0].SessionEndBlockHeight)
require.Equal(t, claimMsg.RootHash, claims[0].RootHash)
require.Equal(t, claimMsg.SessionHeader.SessionId, claims[0].GetSessionHeader().GetSessionId())
require.Equal(t, claimMsg.SupplierAddress, claims[0].GetSupplierAddress())
require.Equal(t, claimMsg.SessionHeader.GetSessionEndBlockHeight(), claims[0].GetSessionHeader().GetSessionEndBlockHeight())
require.Equal(t, claimMsg.RootHash, claims[0].GetRootHash())
}

func TestMsgServer_CreateClaim_Error(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions x/supplier/keeper/msg_server_submit_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
)

func (k msgServer) SubmitProof(goCtx context.Context, msg *types.MsgSubmitProof) (*types.MsgSubmitProofResponse, error) {
// TODO_BLOCKER: Prevent Proof upserts after the tokenomics module has processes the respective session.
// TODO_BLOCKER: Validate the signature on the Proof message corresponds to the supplier before Upserting.

ctx := sdk.UnwrapSDKContext(goCtx)

if err := msg.ValidateBasic(); err != nil {
Expand Down
Loading

0 comments on commit 6b67ff9

Please sign in to comment.