diff --git a/cmd/milvus/util.go b/cmd/milvus/util.go index 35068a6d320dc..baa94db368b12 100644 --- a/cmd/milvus/util.go +++ b/cmd/milvus/util.go @@ -20,6 +20,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/cmd/roles" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -171,7 +172,12 @@ func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles { fmt.Fprintf(os.Stderr, "Unknown server type = %s\n%s", serverType, getHelp()) os.Exit(-1) } - + coordclient.EnableLocalClientRole(&coordclient.LocalClientRoleConfig{ + ServerType: serverType, + EnableQueryCoord: role.EnableQueryCoord, + EnableDataCoord: role.EnableDataCoord, + EnableRootCoord: role.EnableRootCoord, + }) return role } diff --git a/internal/coordinator/coordclient/datacoord.go b/internal/coordinator/coordclient/datacoord.go new file mode 100644 index 0000000000000..ba2cb99cb16e3 --- /dev/null +++ b/internal/coordinator/coordclient/datacoord.go @@ -0,0 +1,418 @@ +package coordclient + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +var _ types.DataCoordClient = &dataCoordLocalClientImpl{} + +// newDataCoordLocalClient creates a new local client for data coordinator server. +func newDataCoordLocalClient() *dataCoordLocalClientImpl { + return &dataCoordLocalClientImpl{ + localDataCoordServer: syncutil.NewFuture[datapb.DataCoordServer](), + } +} + +// dataCoordLocalClientImpl is used to implement a local client for data coordinator server. +// We need to merge all the coordinator into one server, so use those client to erase the rpc layer between different coord. +type dataCoordLocalClientImpl struct { + localDataCoordServer *syncutil.Future[datapb.DataCoordServer] +} + +func (c *dataCoordLocalClientImpl) setReadyServer(server datapb.DataCoordServer) { + c.localDataCoordServer.Set(server) +} + +func (c *dataCoordLocalClientImpl) waitForReady(ctx context.Context) (datapb.DataCoordServer, error) { + return c.localDataCoordServer.GetWithContext(ctx) +} + +func (c *dataCoordLocalClientImpl) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetComponentStates(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetTimeTickChannel(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetStatisticsChannel(ctx, in) +} + +func (c *dataCoordLocalClientImpl) Flush(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.Flush(ctx, in) +} + +func (c *dataCoordLocalClientImpl) AssignSegmentID(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AssignSegmentID(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetSegmentInfo(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentInfo(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetSegmentStates(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentStates(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetInsertBinlogPaths(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetInsertBinlogPaths(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetCollectionStatistics(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetPartitionStatistics(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetPartitionStatistics(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentInfoChannel(ctx, in) +} + +func (c *dataCoordLocalClientImpl) SaveBinlogPaths(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SaveBinlogPaths(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetRecoveryInfo(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetRecoveryInfo(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetRecoveryInfoV2(ctx context.Context, in *datapb.GetRecoveryInfoRequestV2, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetRecoveryInfoV2(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetFlushedSegments(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetSegmentsByStates(ctx context.Context, in *datapb.GetSegmentsByStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentsByStates(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetFlushAllState(ctx context.Context, in *milvuspb.GetFlushAllStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetFlushAllState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowConfigurations(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetMetrics(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ManualCompaction(ctx context.Context, in *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ManualCompaction(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetCompactionState(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetCompactionState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetCompactionStateWithPlans(ctx context.Context, in *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetCompactionStateWithPlans(ctx, in) +} + +func (c *dataCoordLocalClientImpl) WatchChannels(ctx context.Context, in *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.WatchChannels(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetFlushState(ctx context.Context, in *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetFlushState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) DropVirtualChannel(ctx context.Context, in *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropVirtualChannel(ctx, in) +} + +func (c *dataCoordLocalClientImpl) SetSegmentState(ctx context.Context, in *datapb.SetSegmentStateRequest, opts ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SetSegmentState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) UpdateSegmentStatistics(ctx context.Context, in *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateSegmentStatistics(ctx, in) +} + +func (c *dataCoordLocalClientImpl) UpdateChannelCheckpoint(ctx context.Context, in *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateChannelCheckpoint(ctx, in) +} + +func (c *dataCoordLocalClientImpl) MarkSegmentsDropped(ctx context.Context, in *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.MarkSegmentsDropped(ctx, in) +} + +func (c *dataCoordLocalClientImpl) BroadcastAlteredCollection(ctx context.Context, in *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.BroadcastAlteredCollection(ctx, in) +} + +func (c *dataCoordLocalClientImpl) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CheckHealth(ctx, in) +} + +func (c *dataCoordLocalClientImpl) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateIndex(ctx, in) +} + +func (c *dataCoordLocalClientImpl) AlterIndex(ctx context.Context, in *indexpb.AlterIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AlterIndex(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetIndexState(ctx context.Context, in *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetIndexState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetSegmentIndexState(ctx context.Context, in *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentIndexState(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetIndexInfos(ctx context.Context, in *indexpb.GetIndexInfoRequest, opts ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetIndexInfos(ctx, in) +} + +func (c *dataCoordLocalClientImpl) DropIndex(ctx context.Context, in *indexpb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropIndex(ctx, in) +} + +func (c *dataCoordLocalClientImpl) DescribeIndex(ctx context.Context, in *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeIndex(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetIndexStatistics(ctx context.Context, in *indexpb.GetIndexStatisticsRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetIndexStatistics(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetIndexBuildProgress(ctx context.Context, in *indexpb.GetIndexBuildProgressRequest, opts ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetIndexBuildProgress(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ListIndexes(ctx context.Context, in *indexpb.ListIndexesRequest, opts ...grpc.CallOption) (*indexpb.ListIndexesResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListIndexes(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GcConfirm(ctx context.Context, in *datapb.GcConfirmRequest, opts ...grpc.CallOption) (*datapb.GcConfirmResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GcConfirm(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ReportDataNodeTtMsgs(ctx context.Context, in *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ReportDataNodeTtMsgs(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GcControl(ctx context.Context, in *datapb.GcControlRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GcControl(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ImportV2(ctx context.Context, in *internalpb.ImportRequestInternal, opts ...grpc.CallOption) (*internalpb.ImportResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ImportV2(ctx, in) +} + +func (c *dataCoordLocalClientImpl) GetImportProgress(ctx context.Context, in *internalpb.GetImportProgressRequest, opts ...grpc.CallOption) (*internalpb.GetImportProgressResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetImportProgress(ctx, in) +} + +func (c *dataCoordLocalClientImpl) ListImports(ctx context.Context, in *internalpb.ListImportsRequestInternal, opts ...grpc.CallOption) (*internalpb.ListImportsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListImports(ctx, in) +} + +func (c *dataCoordLocalClientImpl) Close() error { + return nil +} diff --git a/internal/coordinator/coordclient/datacoord_test.go b/internal/coordinator/coordclient/datacoord_test.go new file mode 100644 index 0000000000000..6b42f4135e8be --- /dev/null +++ b/internal/coordinator/coordclient/datacoord_test.go @@ -0,0 +1,308 @@ +package coordclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" +) + +func TestDataCoordLocalClient(t *testing.T) { + c := newDataCoordLocalClient() + c.setReadyServer(datapb.UnimplementedDataCoordServer{}) + + _, err := c.GetComponentStates(context.Background(), &milvuspb.GetComponentStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetTimeTickChannel(context.Background(), &internalpb.GetTimeTickChannelRequest{}) + assert.Error(t, err) + + _, err = c.GetStatisticsChannel(context.Background(), &internalpb.GetStatisticsChannelRequest{}) + assert.Error(t, err) + + _, err = c.Flush(context.Background(), &datapb.FlushRequest{}) + assert.Error(t, err) + + _, err = c.AssignSegmentID(context.Background(), &datapb.AssignSegmentIDRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentInfo(context.Background(), &datapb.GetSegmentInfoRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentStates(context.Background(), &datapb.GetSegmentStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetInsertBinlogPaths(context.Background(), &datapb.GetInsertBinlogPathsRequest{}) + assert.Error(t, err) + + _, err = c.GetCollectionStatistics(context.Background(), &datapb.GetCollectionStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetPartitionStatistics(context.Background(), &datapb.GetPartitionStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentInfoChannel(context.Background(), &datapb.GetSegmentInfoChannelRequest{}) + assert.Error(t, err) + + _, err = c.SaveBinlogPaths(context.Background(), &datapb.SaveBinlogPathsRequest{}) + assert.Error(t, err) + + _, err = c.GetRecoveryInfo(context.Background(), &datapb.GetRecoveryInfoRequest{}) + assert.Error(t, err) + + _, err = c.GetRecoveryInfoV2(context.Background(), &datapb.GetRecoveryInfoRequestV2{}) + assert.Error(t, err) + + _, err = c.GetFlushedSegments(context.Background(), &datapb.GetFlushedSegmentsRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentsByStates(context.Background(), &datapb.GetSegmentsByStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetFlushAllState(context.Background(), &milvuspb.GetFlushAllStateRequest{}) + assert.Error(t, err) + + _, err = c.ShowConfigurations(context.Background(), &internalpb.ShowConfigurationsRequest{}) + assert.Error(t, err) + + _, err = c.GetMetrics(context.Background(), &milvuspb.GetMetricsRequest{}) + assert.Error(t, err) + + _, err = c.ManualCompaction(context.Background(), &milvuspb.ManualCompactionRequest{}) + assert.Error(t, err) + + _, err = c.GetCompactionState(context.Background(), &milvuspb.GetCompactionStateRequest{}) + assert.Error(t, err) + + _, err = c.GetCompactionStateWithPlans(context.Background(), &milvuspb.GetCompactionPlansRequest{}) + assert.Error(t, err) + + _, err = c.WatchChannels(context.Background(), &datapb.WatchChannelsRequest{}) + assert.Error(t, err) + + _, err = c.GetFlushState(context.Background(), &datapb.GetFlushStateRequest{}) + assert.Error(t, err) + + _, err = c.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{}) + assert.Error(t, err) + + _, err = c.SetSegmentState(context.Background(), &datapb.SetSegmentStateRequest{}) + assert.Error(t, err) + + _, err = c.UpdateSegmentStatistics(context.Background(), &datapb.UpdateSegmentStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.UpdateChannelCheckpoint(context.Background(), &datapb.UpdateChannelCheckpointRequest{}) + assert.Error(t, err) + + _, err = c.MarkSegmentsDropped(context.Background(), &datapb.MarkSegmentsDroppedRequest{}) + assert.Error(t, err) + + _, err = c.BroadcastAlteredCollection(context.Background(), &datapb.AlterCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{}) + assert.Error(t, err) + + _, err = c.CreateIndex(context.Background(), &indexpb.CreateIndexRequest{}) + assert.Error(t, err) + + _, err = c.AlterIndex(context.Background(), &indexpb.AlterIndexRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexState(context.Background(), &indexpb.GetIndexStateRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentIndexState(context.Background(), &indexpb.GetSegmentIndexStateRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexInfos(context.Background(), &indexpb.GetIndexInfoRequest{}) + assert.Error(t, err) + + _, err = c.DropIndex(context.Background(), &indexpb.DropIndexRequest{}) + assert.Error(t, err) + + _, err = c.DescribeIndex(context.Background(), &indexpb.DescribeIndexRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexStatistics(context.Background(), &indexpb.GetIndexStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexBuildProgress(context.Background(), &indexpb.GetIndexBuildProgressRequest{}) + assert.Error(t, err) + + _, err = c.ListIndexes(context.Background(), &indexpb.ListIndexesRequest{}) + assert.Error(t, err) + + _, err = c.GcConfirm(context.Background(), &datapb.GcConfirmRequest{}) + assert.Error(t, err) + + _, err = c.ReportDataNodeTtMsgs(context.Background(), &datapb.ReportDataNodeTtMsgsRequest{}) + assert.Error(t, err) + + _, err = c.GcControl(context.Background(), &datapb.GcControlRequest{}) + assert.Error(t, err) + + _, err = c.ImportV2(context.Background(), &internalpb.ImportRequestInternal{}) + assert.Error(t, err) + + _, err = c.GetImportProgress(context.Background(), &internalpb.GetImportProgressRequest{}) + assert.Error(t, err) + + _, err = c.ListImports(context.Background(), &internalpb.ListImportsRequestInternal{}) + assert.Error(t, err) +} + +func TestDataCoordLocalClientWithTimeout(t *testing.T) { + c := newDataCoordLocalClient() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := c.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + assert.Error(t, err) + + _, err = c.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + assert.Error(t, err) + + _, err = c.Flush(ctx, &datapb.FlushRequest{}) + assert.Error(t, err) + + _, err = c.AssignSegmentID(ctx, &datapb.AssignSegmentIDRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentStates(ctx, &datapb.GetSegmentStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetInsertBinlogPaths(ctx, &datapb.GetInsertBinlogPathsRequest{}) + assert.Error(t, err) + + _, err = c.GetCollectionStatistics(ctx, &datapb.GetCollectionStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetPartitionStatistics(ctx, &datapb.GetPartitionStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) + assert.Error(t, err) + + _, err = c.SaveBinlogPaths(ctx, &datapb.SaveBinlogPathsRequest{}) + assert.Error(t, err) + + _, err = c.GetRecoveryInfo(ctx, &datapb.GetRecoveryInfoRequest{}) + assert.Error(t, err) + + _, err = c.GetRecoveryInfoV2(ctx, &datapb.GetRecoveryInfoRequestV2{}) + assert.Error(t, err) + + _, err = c.GetFlushedSegments(ctx, &datapb.GetFlushedSegmentsRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentsByStates(ctx, &datapb.GetSegmentsByStatesRequest{}) + assert.Error(t, err) + + _, err = c.GetFlushAllState(ctx, &milvuspb.GetFlushAllStateRequest{}) + assert.Error(t, err) + + _, err = c.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) + assert.Error(t, err) + + _, err = c.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) + assert.Error(t, err) + + _, err = c.ManualCompaction(ctx, &milvuspb.ManualCompactionRequest{}) + assert.Error(t, err) + + _, err = c.GetCompactionState(ctx, &milvuspb.GetCompactionStateRequest{}) + assert.Error(t, err) + + _, err = c.GetCompactionStateWithPlans(ctx, &milvuspb.GetCompactionPlansRequest{}) + assert.Error(t, err) + + _, err = c.WatchChannels(ctx, &datapb.WatchChannelsRequest{}) + assert.Error(t, err) + + _, err = c.GetFlushState(ctx, &datapb.GetFlushStateRequest{}) + assert.Error(t, err) + + _, err = c.DropVirtualChannel(ctx, &datapb.DropVirtualChannelRequest{}) + assert.Error(t, err) + + _, err = c.SetSegmentState(ctx, &datapb.SetSegmentStateRequest{}) + assert.Error(t, err) + + _, err = c.UpdateSegmentStatistics(ctx, &datapb.UpdateSegmentStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.UpdateChannelCheckpoint(ctx, &datapb.UpdateChannelCheckpointRequest{}) + assert.Error(t, err) + + _, err = c.MarkSegmentsDropped(ctx, &datapb.MarkSegmentsDroppedRequest{}) + assert.Error(t, err) + + _, err = c.BroadcastAlteredCollection(ctx, &datapb.AlterCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + assert.Error(t, err) + + _, err = c.CreateIndex(ctx, &indexpb.CreateIndexRequest{}) + assert.Error(t, err) + + _, err = c.AlterIndex(ctx, &indexpb.AlterIndexRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexState(ctx, &indexpb.GetIndexStateRequest{}) + assert.Error(t, err) + + _, err = c.GetSegmentIndexState(ctx, &indexpb.GetSegmentIndexStateRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexInfos(ctx, &indexpb.GetIndexInfoRequest{}) + assert.Error(t, err) + + _, err = c.DropIndex(ctx, &indexpb.DropIndexRequest{}) + assert.Error(t, err) + + _, err = c.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexStatistics(ctx, &indexpb.GetIndexStatisticsRequest{}) + assert.Error(t, err) + + _, err = c.GetIndexBuildProgress(ctx, &indexpb.GetIndexBuildProgressRequest{}) + assert.Error(t, err) + + _, err = c.ListIndexes(ctx, &indexpb.ListIndexesRequest{}) + assert.Error(t, err) + + _, err = c.GcConfirm(ctx, &datapb.GcConfirmRequest{}) + assert.Error(t, err) + + _, err = c.ReportDataNodeTtMsgs(ctx, &datapb.ReportDataNodeTtMsgsRequest{}) + assert.Error(t, err) + + _, err = c.GcControl(ctx, &datapb.GcControlRequest{}) + assert.Error(t, err) + + _, err = c.ImportV2(ctx, &internalpb.ImportRequestInternal{}) + assert.Error(t, err) + + _, err = c.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{}) + assert.Error(t, err) + + _, err = c.ListImports(ctx, &internalpb.ListImportsRequestInternal{}) + assert.Error(t, err) + + c.Close() +} diff --git a/internal/coordinator/coordclient/querycoord.go b/internal/coordinator/coordclient/querycoord.go new file mode 100644 index 0000000000000..a9dccda249ded --- /dev/null +++ b/internal/coordinator/coordclient/querycoord.go @@ -0,0 +1,345 @@ +package coordclient + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +var _ types.QueryCoordClient = &queryCoordLocalClientImpl{} + +// newQueryCoordLocalClient creates a new local client for query coordinator server. +func newQueryCoordLocalClient() *queryCoordLocalClientImpl { + return &queryCoordLocalClientImpl{ + localQueryCoordServer: syncutil.NewFuture[querypb.QueryCoordServer](), + } +} + +// queryCoordLocalClientImpl is used to implement a local client for query coordinator server. +// We need to merge all the coordinator into one server, so use those client to erase the rpc layer between different coord. +type queryCoordLocalClientImpl struct { + localQueryCoordServer *syncutil.Future[querypb.QueryCoordServer] +} + +func (c *queryCoordLocalClientImpl) setReadyServer(server querypb.QueryCoordServer) { + c.localQueryCoordServer.Set(server) +} + +func (c *queryCoordLocalClientImpl) waitForReady(ctx context.Context) (querypb.QueryCoordServer, error) { + return c.localQueryCoordServer.GetWithContext(ctx) +} + +func (c *queryCoordLocalClientImpl) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetComponentStates(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetTimeTickChannel(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetStatisticsChannel(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowCollections(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowPartitions(ctx, in) +} + +func (c *queryCoordLocalClientImpl) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.LoadPartitions(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ReleasePartitions(ctx, in) +} + +func (c *queryCoordLocalClientImpl) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.LoadCollection(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ReleaseCollection(ctx, in) +} + +func (c *queryCoordLocalClientImpl) SyncNewCreatedPartition(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SyncNewCreatedPartition(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetPartitionStates(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetSegmentInfo(ctx, in) +} + +func (c *queryCoordLocalClientImpl) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.LoadBalance(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowConfigurations(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetMetrics(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetReplicas(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetReplicas(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetShardLeaders(ctx context.Context, in *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetShardLeaders(ctx, in) +} + +func (c *queryCoordLocalClientImpl) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CheckHealth(ctx, in) +} + +func (c *queryCoordLocalClientImpl) CreateResourceGroup(ctx context.Context, in *milvuspb.CreateResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateResourceGroup(ctx, in) +} + +func (c *queryCoordLocalClientImpl) UpdateResourceGroups(ctx context.Context, in *querypb.UpdateResourceGroupsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateResourceGroups(ctx, in) +} + +func (c *queryCoordLocalClientImpl) DropResourceGroup(ctx context.Context, in *milvuspb.DropResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropResourceGroup(ctx, in) +} + +func (c *queryCoordLocalClientImpl) TransferNode(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.TransferNode(ctx, in) +} + +func (c *queryCoordLocalClientImpl) TransferReplica(ctx context.Context, in *querypb.TransferReplicaRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.TransferReplica(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ListResourceGroups(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListResourceGroups(ctx, in) +} + +func (c *queryCoordLocalClientImpl) DescribeResourceGroup(ctx context.Context, in *querypb.DescribeResourceGroupRequest, opts ...grpc.CallOption) (*querypb.DescribeResourceGroupResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeResourceGroup(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ListCheckers(ctx context.Context, in *querypb.ListCheckersRequest, opts ...grpc.CallOption) (*querypb.ListCheckersResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListCheckers(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ActivateChecker(ctx context.Context, in *querypb.ActivateCheckerRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ActivateChecker(ctx, in) +} + +func (c *queryCoordLocalClientImpl) DeactivateChecker(ctx context.Context, in *querypb.DeactivateCheckerRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DeactivateChecker(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ListQueryNode(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListQueryNode(ctx, in) +} + +func (c *queryCoordLocalClientImpl) GetQueryNodeDistribution(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetQueryNodeDistribution(ctx, in) +} + +func (c *queryCoordLocalClientImpl) SuspendBalance(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SuspendBalance(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ResumeBalance(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ResumeBalance(ctx, in) +} + +func (c *queryCoordLocalClientImpl) SuspendNode(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SuspendNode(ctx, in) +} + +func (c *queryCoordLocalClientImpl) ResumeNode(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ResumeNode(ctx, in) +} + +func (c *queryCoordLocalClientImpl) TransferSegment(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.TransferSegment(ctx, in) +} + +func (c *queryCoordLocalClientImpl) TransferChannel(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.TransferChannel(ctx, in) +} + +func (c *queryCoordLocalClientImpl) CheckQueryNodeDistribution(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CheckQueryNodeDistribution(ctx, in) +} + +func (c *queryCoordLocalClientImpl) UpdateLoadConfig(ctx context.Context, in *querypb.UpdateLoadConfigRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := c.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateLoadConfig(ctx, in) +} + +func (c *queryCoordLocalClientImpl) Close() error { + return nil +} diff --git a/internal/coordinator/coordclient/querycoord_test.go b/internal/coordinator/coordclient/querycoord_test.go new file mode 100644 index 0000000000000..40bd54c18ee1c --- /dev/null +++ b/internal/coordinator/coordclient/querycoord_test.go @@ -0,0 +1,253 @@ +package coordclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +func TestQueryCoordLocalClient(t *testing.T) { + c := newQueryCoordLocalClient() + c.setReadyServer(querypb.UnimplementedQueryCoordServer{}) + + ctx := context.Background() + + _, err := c.GetComponentStates(ctx, nil) + assert.Error(t, err) + + _, err = c.GetTimeTickChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.GetStatisticsChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowCollections(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowPartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadPartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.ReleasePartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.ReleaseCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.SyncNewCreatedPartition(ctx, nil) + assert.Error(t, err) + + _, err = c.GetPartitionStates(ctx, nil) + assert.Error(t, err) + + _, err = c.GetSegmentInfo(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowConfigurations(ctx, nil) + assert.Error(t, err) + + _, err = c.GetMetrics(ctx, nil) + assert.Error(t, err) + + _, err = c.GetReplicas(ctx, nil) + assert.Error(t, err) + + _, err = c.GetShardLeaders(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckHealth(ctx, nil) + assert.Error(t, err) + + _, err = c.CreateResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.UpdateResourceGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.DropResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferNode(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferReplica(ctx, nil) + assert.Error(t, err) + + _, err = c.ListResourceGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.DescribeResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.ListCheckers(ctx, nil) + assert.Error(t, err) + + _, err = c.ActivateChecker(ctx, nil) + assert.Error(t, err) + + _, err = c.DeactivateChecker(ctx, nil) + assert.Error(t, err) + + _, err = c.ListQueryNode(ctx, nil) + assert.Error(t, err) + + _, err = c.GetQueryNodeDistribution(ctx, nil) + assert.Error(t, err) + + _, err = c.SuspendBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.ResumeBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.SuspendNode(ctx, nil) + assert.Error(t, err) + + _, err = c.ResumeNode(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferSegment(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckQueryNodeDistribution(ctx, nil) + assert.Error(t, err) + + _, err = c.UpdateLoadConfig(ctx, nil) + assert.Error(t, err) + + c.Close() +} + +func TestQueryCoordLocalClientWithTimeout(t *testing.T) { + c := newQueryCoordLocalClient() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := c.GetComponentStates(ctx, nil) + assert.Error(t, err) + + _, err = c.GetTimeTickChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.GetStatisticsChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowCollections(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowPartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadPartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.ReleasePartitions(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.ReleaseCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.SyncNewCreatedPartition(ctx, nil) + assert.Error(t, err) + + _, err = c.GetPartitionStates(ctx, nil) + assert.Error(t, err) + + _, err = c.GetSegmentInfo(ctx, nil) + assert.Error(t, err) + + _, err = c.LoadBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.ShowConfigurations(ctx, nil) + assert.Error(t, err) + + _, err = c.GetMetrics(ctx, nil) + assert.Error(t, err) + + _, err = c.GetReplicas(ctx, nil) + assert.Error(t, err) + + _, err = c.GetShardLeaders(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckHealth(ctx, nil) + assert.Error(t, err) + + _, err = c.CreateResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.UpdateResourceGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.DropResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferNode(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferReplica(ctx, nil) + assert.Error(t, err) + + _, err = c.ListResourceGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.DescribeResourceGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.ListCheckers(ctx, nil) + assert.Error(t, err) + + _, err = c.ActivateChecker(ctx, nil) + assert.Error(t, err) + + _, err = c.DeactivateChecker(ctx, nil) + assert.Error(t, err) + + _, err = c.ListQueryNode(ctx, nil) + assert.Error(t, err) + + _, err = c.GetQueryNodeDistribution(ctx, nil) + assert.Error(t, err) + + _, err = c.SuspendBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.ResumeBalance(ctx, nil) + assert.Error(t, err) + + _, err = c.SuspendNode(ctx, nil) + assert.Error(t, err) + + _, err = c.ResumeNode(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferSegment(ctx, nil) + assert.Error(t, err) + + _, err = c.TransferChannel(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckQueryNodeDistribution(ctx, nil) + assert.Error(t, err) + + _, err = c.UpdateLoadConfig(ctx, nil) + assert.Error(t, err) +} diff --git a/internal/coordinator/coordclient/registry.go b/internal/coordinator/coordclient/registry.go new file mode 100644 index 0000000000000..558381594b8d4 --- /dev/null +++ b/internal/coordinator/coordclient/registry.go @@ -0,0 +1,120 @@ +package coordclient + +import ( + "context" + "fmt" + + "go.uber.org/zap" + + dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// localClient is a client that can access local server directly +type localClient struct { + queryCoordClient *queryCoordLocalClientImpl + dataCoordClient *dataCoordLocalClientImpl + rootCoordClient *rootCoordLocalClientImpl +} + +var ( + enableLocal *LocalClientRoleConfig // a global map to store all can be local accessible roles. + glocalClient *localClient +) + +func init() { + enableLocal = &LocalClientRoleConfig{} + glocalClient = &localClient{ + queryCoordClient: newQueryCoordLocalClient(), + dataCoordClient: newDataCoordLocalClient(), + rootCoordClient: newRootCoordLocalClient(), + } +} + +type LocalClientRoleConfig struct { + ServerType string + EnableQueryCoord bool + EnableDataCoord bool + EnableRootCoord bool +} + +// EnableLocalClientRole init localable roles +func EnableLocalClientRole(cfg *LocalClientRoleConfig) { + if cfg.ServerType != typeutil.StandaloneRole && cfg.ServerType != typeutil.MixtureRole { + return + } + enableLocal = cfg +} + +// RegisterQueryCoordServer register query coord server +func RegisterQueryCoordServer(server querypb.QueryCoordServer) { + if !enableLocal.EnableQueryCoord { + return + } + glocalClient.queryCoordClient.setReadyServer(server) + log.Info("register query coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// RegsterDataCoordServer register data coord server +func RegisterDataCoordServer(server datapb.DataCoordServer) { + if !enableLocal.EnableDataCoord { + return + } + glocalClient.dataCoordClient.setReadyServer(server) + log.Info("register data coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// RegisterRootCoordServer register root coord server +func RegisterRootCoordServer(server rootcoordpb.RootCoordServer) { + if !enableLocal.EnableRootCoord { + return + } + glocalClient.rootCoordClient.setReadyServer(server) + log.Info("register root coord server", zap.Any("enableLocalClient", enableLocal)) +} + +// GetQueryCoordClient return query coord client +func GetQueryCoordClient(ctx context.Context) types.QueryCoordClient { + if enableLocal.EnableQueryCoord { + return glocalClient.queryCoordClient + } + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + queryCoordClient, err := qcc.NewClient(ctx) + if err != nil { + panic(fmt.Sprintf("get query coord client failed: %v", err)) + } + return queryCoordClient +} + +// GetDataCoordClient return data coord client +func GetDataCoordClient(ctx context.Context) types.DataCoordClient { + if enableLocal.EnableDataCoord { + return glocalClient.dataCoordClient + } + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + dataCoordClient, err := dcc.NewClient(ctx) + if err != nil { + panic(fmt.Sprintf("get data coord client failed: %v", err)) + } + return dataCoordClient +} + +// GetRootCoordClient return root coord client +func GetRootCoordClient(ctx context.Context) types.RootCoordClient { + if enableLocal.EnableRootCoord { + return glocalClient.rootCoordClient + } + // TODO: we should make a singleton here. but most unittest rely on a dedicated client. + rootCoordClient, err := rcc.NewClient(ctx) + if err != nil { + panic(fmt.Sprintf("get root coord client failed: %v", err)) + } + return rootCoordClient +} diff --git a/internal/coordinator/coordclient/registry_test.go b/internal/coordinator/coordclient/registry_test.go new file mode 100644 index 0000000000000..f16918b72bad5 --- /dev/null +++ b/internal/coordinator/coordclient/registry_test.go @@ -0,0 +1,71 @@ +package coordclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestRegistry(t *testing.T) { + assert.False(t, enableLocal.EnableQueryCoord) + assert.False(t, enableLocal.EnableDataCoord) + assert.False(t, enableLocal.EnableRootCoord) + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.RootCoordRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.False(t, enableLocal.EnableQueryCoord) + assert.False(t, enableLocal.EnableDataCoord) + assert.False(t, enableLocal.EnableRootCoord) + + RegisterRootCoordServer(&rootcoordpb.UnimplementedRootCoordServer{}) + RegisterDataCoordServer(&datapb.UnimplementedDataCoordServer{}) + RegisterQueryCoordServer(&querypb.UnimplementedQueryCoordServer{}) + assert.False(t, glocalClient.dataCoordClient.localDataCoordServer.Ready()) + assert.False(t, glocalClient.queryCoordClient.localQueryCoordServer.Ready()) + assert.False(t, glocalClient.rootCoordClient.localRootCoordServer.Ready()) + + enableLocal = &LocalClientRoleConfig{} + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.StandaloneRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.True(t, enableLocal.EnableDataCoord) + assert.True(t, enableLocal.EnableQueryCoord) + assert.True(t, enableLocal.EnableRootCoord) + + RegisterRootCoordServer(&rootcoordpb.UnimplementedRootCoordServer{}) + RegisterDataCoordServer(&datapb.UnimplementedDataCoordServer{}) + RegisterQueryCoordServer(&querypb.UnimplementedQueryCoordServer{}) + assert.True(t, glocalClient.dataCoordClient.localDataCoordServer.Ready()) + assert.True(t, glocalClient.queryCoordClient.localQueryCoordServer.Ready()) + assert.True(t, glocalClient.rootCoordClient.localRootCoordServer.Ready()) + + enableLocal = &LocalClientRoleConfig{} + + EnableLocalClientRole(&LocalClientRoleConfig{ + ServerType: typeutil.MixtureRole, + EnableQueryCoord: true, + EnableDataCoord: true, + EnableRootCoord: true, + }) + assert.True(t, enableLocal.EnableDataCoord) + assert.True(t, enableLocal.EnableQueryCoord) + assert.True(t, enableLocal.EnableRootCoord) + + assert.NotNil(t, GetQueryCoordClient(context.Background())) + assert.NotNil(t, GetDataCoordClient(context.Background())) + assert.NotNil(t, GetRootCoordClient(context.Background())) +} diff --git a/internal/coordinator/coordclient/rootcoord.go b/internal/coordinator/coordclient/rootcoord.go new file mode 100644 index 0000000000000..224d7c4942c76 --- /dev/null +++ b/internal/coordinator/coordclient/rootcoord.go @@ -0,0 +1,466 @@ +package coordclient + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +var _ types.RootCoordClient = &rootCoordLocalClientImpl{} + +// newRootCoordLocalClient creates a new local client for root coordinator server. +func newRootCoordLocalClient() *rootCoordLocalClientImpl { + return &rootCoordLocalClientImpl{ + localRootCoordServer: syncutil.NewFuture[rootcoordpb.RootCoordServer](), + } +} + +// rootCoordLocalClientImpl is used to implement a local client for root coordinator server. +// We need to merge all the coordinator into one server, so use those client to erase the rpc layer between different coord. +type rootCoordLocalClientImpl struct { + localRootCoordServer *syncutil.Future[rootcoordpb.RootCoordServer] +} + +func (r *rootCoordLocalClientImpl) waitForReady(ctx context.Context) (rootcoordpb.RootCoordServer, error) { + return r.localRootCoordServer.GetWithContext(ctx) +} + +func (r *rootCoordLocalClientImpl) setReadyServer(server rootcoordpb.RootCoordServer) { + r.localRootCoordServer.Set(server) +} + +func (r *rootCoordLocalClientImpl) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetComponentStates(ctx, in) +} + +func (r *rootCoordLocalClientImpl) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetTimeTickChannel(ctx, in) +} + +func (r *rootCoordLocalClientImpl) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetStatisticsChannel(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.HasCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeCollectionInternal(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateAlias(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropAlias(ctx, in) +} + +func (r *rootCoordLocalClientImpl) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AlterAlias(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DescribeAlias(ctx context.Context, in *milvuspb.DescribeAliasRequest, opts ...grpc.CallOption) (*milvuspb.DescribeAliasResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeAlias(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ListAliases(ctx context.Context, in *milvuspb.ListAliasesRequest, opts ...grpc.CallOption) (*milvuspb.ListAliasesResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListAliases(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowCollections(ctx, in) +} + +func (r *rootCoordLocalClientImpl) AlterCollection(ctx context.Context, in *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AlterCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreatePartition(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropPartition(ctx, in) +} + +func (r *rootCoordLocalClientImpl) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.HasPartition(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowPartitions(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowPartitionsInternal(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowSegments(ctx, in) +} + +func (r *rootCoordLocalClientImpl) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AllocTimestamp(ctx, in) +} + +func (r *rootCoordLocalClientImpl) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AllocID(ctx, in) +} + +func (r *rootCoordLocalClientImpl) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateChannelTimeTick(ctx, in) +} + +func (r *rootCoordLocalClientImpl) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.InvalidateCollectionMetaCache(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ShowConfigurations(ctx, in) +} + +func (r *rootCoordLocalClientImpl) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetMetrics(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreateCredential(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateCredential(ctx, in) +} + +func (r *rootCoordLocalClientImpl) UpdateCredential(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.UpdateCredential(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DeleteCredential(ctx context.Context, in *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DeleteCredential(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListCredUsers(ctx, in) +} + +func (r *rootCoordLocalClientImpl) GetCredential(ctx context.Context, in *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.GetCredential(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreateRole(ctx context.Context, in *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateRole(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropRole(ctx, in) +} + +func (r *rootCoordLocalClientImpl) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.OperateUserRole(ctx, in) +} + +func (r *rootCoordLocalClientImpl) SelectRole(ctx context.Context, in *milvuspb.SelectRoleRequest, opts ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SelectRole(ctx, in) +} + +func (r *rootCoordLocalClientImpl) SelectUser(ctx context.Context, in *milvuspb.SelectUserRequest, opts ...grpc.CallOption) (*milvuspb.SelectUserResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SelectUser(ctx, in) +} + +func (r *rootCoordLocalClientImpl) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.OperatePrivilege(ctx, in) +} + +func (r *rootCoordLocalClientImpl) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest, opts ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.SelectGrant(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListPolicy(ctx, in) +} + +func (r *rootCoordLocalClientImpl) BackupRBAC(ctx context.Context, in *milvuspb.BackupRBACMetaRequest, opts ...grpc.CallOption) (*milvuspb.BackupRBACMetaResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.BackupRBAC(ctx, in) +} + +func (r *rootCoordLocalClientImpl) RestoreRBAC(ctx context.Context, in *milvuspb.RestoreRBACMetaRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.RestoreRBAC(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreatePrivilegeGroup(ctx context.Context, in *milvuspb.CreatePrivilegeGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreatePrivilegeGroup(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropPrivilegeGroup(ctx context.Context, in *milvuspb.DropPrivilegeGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropPrivilegeGroup(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ListPrivilegeGroups(ctx context.Context, in *milvuspb.ListPrivilegeGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListPrivilegeGroupsResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListPrivilegeGroups(ctx, in) +} + +func (r *rootCoordLocalClientImpl) OperatePrivilegeGroup(ctx context.Context, in *milvuspb.OperatePrivilegeGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.OperatePrivilegeGroup(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CheckHealth(ctx, in) +} + +func (r *rootCoordLocalClientImpl) RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.RenameCollection(ctx, in) +} + +func (r *rootCoordLocalClientImpl) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.CreateDatabase(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DropDatabase(ctx, in) +} + +func (r *rootCoordLocalClientImpl) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.ListDatabases(ctx, in) +} + +func (r *rootCoordLocalClientImpl) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.DescribeDatabase(ctx, in) +} + +func (r *rootCoordLocalClientImpl) AlterDatabase(ctx context.Context, in *rootcoordpb.AlterDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + s, err := r.waitForReady(ctx) + if err != nil { + return nil, err + } + return s.AlterDatabase(ctx, in) +} + +func (r *rootCoordLocalClientImpl) Close() error { + return nil +} diff --git a/internal/coordinator/coordclient/rootcoord_test.go b/internal/coordinator/coordclient/rootcoord_test.go new file mode 100644 index 0000000000000..a4fc81c611304 --- /dev/null +++ b/internal/coordinator/coordclient/rootcoord_test.go @@ -0,0 +1,322 @@ +package coordclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +func TestRootCoordLocalClient(t *testing.T) { + c := newRootCoordLocalClient() + c.setReadyServer(rootcoordpb.UnimplementedRootCoordServer{}) + + ctx := context.Background() + + _, err := c.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + assert.Error(t, err) + + _, err = c.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{}) + assert.Error(t, err) + + _, err = c.DropCollection(ctx, &milvuspb.DropCollectionRequest{}) + assert.Error(t, err) + + _, err = c.HasCollection(ctx, &milvuspb.HasCollectionRequest{}) + assert.Error(t, err) + + _, err = c.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CreateAlias(ctx, &milvuspb.CreateAliasRequest{}) + assert.Error(t, err) + + _, err = c.DropAlias(ctx, &milvuspb.DropAliasRequest{}) + assert.Error(t, err) + + _, err = c.AlterAlias(ctx, &milvuspb.AlterAliasRequest{}) + assert.Error(t, err) + + _, err = c.DescribeAlias(ctx, &milvuspb.DescribeAliasRequest{}) + assert.Error(t, err) + + _, err = c.ListAliases(ctx, &milvuspb.ListAliasesRequest{}) + assert.Error(t, err) + + _, err = c.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + assert.Error(t, err) + + _, err = c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{}) + assert.Error(t, err) + + _, err = c.DropPartition(ctx, &milvuspb.DropPartitionRequest{}) + assert.Error(t, err) + + _, err = c.HasPartition(ctx, &milvuspb.HasPartitionRequest{}) + assert.Error(t, err) + + _, err = c.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{}) + assert.Error(t, err) + + _, err = c.ShowSegments(ctx, &milvuspb.ShowSegmentsRequest{}) + assert.Error(t, err) + + _, err = c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{}) + assert.Error(t, err) + + _, err = c.AllocID(ctx, &rootcoordpb.AllocIDRequest{}) + assert.Error(t, err) + + _, err = c.UpdateChannelTimeTick(ctx, &internalpb.ChannelTimeTickMsg{}) + assert.Error(t, err) + + _, err = c.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.Error(t, err) + + _, err = c.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) + assert.Error(t, err) + + _, err = c.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) + assert.Error(t, err) + + _, err = c.CreateCredential(ctx, &internalpb.CredentialInfo{}) + assert.Error(t, err) + + _, err = c.UpdateCredential(ctx, &internalpb.CredentialInfo{}) + assert.Error(t, err) + + _, err = c.DeleteCredential(ctx, &milvuspb.DeleteCredentialRequest{}) + assert.Error(t, err) + + _, err = c.ListCredUsers(ctx, &milvuspb.ListCredUsersRequest{}) + assert.Error(t, err) + + _, err = c.GetCredential(ctx, &rootcoordpb.GetCredentialRequest{}) + assert.Error(t, err) + + _, err = c.CreateRole(ctx, nil) + assert.Error(t, err) + + _, err = c.DropRole(ctx, nil) + assert.Error(t, err) + + _, err = c.OperateUserRole(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectRole(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectUser(ctx, nil) + assert.Error(t, err) + + _, err = c.OperatePrivilege(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectGrant(ctx, nil) + assert.Error(t, err) + + _, err = c.ListPolicy(ctx, nil) + assert.Error(t, err) + + _, err = c.BackupRBAC(ctx, nil) + assert.Error(t, err) + + _, err = c.RestoreRBAC(ctx, nil) + assert.Error(t, err) + + _, err = c.CreatePrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.DropPrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.ListPrivilegeGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.OperatePrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckHealth(ctx, nil) + assert.Error(t, err) + + _, err = c.RenameCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.CreateDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.DropDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.ListDatabases(ctx, nil) + assert.Error(t, err) + + _, err = c.DescribeDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.AlterDatabase(ctx, nil) + assert.Error(t, err) + + c.Close() +} + +func TestRootCoordLocalClientWithTimeout(t *testing.T) { + c := newRootCoordLocalClient() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := c.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + assert.Error(t, err) + + _, err = c.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{}) + assert.Error(t, err) + + _, err = c.DropCollection(ctx, &milvuspb.DropCollectionRequest{}) + assert.Error(t, err) + + _, err = c.HasCollection(ctx, &milvuspb.HasCollectionRequest{}) + assert.Error(t, err) + + _, err = c.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CreateAlias(ctx, &milvuspb.CreateAliasRequest{}) + assert.Error(t, err) + + _, err = c.DropAlias(ctx, &milvuspb.DropAliasRequest{}) + assert.Error(t, err) + + _, err = c.AlterAlias(ctx, &milvuspb.AlterAliasRequest{}) + assert.Error(t, err) + + _, err = c.DescribeAlias(ctx, &milvuspb.DescribeAliasRequest{}) + assert.Error(t, err) + + _, err = c.ListAliases(ctx, &milvuspb.ListAliasesRequest{}) + assert.Error(t, err) + + _, err = c.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + assert.Error(t, err) + + _, err = c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{}) + assert.Error(t, err) + + _, err = c.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{}) + assert.Error(t, err) + + _, err = c.DropPartition(ctx, &milvuspb.DropPartitionRequest{}) + assert.Error(t, err) + + _, err = c.HasPartition(ctx, &milvuspb.HasPartitionRequest{}) + assert.Error(t, err) + + _, err = c.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{}) + assert.Error(t, err) + + _, err = c.ShowSegments(ctx, &milvuspb.ShowSegmentsRequest{}) + assert.Error(t, err) + + _, err = c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{}) + assert.Error(t, err) + + _, err = c.AllocID(ctx, &rootcoordpb.AllocIDRequest{}) + assert.Error(t, err) + + _, err = c.UpdateChannelTimeTick(ctx, &internalpb.ChannelTimeTickMsg{}) + assert.Error(t, err) + + _, err = c.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.Error(t, err) + + _, err = c.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) + assert.Error(t, err) + + _, err = c.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) + assert.Error(t, err) + + _, err = c.CreateCredential(ctx, &internalpb.CredentialInfo{}) + assert.Error(t, err) + + _, err = c.UpdateCredential(ctx, &internalpb.CredentialInfo{}) + assert.Error(t, err) + + _, err = c.DeleteCredential(ctx, &milvuspb.DeleteCredentialRequest{}) + assert.Error(t, err) + + _, err = c.ListCredUsers(ctx, &milvuspb.ListCredUsersRequest{}) + assert.Error(t, err) + + _, err = c.GetCredential(ctx, &rootcoordpb.GetCredentialRequest{}) + assert.Error(t, err) + + _, err = c.CreateRole(ctx, nil) + assert.Error(t, err) + + _, err = c.DropRole(ctx, nil) + assert.Error(t, err) + + _, err = c.OperateUserRole(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectRole(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectUser(ctx, nil) + assert.Error(t, err) + + _, err = c.OperatePrivilege(ctx, nil) + assert.Error(t, err) + + _, err = c.SelectGrant(ctx, nil) + assert.Error(t, err) + + _, err = c.ListPolicy(ctx, nil) + assert.Error(t, err) + + _, err = c.BackupRBAC(ctx, nil) + assert.Error(t, err) + + _, err = c.RestoreRBAC(ctx, nil) + assert.Error(t, err) + + _, err = c.CreatePrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.DropPrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.ListPrivilegeGroups(ctx, nil) + assert.Error(t, err) + + _, err = c.OperatePrivilegeGroup(ctx, nil) + assert.Error(t, err) + + _, err = c.CheckHealth(ctx, nil) + assert.Error(t, err) + + _, err = c.RenameCollection(ctx, nil) + assert.Error(t, err) + + _, err = c.CreateDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.DropDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.ListDatabases(ctx, nil) + assert.Error(t, err) + + _, err = c.DescribeDatabase(ctx, nil) + assert.Error(t, err) + + _, err = c.AlterDatabase(ctx, nil) + assert.Error(t, err) +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index ecff0a56e3d26..273cafe3d52c9 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -35,10 +35,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" globalIDAllocator "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/datacoord/broker" datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" - rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/kv/tikv" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" @@ -252,7 +252,7 @@ func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) } func defaultRootCoordCreatorFunc(ctx context.Context) (types.RootCoordClient, error) { - return rootcoordclient.NewClient(ctx) + return coordclient.GetRootCoordClient(ctx), nil } // QuitSignal returns signal when server quits diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 9b3f07f4c2d77..6639f52a367a7 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/datacoord" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -200,6 +201,7 @@ func (s *Server) startGrpcLoop() { grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) indexpb.RegisterIndexCoordServer(s.grpcServer, s) datapb.RegisterDataCoordServer(s.grpcServer, s) + coordclient.RegisterDataCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 25b903c4edc9c..77d4cc2596a49 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -31,8 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -169,11 +168,7 @@ func (s *Server) init() error { // --- Master Server Client --- if s.rootCoord == nil { - s.rootCoord, err = rcc.NewClient(s.loopCtx) - if err != nil { - log.Error("QueryCoord try to new RootCoord client failed", zap.Error(err)) - panic(err) - } + s.rootCoord = coordclient.GetRootCoordClient(s.loopCtx) } // wait for master init or healthy @@ -191,11 +186,7 @@ func (s *Server) init() error { // --- Data service client --- if s.dataCoord == nil { - s.dataCoord, err = dcc.NewClient(s.loopCtx) - if err != nil { - log.Error("QueryCoord try to new DataCoord client failed", zap.Error(err)) - panic(err) - } + s.dataCoord = coordclient.GetDataCoordClient(s.loopCtx) } log.Info("QueryCoord try to wait for DataCoord ready") @@ -258,6 +249,7 @@ func (s *Server) startGrpcLoop() { grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), ) querypb.RegisterQueryCoordServer(s.grpcServer, s) + coordclient.RegisterQueryCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 4f245d1b07654..e99a81dd50cf9 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -31,8 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + "github.com/milvus-io/milvus/internal/coordinator/coordclient" "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" @@ -72,8 +71,8 @@ type Server struct { dataCoord types.DataCoordClient queryCoord types.QueryCoordClient - newDataCoordClient func() types.DataCoordClient - newQueryCoordClient func() types.QueryCoordClient + newDataCoordClient func(ctx context.Context) types.DataCoordClient + newQueryCoordClient func(ctx context.Context) types.QueryCoordClient } func (s *Server) DescribeDatabase(ctx context.Context, request *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) { @@ -157,21 +156,8 @@ func (s *Server) Prepare() error { } func (s *Server) setClient() { - s.newDataCoordClient = func() types.DataCoordClient { - dsClient, err := dcc.NewClient(s.ctx) - if err != nil { - panic(err) - } - return dsClient - } - - s.newQueryCoordClient = func() types.QueryCoordClient { - qsClient, err := qcc.NewClient(s.ctx) - if err != nil { - panic(err) - } - return qsClient - } + s.newDataCoordClient = coordclient.GetDataCoordClient + s.newQueryCoordClient = coordclient.GetQueryCoordClient } // Run initializes and starts RootCoord's grpc service. @@ -234,7 +220,7 @@ func (s *Server) init() error { if s.newDataCoordClient != nil { log.Info("RootCoord start to create DataCoord client") - dataCoord := s.newDataCoordClient() + dataCoord := s.newDataCoordClient(s.ctx) s.dataCoord = dataCoord if err := s.rootCoord.SetDataCoordClient(dataCoord); err != nil { panic(err) @@ -243,7 +229,7 @@ func (s *Server) init() error { if s.newQueryCoordClient != nil { log.Info("RootCoord start to create QueryCoord client") - queryCoord := s.newQueryCoordClient() + queryCoord := s.newQueryCoordClient(s.ctx) s.queryCoord = queryCoord if err := s.rootCoord.SetQueryCoordClient(queryCoord); err != nil { panic(err) @@ -305,6 +291,7 @@ func (s *Server) startGrpcLoop() { )), grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) + coordclient.RegisterRootCoordServer(s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index 43c9ba6ba0534..5965b47883c2c 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -142,13 +142,13 @@ func TestRun(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - svr.newDataCoordClient = func() types.DataCoordClient { + svr.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - svr.newQueryCoordClient = func() types.QueryCoordClient { + svr.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } @@ -238,7 +238,7 @@ func TestServerRun_DataCoordClientInitErr(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - server.newDataCoordClient = func() types.DataCoordClient { + server.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } err = server.Prepare() @@ -268,7 +268,7 @@ func TestServerRun_DataCoordClientStartErr(t *testing.T) { mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) - server.newDataCoordClient = func() types.DataCoordClient { + server.newDataCoordClient = func(_ context.Context) types.DataCoordClient { return mockDataCoord } err = server.Prepare() @@ -298,7 +298,7 @@ func TestServerRun_QueryCoordClientInitErr(t *testing.T) { mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - server.newQueryCoordClient = func() types.QueryCoordClient { + server.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } err = server.Prepare() @@ -328,7 +328,7 @@ func TestServer_QueryCoordClientStartErr(t *testing.T) { mockQueryCoord := mocks.NewMockQueryCoordClient(t) mockQueryCoord.EXPECT().Close().Return(nil) - server.newQueryCoordClient = func() types.QueryCoordClient { + server.newQueryCoordClient = func(_ context.Context) types.QueryCoordClient { return mockQueryCoord } err = server.Prepare() diff --git a/pkg/util/syncutil/future.go b/pkg/util/syncutil/future.go new file mode 100644 index 0000000000000..f13c40f58e14e --- /dev/null +++ b/pkg/util/syncutil/future.go @@ -0,0 +1,56 @@ +package syncutil + +import ( + "context" +) + +// Future is a future value that can be set and retrieved. +type Future[T any] struct { + ch chan struct{} + value T +} + +// NewFuture creates a new future. +func NewFuture[T any]() *Future[T] { + return &Future[T]{ + ch: make(chan struct{}), + } +} + +// Set sets the value of the future. +func (f *Future[T]) Set(value T) { + f.value = value + close(f.ch) +} + +// GetWithContext retrieves the value of the future if set, otherwise block until set or the context is done. +func (f *Future[T]) GetWithContext(ctx context.Context) (T, error) { + select { + case <-ctx.Done(): + var val T + return val, ctx.Err() + case <-f.ch: + return f.value, nil + } +} + +// Get retrieves the value of the future if set, otherwise block until set. +func (f *Future[T]) Get() T { + <-f.ch + return f.value +} + +// Done returns a channel that is closed when the future is set. +func (f *Future[T]) Done() <-chan struct{} { + return f.ch +} + +// Ready returns true if the future is set. +func (f *Future[T]) Ready() bool { + select { + case <-f.ch: + return true + default: + return false + } +} diff --git a/pkg/util/syncutil/future_test.go b/pkg/util/syncutil/future_test.go new file mode 100644 index 0000000000000..3e0c567789218 --- /dev/null +++ b/pkg/util/syncutil/future_test.go @@ -0,0 +1,51 @@ +package syncutil + +import ( + "testing" + "time" +) + +func TestFuture_SetAndGet(t *testing.T) { + f := NewFuture[int]() + go func() { + time.Sleep(1 * time.Second) // Simulate some work + f.Set(42) + }() + + val := f.Get() + if val != 42 { + t.Errorf("Expected value 42, got %d", val) + } +} + +func TestFuture_Done(t *testing.T) { + f := NewFuture[string]() + go func() { + f.Set("done") + }() + + select { + case <-f.Done(): + // Success + case <-time.After(20 * time.Millisecond): + t.Error("Expected future to be done within 2 seconds") + } +} + +func TestFuture_Ready(t *testing.T) { + f := NewFuture[float64]() + go func() { + time.Sleep(20 * time.Millisecond) // Simulate some work + f.Set(3.14) + }() + + if f.Ready() { + t.Error("Expected future not to be ready immediately") + } + + <-f.Done() // Wait for the future to be set + + if !f.Ready() { + t.Error("Expected future to be ready after being set") + } +}