Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context propogation to subscribers for local pub-sub #487

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,11 @@ func (m *Message) Copy() *Message {
}
return msg
}

// CopyWithContext copies all message without Acks/Nacks.
// The context is also propagated to the copy.
func (m *Message) CopyWithContext() *Message {
msg := m.Copy()
msg.ctx = m.ctx
return msg
}
32 changes: 32 additions & 0 deletions message/message_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package message_test

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -113,6 +114,37 @@ func TestMessage_Copy(t *testing.T) {
assert.True(t, msg.Equals(msgCopy))
}

func TestMessage_CopyWithContext(t *testing.T) {
msg := message.NewMessage("1", []byte("foo"))
testCtx := context.Background()
testCtx = context.WithValue(testCtx, "foo", "bar")
msg.SetContext(testCtx)

msgCopy := msg.CopyWithContext()
testCtx = context.WithValue(testCtx, "foo", "baz")
testCtx = context.WithValue(testCtx, "abc", "def")

copyMsgCtx := msgCopy.Context()
assert.True(t, copyMsgCtx.Value("foo") == "bar", "expected context not being copied")
assert.False(t, copyMsgCtx.Value("abc") == "def", "non-expected context being copied")
assert.True(t, msg.Equals(msgCopy))
}

func TestMessage_CopyWithContextAndMetadata(t *testing.T) {
msg := message.NewMessage("1", []byte("foo"))
testCtx := context.Background()
testCtx = context.WithValue(testCtx, "foo", "bar")
msg.SetContext(testCtx)
msg.Metadata.Set("foo", "bar")
msgCopy := msg.CopyWithContext()

msg.Metadata.Set("foo", "baz")

copyMsgCtx := msgCopy.Context()
assert.True(t, copyMsgCtx.Value("foo") == "bar", "expected context not being copied")
assert.Equal(t, msgCopy.Metadata.Get("foo"), "bar", "did not expect changing source message's metadata to alter copy's metadata")
}

func TestMessage_CopyMetadata(t *testing.T) {
msg := message.NewMessage("1", []byte("foo"))
msg.Metadata.Set("foo", "bar")
Expand Down
2 changes: 1 addition & 1 deletion message/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Subscriber interface {
// If message processing fails and the message should be redelivered `Nack()` should be called instead.
//
// When the provided ctx is canceled, the subscriber closes the subscription and the output channel.
// The provided ctx is passed to all produced messages.
// The provided ctx is passed to all produced messages (this is configurable for the local Pub/Sub implementation).
// When Nack or Ack is called on the message, the context of the message is canceled.
Subscribe(ctx context.Context, topic string) (<-chan *Message, error)
// Close closes all subscriptions with their output channels and flushes offsets etc. when needed.
Expand Down
34 changes: 26 additions & 8 deletions pubsub/gochannel/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ type Config struct {
// When true, Publish will block until subscriber Ack's the message.
// If there are no subscribers, Publish will not block (also when Persistent is true).
BlockPublishUntilSubscriberAck bool

// PreserveContext is a flag that determines if the context should be preserved when sending messages to subscribers.
// This behavior is different from other implementations of Publishers where data travels over the network,
// hence context can't be preserved in those cases
PreserveContext bool
}

// GoChannel is the simplest Pub/Sub implementation.
Expand Down Expand Up @@ -87,7 +92,11 @@ func (g *GoChannel) Publish(topic string, messages ...*message.Message) error {

messagesToPublish := make(message.Messages, len(messages))
for i, msg := range messages {
messagesToPublish[i] = msg.Copy()
if g.config.PreserveContext {
messagesToPublish[i] = msg.CopyWithContext()
} else {
messagesToPublish[i] = msg.Copy()
}
}

g.subscribersLock.RLock()
Expand Down Expand Up @@ -187,11 +196,12 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
subLock.(*sync.Mutex).Lock()

s := &subscriber{
ctx: ctx,
uuid: watermill.NewUUID(),
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
logger: g.logger,
closing: make(chan struct{}),
ctx: ctx,
uuid: watermill.NewUUID(),
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
logger: g.logger,
closing: make(chan struct{}),
preserveContext: g.config.PreserveContext,
}

go func(s *subscriber, g *GoChannel) {
Expand Down Expand Up @@ -320,6 +330,8 @@ type subscriber struct {
logger watermill.LoggerAdapter
closed bool
closing chan struct{}

preserveContext bool
}

func (s *subscriber) Close() {
Expand All @@ -344,8 +356,14 @@ func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields wat
s.sending.Lock()
defer s.sending.Unlock()

ctx, cancelCtx := context.WithCancel(s.ctx)
defer cancelCtx()
ctx := msg.Context()

//This is getting the context from the message, not the subscriber
if !s.preserveContext {
var cancelCtx context.CancelFunc
ctx, cancelCtx = context.WithCancel(s.ctx)
defer cancelCtx()
}

SendToSubscriber:
for {
Expand Down
53 changes: 53 additions & 0 deletions pubsub/gochannel/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ func createPersistentPubSub(t *testing.T) (message.Publisher, message.Subscriber
return pubSub, pubSub
}

func createPersistentPubSubWithContextPreserved(t *testing.T) (message.Publisher, message.Subscriber) {
pubSub := gochannel.NewGoChannel(
gochannel.Config{
OutputChannelBuffer: 10000,
Persistent: true,
PreserveContext: true,
},
watermill.NewStdLogger(true, true),
)
return pubSub, pubSub
}

func TestPublishSubscribe_persistent(t *testing.T) {
tests.TestPubSub(
t,
Expand All @@ -44,6 +56,22 @@ func TestPublishSubscribe_persistent(t *testing.T) {
)
}

func TestPublishSubscribe_context_preserved(t *testing.T) {
tests.TestPubSub(
t,
tests.Features{
ConsumerGroups: false,
ExactlyOnceDelivery: true,
GuaranteedOrder: false,
Persistent: false,
RequireSingleInstance: true,
ContextPreserved: true,
},
createPersistentPubSubWithContextPreserved,
nil,
)
}

func TestPublishSubscribe_not_persistent(t *testing.T) {
messagesCount := 100
pubSub := gochannel.NewGoChannel(
Expand All @@ -63,6 +91,31 @@ func TestPublishSubscribe_not_persistent(t *testing.T) {
assert.NoError(t, pubSub.Close())
}

func TestPublishSubscribe_not_persistent_with_context(t *testing.T) {
messagesCount := 100
pubSub := gochannel.NewGoChannel(
gochannel.Config{OutputChannelBuffer: int64(messagesCount), PreserveContext: true},
watermill.NewStdLogger(true, true),
)
topicName := "test_topic_" + watermill.NewUUID()

msgs, err := pubSub.Subscribe(context.Background(), topicName)
require.NoError(t, err)

const contextKeyString = "foo"
sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pubSub, topicName)
receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second)

expectedContexts := make(map[string]context.Context)
for _, msg := range sendMessages {
expectedContexts[msg.UUID] = msg.Context()
}
tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs)
tests.AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMsgs)

assert.NoError(t, pubSub.Close())
}

func TestPublishSubscribe_block_until_ack(t *testing.T) {
pubSub := gochannel.NewGoChannel(
gochannel.Config{BlockPublishUntilSubscriberAck: true},
Expand Down
11 changes: 11 additions & 0 deletions pubsub/tests/test_asserts.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"context"
"sort"
"testing"

Expand Down Expand Up @@ -92,3 +93,13 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string]

return ok
}

// AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled.
func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) {
assert.Len(t, received, len(expectedValues))
for _, msg := range received {
expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string)
actualValue := msg.Context().Value(contextKey(contextKeyString))
assert.Equal(t, expectedValue, actualValue)
}
}
47 changes: 43 additions & 4 deletions pubsub/tests/test_pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ type Features struct {

// GenerateTopicFunc overrides standard topic name generation.
GenerateTopicFunc func(tctx TestContext) string

// ContextPreserved should be set to true if the Pub/Sub implementation preserves the context
// of the message when it's published and consumed.
ContextPreserved bool
}

// RunOnlyFastTests returns true if -short flag was provided -race was not provided.
Expand Down Expand Up @@ -993,7 +997,14 @@ func TestSubscribeCtx(
if subscribeInitializer, ok := sub.(message.SubscribeInitializer); ok {
require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName))
}
publishedMessages := PublishSimpleMessages(t, messagesCount, pub, topicName)

var publishedMessages message.Messages
var contextKeyString = "abc"
if tCtx.Features.ContextPreserved {
publishedMessages = PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pub, topicName)
} else {
publishedMessages = PublishSimpleMessages(t, messagesCount, pub, topicName)
}

msgsToCancel, err := sub.Subscribe(ctxWithCancel, topicName)
require.NoError(t, err)
Expand All @@ -1017,14 +1028,24 @@ ClosedLoop:
}

ctx := context.WithValue(context.Background(), contextKey("foo"), "bar")

// For mocking the output of pub-subs where context is preserved vs not preserved
expectedContexts := make(map[string]context.Context)
for _, msg := range publishedMessages {
if tCtx.Features.ContextPreserved {
expectedContexts[msg.UUID] = msg.Context()
} else {
expectedContexts[msg.UUID] = ctx
}
}

msgs, err := sub.Subscribe(ctx, topicName)
require.NoError(t, err)

receivedMessages, _ := bulkRead(tCtx, msgs, messagesCount, defaultTimeout)
AssertAllMessagesReceived(t, publishedMessages, receivedMessages)

for _, msg := range receivedMessages {
assert.EqualValues(t, "bar", msg.Context().Value(contextKey("foo")))
if tCtx.Features.ContextPreserved {
AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMessages)
}
}

Expand Down Expand Up @@ -1271,6 +1292,24 @@ func PublishSimpleMessages(t *testing.T, messagesCount int, publisher message.Pu
return messagesToPublish
}

// PublishSimpleMessagesWithContext publishes provided number of simple messages without a payload, but custom context
func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKeyString string, publisher message.Publisher, topicName string) message.Messages {
var messagesToPublish []*message.Message

for i := 0; i < messagesCount; i++ {
id := watermill.NewUUID()

msg := message.NewMessage(id, nil)
msg.SetContext(context.WithValue(context.Background(), contextKey(contextKeyString), "bar"+strconv.Itoa(i)))
messagesToPublish = append(messagesToPublish, msg)

err := publishWithRetry(publisher, topicName, msg)
require.NoError(t, err, "cannot publish messages")
}

return messagesToPublish
}

// AddSimpleMessagesParallel publishes provided number of simple messages without a payload
// using the provided number of publishers (goroutines).
func AddSimpleMessagesParallel(t *testing.T, messagesCount int, publisher message.Publisher, topicName string, publishers int) message.Messages {
Expand Down