diff --git a/e2e/tests/session.feature b/e2e/tests/session.feature index 5b36fe6c8..1623b0acf 100644 --- a/e2e/tests/session.feature +++ b/e2e/tests/session.feature @@ -1,5 +1,7 @@ Feature: Session Namespace + # TODO_TECHDEBT(@Olshansk, #180): This test requires you to run `make supplier1_stake && make app1_stake` first + # As a shorter workaround, we can also add steps that stake the application and supplier as part of the scenario. Scenario: Supplier completes claim/proof lifecycle for a valid session Given the user has the pocketd binary installed When the supplier "supplier1" has serviced a session with "5" relays for service "svc1" for application "app1" @@ -7,7 +9,10 @@ Feature: Session Namespace Then the claim created by supplier "supplier1" for service "svc1" for application "app1" should be persisted on-chain # TODO_IMPROVE: ... # And an event should be emitted... -# TODO_INCOMPLETE: add step(s) for proof validation. + And after the supplier submits a proof for the session for service "svc1" for application "app1" + Then the proof submitted by supplier "supplier1" for service "svc1" for application "app1" should be persisted on-chain +# TODO_IMPROVE: ... +# And an event should be emitted... # TODO_BLOCKER(@red-0ne): Make sure to implement and validate this test # One way to exercise this behavior is to close the `RelayMiner` port to prevent @@ -34,4 +39,4 @@ Feature: Session Namespace # And the supllier "supplier1" calls GetSession and gets session number "3" # Then the supplier "supplier1" replys to application "app1" with a "session mismatch" error relay response # And the application "app1" receives a failed relay response with a "session mismatch" error -# And the supplier "supplier1" do not update a claim for session number "1" for service "svc1" for application "app1" \ No newline at end of file +# And the supplier "supplier1" do not update a claim for session number "1" for service "svc1" for application "app1" diff --git a/e2e/tests/session_steps_test.go b/e2e/tests/session_steps_test.go index 96248e243..dff132d9c 100644 --- a/e2e/tests/session_steps_test.go +++ b/e2e/tests/session_steps_test.go @@ -10,76 +10,41 @@ import ( "strings" "time" + "cosmossdk.io/depinject" abci "github.com/cometbft/cometbft/abci/types" + "github.com/stretchr/testify/require" + + "github.com/pokt-network/poktroll/pkg/client" "github.com/pokt-network/poktroll/pkg/client/events" - "github.com/pokt-network/poktroll/pkg/either" - "github.com/pokt-network/poktroll/pkg/observable" "github.com/pokt-network/poktroll/pkg/observable/channel" "github.com/pokt-network/poktroll/testutil/testclient" suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" - "github.com/stretchr/testify/require" ) const ( - createClaimTimeoutDuration = 10 * time.Second - eitherEventsReplayBufferSize = 100 - msgClaimSenderQueryFmt = "tm.event='Tx' AND message.sender='%s' AND message.action='/pocket.supplier.MsgCreateClaim'" - testServiceId = "anvil" - eitherEventsBzReplayObsKey = "eitherEventsBzReplayObsKey" - preExistingClaimsKey = "preExistingClaimsKey" + // txEventTimeout is the duration of time to wait after sending a valid tx + // before the test should time out (fail). + txEventTimeout = 10 * time.Second + // txSenderEventSubscriptionQueryFmt is the format string which yields the + // cosmos-sdk event subscription "query" string for a given sender address. + // This is used by an events replay client to subscribe to tx events from the supplier. + // See: https://docs.cosmos.network/v0.47/learn/advanced/events#subscribing-to-events + txSenderEventSubscriptionQueryFmt = "tm.event='Tx' AND message.sender='%s'" + testEventsReplayClientBufferSize = 100 + testServiceId = "anvil" + // eventsReplayClientKey is the suite#scenarioState key for the events replay client + // which is subscribed to tx events where the tx sender is the scenario's supplier. + eventsReplayClientKey = "eventsReplayClientKey" + // preExistingClaimsKey is the suite#scenarioState key for any pre-existing + // claims when querying for all claims prior to running the scenario. + preExistingClaimsKey = "preExistingClaimsKey" + // preExistingProofsKey is the suite#scenarioState key for any pre-existing + // proofs when querying for all proofs prior to running the scenario. + preExistingProofsKey = "preExistingProofsKey" ) func (s *suite) AfterTheSupplierCreatesAClaimForTheSessionForServiceForApplication(serviceId, appName string) { - ctx, done := context.WithCancel(context.Background()) - - // TODO_CONSIDERATION: if this test suite gets more complex, it might make - // sense to refactor this key into a function that takes serviceId and appName - // as arguments and returns the key. - eitherEventsBzReplayObs := s.scenarioState[eitherEventsBzReplayObsKey].(observable.ReplayObservable[either.Bytes]) - - // TODO(#220): refactor to use EventsReplayClient once available. - channel.ForEach[either.Bytes]( - ctx, eitherEventsBzReplayObs, - func(_ context.Context, eitherEventBz either.Bytes) { - eventBz, err := eitherEventBz.ValueOrError() - require.NoError(s, err) - - if strings.Contains(string(eventBz), "jsonrpc") { - return - } - - // Unmarshal event data into a TxEventResponse object. - txEvent := &abci.TxResult{} - err = json.Unmarshal(eventBz, txEvent) - require.NoError(s, err) - - var found bool - for _, event := range txEvent.Result.Events { - for _, attribute := range event.Attributes { - if attribute.Key == "action" { - require.Equal( - s, "/pocket.supplier.MsgCreateClaim", - attribute.Value, - ) - found = true - break - } - } - if found { - break - } - } - require.Truef(s, found, "unable to find event action attribute") - - done() - }, - ) - - select { - case <-ctx.Done(): - case <-time.After(createClaimTimeoutDuration): - s.Fatal("timed out waiting for claim to be created") - } + s.waitForMessageAction("/pocket.supplier.MsgCreateClaim") } func (s *suite) TheClaimCreatedBySupplierForServiceForApplicationShouldBePersistedOnchain(supplierName, serviceId, appName string) { @@ -94,7 +59,8 @@ func (s *suite) TheClaimCreatedBySupplierForServiceForApplicationShouldBePersist require.NotNil(s, allClaimsRes) // Assert that the number of claims has increased by one. - preExistingClaims := s.scenarioState[preExistingClaimsKey].([]suppliertypes.Claim) + preExistingClaims, ok := s.scenarioState[preExistingClaimsKey].([]suppliertypes.Claim) + require.True(s, ok, "preExistingClaimsKey not found in scenarioState") // NB: We are avoiding the use of require.Len here because it provides unreadable output // TODO_TECHDEBT: Due to the speed of the blocks of the LocalNet sequencer, along with the small number // of blocks per session, multiple claims may be created throughout the duration of the test. Until @@ -119,23 +85,43 @@ func (s *suite) TheSupplierHasServicedASessionWithRelaysForServiceForApplication relayCount, err := strconv.Atoi(relayCountStr) require.NoError(s, err) - // Query for any existing claims so that we can compensate for them in the + // Query for any existing claims so that we can compare against them in // future assertions about changes in on-chain claims. allClaimsRes, err := s.supplierQueryClient.AllClaims(ctx, &suppliertypes.QueryAllClaimsRequest{}) require.NoError(s, err) s.scenarioState[preExistingClaimsKey] = allClaimsRes.Claim + // Query for any existing proofs so that we can compare against them in + // future assertions about changes in on-chain proofs. + allProofsRes, err := s.supplierQueryClient.AllProofs(ctx, &suppliertypes.QueryAllProofsRequest{}) + require.NoError(s, err) + s.scenarioState[preExistingProofsKey] = allProofsRes.Proof + // Construct an events query client to listen for tx events from the supplier. - msgSenderQuery := fmt.Sprintf(msgClaimSenderQueryFmt, accNameToAddrMap[supplierName]) + msgSenderQuery := fmt.Sprintf(txSenderEventSubscriptionQueryFmt, accNameToAddrMap[supplierName]) + + deps := depinject.Supply(events.NewEventsQueryClient(testclient.CometLocalWebsocketURL)) + eventsReplayClient, err := events.NewEventsReplayClient[*abci.TxResult]( + ctx, + deps, + msgSenderQuery, + func(eventBz []byte) (*abci.TxResult, error) { + if strings.Contains(string(eventBz), "jsonrpc") { + return nil, nil + } - // TODO_TECHDEBT(#220): refactor to use EventsReplayClient once available. - eventsQueryClient := events.NewEventsQueryClient(testclient.CometLocalWebsocketURL) - eitherEventsBzObs, err := eventsQueryClient.EventsBytes(ctx, msgSenderQuery) + // Unmarshal event data into an ABCI TxResult object. + txResult := &abci.TxResult{} + err = json.Unmarshal(eventBz, txResult) + require.NoError(s, err) + + return txResult, nil + }, + testEventsReplayClientBufferSize, + ) require.NoError(s, err) - eitherEventsBytesObs := observable.Observable[either.Bytes](eitherEventsBzObs) - eitherEventsBzRelayObs := channel.ToReplayObservable(ctx, eitherEventsReplayBufferSize, eitherEventsBytesObs) - s.scenarioState[eitherEventsBzReplayObsKey] = eitherEventsBzRelayObs + s.scenarioState[eventsReplayClientKey] = eventsReplayClient s.sendRelaysForSession( appName, @@ -145,6 +131,42 @@ func (s *suite) TheSupplierHasServicedASessionWithRelaysForServiceForApplication ) } +func (s *suite) AfterTheSupplierSubmitsAProofForTheSessionForServiceForApplication(a string, b string) { + s.waitForMessageAction("/pocket.supplier.MsgSubmitProof") +} + +func (s *suite) TheProofSubmittedBySupplierForServiceForApplicationShouldBePersistedOnchain(supplierName, serviceId, appName string) { + ctx := context.Background() + + // Retrieve all on-chain proofs for supplierName + allProofsRes, err := s.supplierQueryClient.AllProofs(ctx, &suppliertypes.QueryAllProofsRequest{ + Filter: &suppliertypes.QueryAllProofsRequest_SupplierAddress{ + SupplierAddress: accNameToAddrMap[supplierName], + }, + }) + require.NoError(s, err) + require.NotNil(s, allProofsRes) + + // Assert that the number of proofs has increased by one. + preExistingProofs, ok := s.scenarioState[preExistingProofsKey].([]suppliertypes.Proof) + require.True(s, ok, "preExistingProofsKey not found in scenarioState") + // NB: We are avoiding the use of require.Len here because it provides unreadable output + // TODO_TECHDEBT: Due to the speed of the blocks of the LocalNet sequencer, along with the small number + // of blocks per session, multiple proofs may be created throughout the duration of the test. Until + // these values are appropriately adjusted, we assert on an increase in proofs rather than +1. + require.Greater(s, len(allProofsRes.Proof), len(preExistingProofs), "number of proofs must have increased") + + // TODO_UPNEXT(@bryanchriswhite): assert that the root hash of the proof contains the correct + // SMST sum. The sum can be retrieved via the `GetSum` function exposed + // by the SMT. + + // TODO_IMPROVE: add assertions about serviceId and appName and/or incorporate + // them into the scenarioState key(s). + + proof := allProofsRes.Proof[0] + require.Equal(s, accNameToAddrMap[supplierName], proof.SupplierAddress) +} + func (s *suite) sendRelaysForSession( appName string, supplierName string, @@ -163,3 +185,42 @@ func (s *suite) sendRelaysForSession( s.TheApplicationReceivesASuccessfulRelayResponseSignedBy(appName, supplierName) } } + +// waitForMessageAction waits for an event to be observed which has the given message action. +func (s *suite) waitForMessageAction(action string) { + ctx, done := context.WithCancel(context.Background()) + + eventsReplayClient, ok := s.scenarioState[eventsReplayClientKey].(client.EventsReplayClient[*abci.TxResult]) + require.True(s, ok, "eventsReplayClientKey not found in scenarioState") + require.NotNil(s, eventsReplayClient) + + // For each observed event, **asynchronously** check if it contains the given action. + channel.ForEach[*abci.TxResult]( + ctx, eventsReplayClient.EventsSequence(ctx), + func(_ context.Context, txEvent *abci.TxResult) { + if txEvent == nil { + return + } + + // Range over each event's attributes to find the "action" attribute + // and compare its value to that of the action provided. + for _, event := range txEvent.Result.Events { + for _, attribute := range event.Attributes { + if attribute.Key == "action" { + if attribute.Value == action { + done() + return + } + } + } + } + }, + ) + + select { + case <-time.After(txEventTimeout): + s.Fatalf("timed out waiting for message with action %q", action) + case <-ctx.Done(): + s.Log("Success; message detected before timeout.") + } +} diff --git a/proto/pocket/supplier/query.proto b/proto/pocket/supplier/query.proto index 016eec26f..ac0bad9de 100644 --- a/proto/pocket/supplier/query.proto +++ b/proto/pocket/supplier/query.proto @@ -93,6 +93,7 @@ message QueryAllClaimsRequest { } message QueryAllClaimsResponse { + // TODO_IMPROVE: Rename to `Claims` (plural). repeated Claim claim = 1 [(gogoproto.nullable) = false]; cosmos.base.query.v1beta1.PageResponse pagination = 2; } @@ -116,6 +117,7 @@ message QueryAllProofsRequest { } message QueryAllProofsResponse { + // TODO_IMPROVE: Rename to `Proofs` (plural). repeated Proof proof = 1 [(gogoproto.nullable) = false]; cosmos.base.query.v1beta1.PageResponse pagination = 2; } diff --git a/x/session/keeper/query_get_session.go b/x/session/keeper/query_get_session.go index e2d4681af..cce44fc2c 100644 --- a/x/session/keeper/query_get_session.go +++ b/x/session/keeper/query_get_session.go @@ -16,7 +16,7 @@ func (k Keeper) GetSession(goCtx context.Context, req *types.QueryGetSessionRequ } if err := req.ValidateBasic(); err != nil { - return nil, err + return nil, status.Error(codes.InvalidArgument, err.Error()) } ctx := sdk.UnwrapSDKContext(goCtx) diff --git a/x/supplier/client/cli/tx_create_claim_test.go b/x/supplier/client/cli/tx_create_claim_test.go new file mode 100644 index 000000000..276b6d093 --- /dev/null +++ b/x/supplier/client/cli/tx_create_claim_test.go @@ -0,0 +1,3 @@ +package cli + +// TODO_NEXT(@bryanchriswhite #140): add comprehensive CLI test coverage for creating claims. diff --git a/x/supplier/client/cli/tx_submit_proof_test.go b/x/supplier/client/cli/tx_submit_proof_test.go new file mode 100644 index 000000000..d3cd52933 --- /dev/null +++ b/x/supplier/client/cli/tx_submit_proof_test.go @@ -0,0 +1,3 @@ +package cli + +// TODO_NEXT(@bryanchriswhite #141): add comprehensive CLI test coverage for submitting proofs. diff --git a/x/supplier/keeper/msg_server_create_claim.go b/x/supplier/keeper/msg_server_create_claim.go index 3f4b0df30..2ccd4f31e 100644 --- a/x/supplier/keeper/msg_server_create_claim.go +++ b/x/supplier/keeper/msg_server_create_claim.go @@ -4,6 +4,8 @@ import ( "context" sdk "github.com/cosmos/cosmos-sdk/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) @@ -14,6 +16,7 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCrea ctx := sdk.UnwrapSDKContext(goCtx) logger := k.Logger(ctx).With("method", "CreateClaim") + logger.Debug("creating claim") if err := msg.ValidateBasic(); err != nil { return nil, err @@ -25,7 +28,7 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCrea msg.GetSupplierAddress(), ) if err != nil { - return nil, err + return nil, status.Error(codes.InvalidArgument, err.Error()) } logger. diff --git a/x/supplier/keeper/msg_server_create_claim_test.go b/x/supplier/keeper/msg_server_create_claim_test.go index 93bd2b5b2..36efb5ded 100644 --- a/x/supplier/keeper/msg_server_create_claim_test.go +++ b/x/supplier/keeper/msg_server_create_claim_test.go @@ -5,6 +5,8 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" keepertest "github.com/pokt-network/poktroll/testutil/keeper" "github.com/pokt-network/poktroll/testutil/sample" @@ -16,7 +18,10 @@ import ( suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) -const testServiceId = "svc1" +const ( + testServiceId = "svc1" + testSessionId = "mock_session_id" +) func TestMsgServer_CreateClaim_Success(t *testing.T) { appSupplierPair := supplier.AppSupplierPair{ @@ -30,7 +35,7 @@ func TestMsgServer_CreateClaim_Success(t *testing.T) { srv := keeper.NewMsgServerImpl(*supplierKeeper) ctx := sdk.WrapSDKContext(sdkCtx) - claimMsg := newTestClaimMsg(t) + claimMsg := newTestClaimMsg(t, testSessionId) claimMsg.SupplierAddress = appSupplierPair.SupplierAddr claimMsg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr @@ -69,19 +74,25 @@ func TestMsgServer_CreateClaim_Error(t *testing.T) { { desc: "on-chain session ID must match claim msg session ID", claimMsgFn: func(t *testing.T) *types.MsgCreateClaim { - msg := newTestClaimMsg(t) + msg := newTestClaimMsg(t, "invalid_session_id") msg.SupplierAddress = appSupplierPair.SupplierAddr msg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr - msg.SessionHeader.SessionId = "invalid_session_id" return msg }, - expectedErr: types.ErrSupplierInvalidSessionId, + expectedErr: status.Error( + codes.InvalidArgument, + types.ErrSupplierInvalidSessionId.Wrapf( + "session ID does not match on-chain session ID; expected %q, got %q", + testSessionId, + "invalid_session_id", + ).Error(), + ), }, { desc: "claim msg supplier address must be in the session", claimMsgFn: func(t *testing.T) *types.MsgCreateClaim { - msg := newTestClaimMsg(t) + msg := newTestClaimMsg(t, testSessionId) msg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr // Overwrite supplier address to one not included in the session fixtures. @@ -96,21 +107,21 @@ func TestMsgServer_CreateClaim_Error(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { createClaimRes, err := srv.CreateClaim(ctx, tt.claimMsgFn(t)) - require.ErrorIs(t, err, tt.expectedErr) + require.ErrorContains(t, err, tt.expectedErr.Error()) require.Nil(t, createClaimRes) }) } } -func newTestClaimMsg(t *testing.T) *suppliertypes.MsgCreateClaim { +func newTestClaimMsg(t *testing.T, sessionId string) *suppliertypes.MsgCreateClaim { t.Helper() return suppliertypes.NewMsgCreateClaim( sample.AccAddress(), &sessiontypes.SessionHeader{ ApplicationAddress: sample.AccAddress(), - SessionStartBlockHeight: 1, - SessionId: "mock_session_id", + SessionStartBlockHeight: 0, + SessionId: sessionId, Service: &sharedtypes.Service{ Id: "svc1", Name: "svc1", diff --git a/x/supplier/keeper/msg_server_submit_proof.go b/x/supplier/keeper/msg_server_submit_proof.go index 7a764a12d..cae07a675 100644 --- a/x/supplier/keeper/msg_server_submit_proof.go +++ b/x/supplier/keeper/msg_server_submit_proof.go @@ -4,55 +4,134 @@ import ( "context" sdk "github.com/cosmos/cosmos-sdk/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" - "github.com/pokt-network/poktroll/x/supplier/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) -func (k msgServer) SubmitProof(goCtx context.Context, msg *types.MsgSubmitProof) (*types.MsgSubmitProofResponse, error) { +func (k msgServer) SubmitProof(goCtx context.Context, msg *suppliertypes.MsgSubmitProof) (*suppliertypes.MsgSubmitProofResponse, error) { // TODO_BLOCKER: Prevent Proof upserts after the tokenomics module has processes the respective session. // TODO_BLOCKER: Validate the signature on the Proof message corresponds to the supplier before Upserting. - ctx := sdk.UnwrapSDKContext(goCtx) + logger := k.Logger(ctx).With("method", "SubmitProof") + logger.Debug("submitting proof") + + /* + TODO_INCOMPLETE: Handling the message + + ## Actions (error if anything fails) + 1. Retrieve a fully hydrated `session` from on-chain store using `msg` metadata + 2. Retrieve a fully hydrated `claim` from on-chain store using `msg` metadata + 3. Retrieve `relay.Req` and `relay.Res` from deserializing `proof.ClosestValueHash` + + ## Basic Validations (metadata only) + 1. proof.sessionId == claim.sessionId + 2. msg.supplier in session.suppliers + 3. relay.Req.signer == session.appAddr + 4. relay.Res.signer == msg.supplier + + ## Msg distribution validation (governance based params) + 1. Validate Proof submission is not too early; governance-based param + pseudo-random variation + 2. Validate Proof submission is not too late; governance-based param + pseudo-random variation + + ## Relay Signature validation + 1. verify(relay.Req.Signature, appRing) + 2. verify(relay.Res.Signature, supplier.pubKey) + + ## Relay Mining validation + 1. verify(proof.path) is the expected path; pseudo-random variation using on-chain data + 2. verify(proof.ValueHash, expectedDiffictulty); governance based + 3. verify(claim.Root, proof.ClosestProof); verify the closest proof is correct + */ if err := msg.ValidateBasic(); err != nil { - return nil, err + return nil, status.Error(codes.InvalidArgument, err.Error()) } - /* - INCOMPLETE: Handling the message - - ## Validation - - ### Session validation - 1. [ ] claimed session ID == retrieved session ID - 2. [ ] this supplier is in the session's suppliers list - 3. [ ] proof signer addr == session application addr - - ### Msg distribution validation (depends on session validation) - 1. [ ] pseudo-randomize earliest block offset - 2. [ ] governance-based earliest block offset - - ### Proof validation - 1. [ ] session validation - 2. [ ] msg distribution validation - 3. [ ] claim with matching session ID exists - 4. [ ] proof path matches last committed block hash at claim height - 1 - 5. [ ] proof validates with claimed root hash - - ## Persistence - 1. [ ] submit proof message - - supplier address - - session header - - proof - - ## Accounting - 1. [ ] extract work done from root hash - 2. [ ] calculate reward/burn token with governance-based multiplier - 3. [ ] reward supplier - 4. [ ] burn application tokens - */ + if _, err := k.queryAndValidateSessionHeader( + goCtx, + msg.GetSessionHeader(), + msg.GetSupplierAddress(), + ); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + // Construct and insert proof after all validation. + proof := suppliertypes.Proof{ + SupplierAddress: msg.GetSupplierAddress(), + SessionHeader: msg.GetSessionHeader(), + ClosestMerkleProof: msg.Proof, + } + + if err := k.queryAndValidateClaimForProof(ctx, &proof); err != nil { + return nil, status.Error(codes.FailedPrecondition, err.Error()) + } + + // TODO_BLOCKER: check if this proof already exists and return an appropriate error + // in any case where the supplier should no longer be able to update the given proof. + k.Keeper.UpsertProof(ctx, proof) + + // TODO_BLOCKER(@bryanchriswhite, @Olshansk): Call `tokenomics.SettleSessionAccounting()` here + + logger. + With( + "session_id", proof.GetSessionHeader().GetSessionId(), + "session_end_height", proof.GetSessionHeader().GetSessionEndBlockHeight(), + "supplier", proof.GetSupplierAddress(), + ). + Debug("created proof") + + return &suppliertypes.MsgSubmitProofResponse{}, nil +} + +// queryAndValidateClaimForProof ensures that a claim corresponding to the given proof's +// session exists & has a matching supplier address and session header. +func (k msgServer) queryAndValidateClaimForProof(sdkCtx sdk.Context, proof *suppliertypes.Proof) error { + sessionId := proof.GetSessionHeader().GetSessionId() + // NB: no need to assert the testSessionId or supplier address as it is retrieved + // by respective values of the given proof. I.e., if the claim exists, then these + // values are guaranteed to match. + claim, found := k.GetClaim(sdkCtx, sessionId, proof.GetSupplierAddress()) + if !found { + return suppliertypes.ErrSupplierClaimNotFound.Wrapf("no claim found for session ID %q and supplier %q", sessionId, proof.GetSupplierAddress()) + } - _ = ctx + // Ensure session start heights match. + if claim.GetSessionHeader().GetSessionStartBlockHeight() != proof.GetSessionHeader().GetSessionStartBlockHeight() { + return suppliertypes.ErrSupplierInvalidSessionStartHeight.Wrapf( + "claim session start height %d does not match proof session start height %d", + claim.GetSessionHeader().GetSessionStartBlockHeight(), + proof.GetSessionHeader().GetSessionStartBlockHeight(), + ) + } + + // Ensure session end heights match. + if claim.GetSessionHeader().GetSessionEndBlockHeight() != proof.GetSessionHeader().GetSessionEndBlockHeight() { + return suppliertypes.ErrSupplierInvalidSessionEndHeight.Wrapf( + "claim session end height %d does not match proof session end height %d", + claim.GetSessionHeader().GetSessionEndBlockHeight(), + proof.GetSessionHeader().GetSessionEndBlockHeight(), + ) + } + + // Ensure application addresses match. + if claim.GetSessionHeader().GetApplicationAddress() != proof.GetSessionHeader().GetApplicationAddress() { + return suppliertypes.ErrSupplierInvalidAddress.Wrapf( + "claim application address %q does not match proof application address %q", + claim.GetSessionHeader().GetApplicationAddress(), + proof.GetSessionHeader().GetApplicationAddress(), + ) + } + + // Ensure service IDs match. + if claim.GetSessionHeader().GetService().GetId() != proof.GetSessionHeader().GetService().GetId() { + return suppliertypes.ErrSupplierInvalidService.Wrapf( + "claim service ID %q does not match proof service ID %q", + claim.GetSessionHeader().GetService().GetId(), + proof.GetSessionHeader().GetService().GetId(), + ) + } - return &types.MsgSubmitProofResponse{}, nil + return nil } diff --git a/x/supplier/types/errors.go b/x/supplier/types/errors.go index 5406c8e3f..61a56f566 100644 --- a/x/supplier/types/errors.go +++ b/x/supplier/types/errors.go @@ -21,4 +21,6 @@ var ( ErrSupplierInvalidQueryRequest = sdkerrors.Register(ModuleName, 11, "invalid query request") ErrSupplierClaimNotFound = sdkerrors.Register(ModuleName, 12, "claim not found") ErrSupplierProofNotFound = sdkerrors.Register(ModuleName, 13, "proof not found") + ErrSupplierInvalidProof = sdkerrors.Register(ModuleName, 14, "invalid proof") + ErrSupplierInvalidClosestMerkleProof = sdkerrors.Register(ModuleName, 15, "invalid closest merkle proof") ) diff --git a/x/supplier/types/message_create_claim.go b/x/supplier/types/message_create_claim.go index 0b20c27ec..84d2a4b3a 100644 --- a/x/supplier/types/message_create_claim.go +++ b/x/supplier/types/message_create_claim.go @@ -47,27 +47,27 @@ func (msg *MsgCreateClaim) GetSignBytes() []byte { func (msg *MsgCreateClaim) ValidateBasic() error { // Validate the supplier address - _, err := sdk.AccAddressFromBech32(msg.SupplierAddress) + _, err := sdk.AccAddressFromBech32(msg.GetSupplierAddress()) if err != nil { - return sdkerrors.Wrapf(ErrSupplierInvalidAddress, "invalid supplierAddress address (%s)", err) + return sdkerrors.Wrapf(ErrSupplierInvalidAddress, "%s", msg.GetSupplierAddress()) } // Validate the session header sessionHeader := msg.SessionHeader - if sessionHeader.SessionStartBlockHeight < 1 { - return sdkerrors.Wrapf(ErrSupplierInvalidSessionStartHeight, "invalid session start block height (%d)", sessionHeader.SessionStartBlockHeight) + if sessionHeader.SessionStartBlockHeight < 0 { + return sdkerrors.Wrapf(ErrSupplierInvalidSessionStartHeight, "%d", sessionHeader.SessionStartBlockHeight) } if len(sessionHeader.SessionId) == 0 { - return sdkerrors.Wrapf(ErrSupplierInvalidSessionId, "invalid session ID (%v)", sessionHeader.SessionId) + return sdkerrors.Wrapf(ErrSupplierInvalidSessionId, "%s", sessionHeader.SessionId) } if !sharedhelpers.IsValidService(sessionHeader.Service) { - return sdkerrors.Wrapf(ErrSupplierInvalidService, "invalid service (%v)", sessionHeader.Service) + return sdkerrors.Wrapf(ErrSupplierInvalidService, "%v", sessionHeader.Service) } // Validate the root hash // TODO_IMPROVE: Only checking to make sure a non-nil hash was provided for now, but we can validate the length as well. if len(msg.RootHash) == 0 { - return sdkerrors.Wrapf(ErrSupplierInvalidClaimRootHash, "invalid root hash (%v)", msg.RootHash) + return sdkerrors.Wrapf(ErrSupplierInvalidClaimRootHash, "%v", msg.RootHash) } return nil diff --git a/x/supplier/types/message_create_claim_test.go b/x/supplier/types/message_create_claim_test.go index c46647ebe..65a32726c 100644 --- a/x/supplier/types/message_create_claim_test.go +++ b/x/supplier/types/message_create_claim_test.go @@ -31,7 +31,7 @@ func TestMsgCreateClaim_ValidateBasic(t *testing.T) { msg: MsgCreateClaim{ SupplierAddress: sample.AccAddress(), SessionHeader: &sessiontypes.SessionHeader{ - SessionStartBlockHeight: 0, // Invalid start height + SessionStartBlockHeight: -1, // Invalid start height }, }, err: ErrSupplierInvalidSessionStartHeight, diff --git a/x/supplier/types/message_submit_proof.go b/x/supplier/types/message_submit_proof.go index ad00eb225..144872e05 100644 --- a/x/supplier/types/message_submit_proof.go +++ b/x/supplier/types/message_submit_proof.go @@ -40,10 +40,38 @@ func (msg *MsgSubmitProof) GetSignBytes() []byte { return sdk.MustSortJSON(bz) } +// ValidateBasic ensures that the bech32 address strings for the supplier and +// application addresses are valid and that the proof and service ID are not empty. +// +// TODO_TECHDEBT: Call `msg.GetSessionHeader().ValidateBasic()` once its implemented func (msg *MsgSubmitProof) ValidateBasic() error { - _, err := sdk.AccAddressFromBech32(msg.SupplierAddress) + _, err := sdk.AccAddressFromBech32(msg.GetSupplierAddress()) if err != nil { - return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid supplierAddress address (%s)", err) + return sdkerrors.ErrInvalidAddress.Wrapf( + "supplier address %q, error: %s", + msg.GetSupplierAddress(), + err, + ) } + + _, err = sdk.AccAddressFromBech32(msg.GetSessionHeader().GetApplicationAddress()) + if err != nil { + return sdkerrors.ErrInvalidAddress.Wrapf( + "application address: %q, error: %s", + msg.GetSessionHeader().GetApplicationAddress(), + err, + ) + } + + if msg.GetSessionHeader().GetService().GetId() == "" { + return ErrSupplierInvalidService.Wrap("proof service ID %q cannot be empty") + } + + if len(msg.GetProof()) == 0 { + return ErrSupplierInvalidProof.Wrap("proof cannot be empty") + } + + // TODO_BLOCKER: attempt to deserialize the proof for additional validation. + return nil } diff --git a/x/supplier/types/message_submit_proof_test.go b/x/supplier/types/message_submit_proof_test.go index 8479db05d..e7e2cbba6 100644 --- a/x/supplier/types/message_submit_proof_test.go +++ b/x/supplier/types/message_submit_proof_test.go @@ -7,34 +7,93 @@ import ( "github.com/stretchr/testify/require" "github.com/pokt-network/poktroll/testutil/sample" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) -// TODO(@bryanchriswhite): Add unit tests for message validation when adding the business logic. - func TestMsgSubmitProof_ValidateBasic(t *testing.T) { + testService := &sharedtypes.Service{Id: "svc01"} + testClosestMerkleProof := []byte{1, 2, 3, 4} + tests := []struct { - name string - msg MsgSubmitProof - err error + desc string + msg MsgSubmitProof + expectedErr error }{ { - name: "invalid address", + desc: "application bech32 address is invalid", + msg: MsgSubmitProof{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + ApplicationAddress: "not_a_bech32_address", + Service: testService, + SessionStartBlockHeight: 0, + SessionId: "mock_session_id", + SessionEndBlockHeight: 5, + }, + Proof: testClosestMerkleProof, + }, + expectedErr: sdkerrors.ErrInvalidAddress.Wrapf( + "application address: %q, error: %s", + "not_a_bech32_address", + "decoding bech32 failed: invalid separator index -1", + ), + }, + { + desc: "supplier bech32 address is invalid", msg: MsgSubmitProof{ - SupplierAddress: "invalid_address", + SupplierAddress: "not_a_bech32_address", + SessionHeader: &sessiontypes.SessionHeader{ + ApplicationAddress: sample.AccAddress(), + Service: testService, + SessionStartBlockHeight: 0, + SessionId: "mock_session_id", + SessionEndBlockHeight: 5, + }, + Proof: testClosestMerkleProof, }, - err: sdkerrors.ErrInvalidAddress, - }, { - name: "valid address", + expectedErr: sdkerrors.ErrInvalidAddress.Wrapf( + "supplier address %q, error: %s", + "not_a_bech32_address", + "decoding bech32 failed: invalid separator index -1", + ), + }, + { + desc: "session service ID is empty", + msg: MsgSubmitProof{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + ApplicationAddress: sample.AccAddress(), + Service: &sharedtypes.Service{Id: ""}, + SessionStartBlockHeight: 0, + SessionId: "mock_session_id", + SessionEndBlockHeight: 5, + }, + Proof: testClosestMerkleProof, + }, + expectedErr: ErrSupplierInvalidService.Wrap("proof service ID %q cannot be empty"), + }, + { + desc: "valid message metadata", msg: MsgSubmitProof{ SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + ApplicationAddress: sample.AccAddress(), + Service: testService, + SessionId: "mock_session_id", + SessionStartBlockHeight: 0, + SessionEndBlockHeight: 5, + }, + Proof: testClosestMerkleProof, }, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.desc, func(t *testing.T) { err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + require.ErrorContains(t, err, tt.expectedErr.Error()) return } require.NoError(t, err) diff --git a/x/supplier/types/query_validation.go b/x/supplier/types/query_validation.go index c7a352bd6..7111bc950 100644 --- a/x/supplier/types/query_validation.go +++ b/x/supplier/types/query_validation.go @@ -87,7 +87,7 @@ func (query *QueryAllProofsRequest) ValidateBasic() error { default: // No filter is set - logger.Debug().Msg("No specific filter set when requesting claims") + logger.Debug().Msg("No specific filter set when requesting proofs") } return nil