Skip to content

Commit

Permalink
enhance: refine pular related mq interfaces
Browse files Browse the repository at this point in the history
Signed-off-by: tinswzy <[email protected]>
  • Loading branch information
tinswzy committed Nov 25, 2024
1 parent 27c22d1 commit b094b6e
Show file tree
Hide file tree
Showing 48 changed files with 367 additions and 355 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
return make(chan *msgstream.MsgPack, 100)
}

func (mtm *mockTtMsgStream) AsProducer(channels []string) {}
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}

func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil
Expand All @@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
return make([]string, 0)
}

func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
func (mtm *mockTtMsgStream) Produce(ctx context.Context, *msgstream.MsgPack) error {
return nil
}

func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (mtm *mockTtMsgStream) Broadcast(ctx context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return nil, nil
}

Expand Down
18 changes: 9 additions & 9 deletions internal/proxy/channels_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import (
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error)
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID)
removeAllDMLStream()
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
return ok && streamInfos.stream != nil
}

func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error

Expand All @@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
return nil, err
}

stream.AsProducer(pchans)
stream.AsProducer(ctx, pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
Expand All @@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) {

// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
Expand All @@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr
return nil, err
}

stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
Expand Down Expand Up @@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea

// getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}

return mgr.createMsgStream(collectionID)
return mgr.createMsgStream(ctx, collectionID)
}

// removeStream remove the corresponding stream of the specified collection. Idempotent.
Expand Down Expand Up @@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID)
}

func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID)
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
}

func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
Expand Down
24 changes: 12 additions & 12 deletions internal/proxy/channels_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) {
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})

Expand All @@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})

Expand All @@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
Expand All @@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down Expand Up @@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
}()
Expand All @@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(100)
stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down Expand Up @@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand All @@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getOrCreateStream(100)
_, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -6310,7 +6310,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(msgPack)
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
Expand Down
10 changes: 5 additions & 5 deletions internal/proxy/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
Expand Down Expand Up @@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

t.Run("create database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

t.Run("drop database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
Expand Down Expand Up @@ -1496,7 +1496,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t)
Expand Down
31 changes: 17 additions & 14 deletions internal/proxy/mock_channels_manager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/proxy/mock_msgstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type mockMsgStream struct {
enableProduce func(bool)
}

func (m *mockMsgStream) AsProducer(producers []string) {
func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) {
if m.asProducer != nil {
m.asProducer(producers)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
return ms.msgChan
}

func (ms *simpleMockMsgStream) AsProducer(channels []string) {
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
}

func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
Expand Down Expand Up @@ -283,15 +283,15 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
ms.increaseMsgCount(-delta)
}

func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
defer ms.increaseMsgCount(1)

ms.msgChan <- pack

return nil
}

func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (node *Proxy) Init() error {
return err
}
node.replicateMsgStream.EnableProduce(true)
node.replicateMsgStream.AsProducer([]string{replicateMsgChannel})
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})

node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
if err != nil {
Expand Down
Loading

0 comments on commit b094b6e

Please sign in to comment.