From c47dc4129a314c9ee20ea59c3e010fddab29bcf3 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Tue, 17 Dec 2024 08:40:22 -0800 Subject: [PATCH] [KS-602] Fix for remote exectable client not respecting context (#15721) --- core/capabilities/remote/executable/client.go | 25 +++++++++++++---- .../remote/executable/client_test.go | 28 ++++++++++++++++++- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/core/capabilities/remote/executable/client.go b/core/capabilities/remote/executable/client.go index 776ddb692ad..9f23ff4ce4a 100644 --- a/core/capabilities/remote/executable/client.go +++ b/core/capabilities/remote/executable/client.go @@ -43,6 +43,11 @@ var _ services.Service = &client{} const expiryCheckInterval = 30 * time.Second +var ( + ErrRequestExpired = errors.New("request expired by executable client") + ErrContextDoneBeforeResponseQuorum = errors.New("context done before remote client received a quorum of responses") +) + func NewClient(remoteCapabilityInfo commoncap.CapabilityInfo, localDonInfo commoncap.DON, dispatcher types.Dispatcher, requestTimeout time.Duration, lggr logger.Logger) *client { return &client{ @@ -122,7 +127,7 @@ func (c *client) expireRequests() { for messageID, req := range c.requestIDToCallerRequest { if req.Expired() { - req.Cancel(errors.New("request expired by executable client")) + req.Cancel(ErrRequestExpired) delete(c.requestIDToCallerRequest, messageID) } @@ -164,12 +169,22 @@ func (c *client) Execute(ctx context.Context, capReq commoncap.CapabilityRequest return commoncap.CapabilityResponse{}, fmt.Errorf("failed to send request: %w", err) } - resp := <-req.ResponseChan() - if resp.Err != nil { - return commoncap.CapabilityResponse{}, fmt.Errorf("error executing request: %w", resp.Err) + var respResult []byte + var respErr error + select { + case resp := <-req.ResponseChan(): + respResult = resp.Result + respErr = resp.Err + case <-ctx.Done(): + // NOTE: ClientRequest will not block on sending to ResponseChan() because that channel is buffered (with size 1) + return commoncap.CapabilityResponse{}, errors.Join(ErrContextDoneBeforeResponseQuorum, ctx.Err()) + } + + if respErr != nil { + return commoncap.CapabilityResponse{}, fmt.Errorf("error executing request: %w", respErr) } - capabilityResponse, err := pb.UnmarshalCapabilityResponse(resp.Result) + capabilityResponse, err := pb.UnmarshalCapabilityResponse(respResult) if err != nil { return commoncap.CapabilityResponse{}, fmt.Errorf("failed to unmarshal capability response: %w", err) } diff --git a/core/capabilities/remote/executable/client_test.go b/core/capabilities/remote/executable/client_test.go index 5c4da350b9e..d1602e9dc07 100644 --- a/core/capabilities/remote/executable/client_test.go +++ b/core/capabilities/remote/executable/client_test.go @@ -143,7 +143,8 @@ func Test_Client_TimesOutIfInsufficientCapabilityPeerResponses(t *testing.T) { ctx := testutils.Context(t) responseTest := func(t *testing.T, response commoncap.CapabilityResponse, responseError error) { - assert.Error(t, responseError) + require.Error(t, responseError) + require.ErrorIs(t, responseError, executable.ErrRequestExpired) } capability := &TestCapability{} @@ -165,6 +166,31 @@ func Test_Client_TimesOutIfInsufficientCapabilityPeerResponses(t *testing.T) { }) } +func Test_Client_ContextCanceledBeforeQuorumReached(t *testing.T) { + ctx, cancel := context.WithCancel(testutils.Context(t)) + + responseTest := func(t *testing.T, response commoncap.CapabilityResponse, responseError error) { + require.Error(t, responseError) + require.ErrorIs(t, responseError, executable.ErrContextDoneBeforeResponseQuorum) + } + + capability := &TestCapability{} + transmissionSchedule, err := values.NewMap(map[string]any{ + "schedule": transmission.Schedule_AllAtOnce, + "deltaStage": "20s", + }) + require.NoError(t, err) + + cancel() + testClient(t, 2, 20*time.Second, 2, 2, + capability, + func(caller commoncap.ExecutableCapability) { + executeInputs, err := values.NewMap(map[string]any{"executeValue1": "aValue1"}) + require.NoError(t, err) + executeMethod(ctx, caller, transmissionSchedule, executeInputs, responseTest, t) + }) +} + func testClient(t *testing.T, numWorkflowPeers int, workflowNodeResponseTimeout time.Duration, numCapabilityPeers int, capabilityDonF uint8, underlying commoncap.ExecutableCapability, method func(caller commoncap.ExecutableCapability)) {