From b1eea4ae61ba1ab1933a39f3e2992331e57c42df Mon Sep 17 00:00:00 2001 From: Israel Miller Date: Thu, 17 Aug 2023 18:36:20 -0400 Subject: [PATCH] fix: respecting context cancelation since sdk does not (#153) * fix: respecting context cancelation since sdk does not * fix: typo in error wrap * adding missing comment --- v2/sender.go | 36 ++++++++++++++++++++++++++++++++---- v2/sender_test.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/v2/sender.go b/v2/sender.go index 98d63de4..acf95e8e 100644 --- a/v2/sender.go +++ b/v2/sender.go @@ -65,10 +65,23 @@ func (d *Sender) SendMessage(ctx context.Context, mb MessageBody, options ...fun ctx, cancel = context.WithTimeout(ctx, d.options.SendTimeout) defer cancel() } - if err := d.sbSender.SendMessage(ctx, msg, nil); err != nil { // sendMessageOptions currently does nothing - return fmt.Errorf("failed to send message: %w", err) + + errChan := make(chan error) + + go func() { + if err := d.sbSender.SendMessage(ctx, msg, nil); err != nil { // sendMessageOptions currently does nothing + errChan <- fmt.Errorf("failed to send message: %w", err) + } + errChan <- nil + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("failed to send message: %w", ctx.Err()) + case err := <-errChan: + return err } - return nil + } // ToServiceBusMessage transform a MessageBody into an azservicebus.Message. @@ -120,7 +133,22 @@ func (d *Sender) SendMessageBatch(ctx context.Context, messages []*azservicebus. return fmt.Errorf("failed to send message batch: %w", err) } - return nil + errChan := make(chan error) + + go func() { + if err := d.sbSender.SendMessageBatch(ctx, batch, nil); err != nil { + errChan <- fmt.Errorf("failed to send message batch: %w", err) + } + errChan <- nil + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("failed to send message batch: %w", ctx.Err()) + case err := <-errChan: + return err + } + } // AzSender returns the underlying azservicebus.Sender instance. diff --git a/v2/sender_test.go b/v2/sender_test.go index 55858aaa..f2d00935 100644 --- a/v2/sender_test.go +++ b/v2/sender_test.go @@ -7,12 +7,11 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "go.opentelemetry.io/otel/sdk/trace" - . "github.com/onsi/gomega" - + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" + . "github.com/onsi/gomega" ) func TestFunc_NewSender(t *testing.T) { @@ -155,6 +154,30 @@ func TestSender_WithSendTimeout(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) } +func TestSender_WithContextCanceled(t *testing.T) { + g := NewWithT(t) + sendTimeout := 1 * time.Second + azSender := &fakeAzSender{ + DoSendMessage: func(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error { + time.Sleep(2 * time.Second) + return nil + }, + DoSendMessageBatch: func(ctx context.Context, messages *azservicebus.MessageBatch, options *azservicebus.SendMessageBatchOptions) error { + time.Sleep(2 * time.Second) + return nil + }, + } + sender := NewSender(azSender, &SenderOptions{ + Marshaller: &DefaultJSONMarshaller{}, + SendTimeout: sendTimeout, + }) + + err := sender.SendMessage(context.Background(), "test") + g.Expect(err).To(MatchError(context.DeadlineExceeded)) + err = sender.SendMessageBatch(context.Background(), nil) + g.Expect(err).To(MatchError(context.DeadlineExceeded)) +} + func TestSender_DisabledSendTimeout(t *testing.T) { g := NewWithT(t) sendTimeout := -1 * time.Second