Skip to content

Commit

Permalink
enhance: [10kcp] Accelerate the loading of collection (#37879)
Browse files Browse the repository at this point in the history
Remove unnecessary ListIndex and DescribeCollection RPC call during
loading.

issue: #37166,
#37630

pr: #37741

Signed-off-by: bigsheeper <[email protected]>
  • Loading branch information
bigsheeper authored Nov 21, 2024
1 parent 9e1ba07 commit ac7b485
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 136 deletions.
2 changes: 1 addition & 1 deletion internal/proto/query_coord.proto
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ message LoadPartitionsRequest {
bool refresh = 8;
// resource group names
repeated string resource_groups = 9;
repeated index.IndexInfo index_info_list = 10;
repeated index.IndexInfo index_info_list = 10; // deprecated
repeated int64 load_fields = 11;
}

Expand Down
4 changes: 2 additions & 2 deletions internal/querycoordv2/job/job_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (job *LoadCollectionJob) Execute() error {
}

// 3. loadPartitions on QueryNodes
err = loadPartitions(job.ctx, job.meta, job.cluster, job.broker, true, req.GetCollectionID(), lackPartitionIDs...)
err = loadPartitions(job.ctx, job.meta, job.cluster, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
Expand Down Expand Up @@ -400,7 +400,7 @@ func (job *LoadPartitionJob) Execute() error {
}

// 3. loadPartitions on QueryNodes
err = loadPartitions(job.ctx, job.meta, job.cluster, job.broker, true, req.GetCollectionID(), lackPartitionIDs...)
err = loadPartitions(job.ctx, job.meta, job.cluster, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/job/job_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (job *SyncNewCreatedPartitionJob) Execute() error {
return nil
}

err := loadPartitions(job.ctx, job.meta, job.cluster, job.broker, false, req.GetCollectionID(), req.GetPartitionID())
err := loadPartitions(job.ctx, job.meta, job.cluster, req.GetCollectionID(), req.GetPartitionID())
if err != nil {
return err
}
Expand Down
105 changes: 0 additions & 105 deletions internal/querycoordv2/job/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ func (suite *JobSuite) SetupSuite() {

suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).
Return(nil, nil)
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).
Return(nil, nil)

suite.cluster = session.NewMockCluster(suite.T())
suite.cluster.EXPECT().
Expand Down Expand Up @@ -1385,109 +1383,6 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
}
}

func (suite *JobSuite) TestCallLoadPartitionFailed() {
// call LoadPartitions failed at get index info
getIndexErr := fmt.Errorf("mock get index error")
suite.broker.ExpectedCalls = lo.Filter(suite.broker.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "ListIndexes"
})
for _, collection := range suite.collections {
suite.broker.EXPECT().ListIndexes(mock.Anything, collection).Return(nil, getIndexErr)
loadCollectionReq := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
loadCollectionJob := NewLoadCollectionJob(
context.Background(),
loadCollectionReq,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(loadCollectionJob)
err := loadCollectionJob.Wait()
suite.T().Logf("%s", err)
suite.ErrorIs(err, getIndexErr)

loadPartitionReq := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
loadPartitionJob := NewLoadPartitionJob(
context.Background(),
loadPartitionReq,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(loadPartitionJob)
err = loadPartitionJob.Wait()
suite.ErrorIs(err, getIndexErr)
}

// call LoadPartitions failed at get schema
getSchemaErr := fmt.Errorf("mock get schema error")
suite.broker.ExpectedCalls = lo.Filter(suite.broker.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "DescribeCollection"
})
for _, collection := range suite.collections {
suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(nil, getSchemaErr)
loadCollectionReq := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
loadCollectionJob := NewLoadCollectionJob(
context.Background(),
loadCollectionReq,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(loadCollectionJob)
err := loadCollectionJob.Wait()
suite.ErrorIs(err, getSchemaErr)

loadPartitionReq := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
loadPartitionJob := NewLoadPartitionJob(
context.Background(),
loadPartitionReq,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(loadPartitionJob)
err = loadPartitionJob.Wait()
suite.ErrorIs(err, getSchemaErr)
}

suite.broker.ExpectedCalls = lo.Filter(suite.broker.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "ListIndexes" && call.Method != "DescribeCollection"
})
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil)
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil)
}

func (suite *JobSuite) TestCallReleasePartitionFailed() {
ctx := context.Background()
suite.loadAll()
Expand Down
33 changes: 10 additions & 23 deletions internal/querycoordv2/job/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import (
"time"

"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
Expand Down Expand Up @@ -71,41 +71,25 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
func loadPartitions(ctx context.Context,
meta *meta.Meta,
cluster session.Cluster,
broker meta.Broker,
withSchema bool,
collection int64,
partitions ...int64,
) error {
var err error
var schema *schemapb.CollectionSchema
if withSchema {
collectionInfo, err := broker.DescribeCollection(ctx, collection)
if err != nil {
return err
}
schema = collectionInfo.GetSchema()
}
indexes, err := broker.ListIndexes(ctx, collection)
if err != nil {
return err
}
_, span := otel.Tracer(typeutil.QueryCoordRole).Start(ctx, "loadPartitions")
defer span.End()
start := time.Now()

replicas := meta.ReplicaManager.GetByCollection(collection)
loadReq := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: collection,
PartitionIDs: partitions,
Schema: schema,
IndexInfoList: indexes,
CollectionID: collection,
PartitionIDs: partitions,
}
for _, replica := range replicas {
for _, node := range replica.GetNodes() {
status, err := cluster.LoadPartitions(ctx, node, loadReq)
// There is no need to rollback LoadPartitions as the load job will fail
// and the Delegator will not be created,
// resulting in search and query requests failing due to the absence of Delegator.
// TODO: rollback LoadPartitions if failed
if err != nil {
return err
}
Expand All @@ -114,6 +98,9 @@ func loadPartitions(ctx context.Context,
}
}
}

log.Ctx(ctx).Info("load partitions done", zap.Int64("collectionID", collection),
zap.Int64s("partitionIDs", partitions), zap.Duration("dur", time.Since(start)))
return nil
}

Expand Down
25 changes: 21 additions & 4 deletions internal/querycoordv2/services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package querycoordv2
import (
"context"
"encoding/json"
"fmt"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -344,6 +345,7 @@ func (suite *ServiceSuite) TestLoadCollection() {
// Test load all collections
for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection)
suite.expectDescribeCollection()
suite.expectLoadPartitions()

req := &querypb.LoadCollectionRequest{
Expand Down Expand Up @@ -913,6 +915,7 @@ func (suite *ServiceSuite) TestLoadPartition() {
// Test load all partitions
for _, collection := range suite.collections {
suite.expectLoadPartitions()
suite.expectDescribeCollection()
suite.expectGetRecoverInfo(collection)

req := &querypb.LoadPartitionsRequest{
Expand Down Expand Up @@ -1096,6 +1099,9 @@ func (suite *ServiceSuite) TestRefreshCollection() {
// Test load all collections
suite.loadAll()

suite.expectListIndexes()
suite.expectLoadPartitions()

// Test refresh all collections again when collections are loaded. This time should fail with collection not 100% loaded.
for _, collection := range suite.collections {
suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading)
Expand All @@ -1115,7 +1121,11 @@ func (suite *ServiceSuite) TestRefreshCollection() {

readyCh, err := server.targetObserver.UpdateNextTarget(id)
suite.NoError(err)
<-readyCh
select {
case <-time.After(30 * time.Second):
suite.Fail(fmt.Sprintf("update next target timeout, collection=%d", id))
case <-readyCh:
}

// Now the refresh must be done
collection := server.meta.CollectionManager.GetCollection(id)
Expand Down Expand Up @@ -1802,8 +1812,9 @@ func (suite *ServiceSuite) TestHandleNodeUp() {
func (suite *ServiceSuite) loadAll() {
ctx := context.Background()
for _, collection := range suite.collections {
suite.expectLoadPartitions()
suite.expectDescribeCollection()
suite.expectGetRecoverInfo(collection)
suite.expectLoadPartitions()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
Expand Down Expand Up @@ -1940,12 +1951,18 @@ func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
}

func (suite *ServiceSuite) expectLoadPartitions() {
suite.cluster.EXPECT().LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(merr.Success(), nil)
}

func (suite *ServiceSuite) expectDescribeCollection() {
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).
Return(nil, nil)
}

func (suite *ServiceSuite) expectListIndexes() {
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).
Return(nil, nil)
suite.cluster.EXPECT().LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(merr.Success(), nil)
}

func (suite *ServiceSuite) getAllSegments(collection int64) []int64 {
Expand Down

0 comments on commit ac7b485

Please sign in to comment.