diff --git a/go/vt/vtctl/vtctl_test.go b/go/vt/vtctl/vtctl_test.go index eb6a5f5941f..76514ec02b1 100644 --- a/go/vt/vtctl/vtctl_test.go +++ b/go/vt/vtctl/vtctl_test.go @@ -200,12 +200,12 @@ func TestMoveTables(t *testing.T) { expectResults: func() { env.tmc.setVRResults( target.tablet, - fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)", + fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)", vrID, vrID), sqltypes.MakeTestResult(sqltypes.MakeTestFields( - "table_name|lastpk", - "varchar|varbinary"), - fmt.Sprintf("%s|", table), + "vrepl_id|table_name|lastpk", + "int64|varchar|varbinary"), + fmt.Sprintf("%d|%s|", vrID, table), ), ) env.tmc.setDBAResults( @@ -260,12 +260,12 @@ func TestMoveTables(t *testing.T) { expectResults: func() { env.tmc.setVRResults( target.tablet, - fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)", + fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)", vrID, vrID), sqltypes.MakeTestResult(sqltypes.MakeTestFields( - "table_name|lastpk", - "varchar|varbinary"), - fmt.Sprintf("%s|", table), + "vrepl_id|table_name|lastpk", + "int64|varchar|varbinary"), + fmt.Sprintf("%d|%s|", vrID, table), ), ) env.tmc.setDBAResults( @@ -320,7 +320,7 @@ func TestMoveTables(t *testing.T) { expectResults: func() { env.tmc.setVRResults( target.tablet, - fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)", + fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)", vrID, vrID), &sqltypes.Result{}, ) diff --git a/go/vt/wrangler/traffic_switcher_env_test.go b/go/vt/wrangler/traffic_switcher_env_test.go index c8ec71dba96..efd740e8f97 100644 --- a/go/vt/wrangler/traffic_switcher_env_test.go +++ b/go/vt/wrangler/traffic_switcher_env_test.go @@ -20,6 +20,8 @@ import ( "context" "fmt" "math/rand" + "strconv" + "strings" "sync" "testing" "time" @@ -56,7 +58,7 @@ import ( const ( streamInfoQuery = "select id, source, message, cell, tablet_types, workflow_type, workflow_sub_type, defer_secondary_keys from _vt.vreplication where workflow='%s' and db_name='vt_%s'" streamExtInfoQuery = "select id, source, pos, stop_pos, max_replication_lag, state, db_name, time_updated, transaction_timestamp, time_heartbeat, time_throttled, component_throttled, message, tags, workflow_type, workflow_sub_type, defer_secondary_keys, rows_copied from _vt.vreplication where db_name = 'vt_%s' and workflow = '%s'" - copyStateQuery = "select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)" + copyStateQuery = "select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%s) and id in (select max(id) from _vt.copy_state where vrepl_id in (%s) group by vrepl_id, table_name)" maxValForSequence = "select max(`id`) as maxval from `vt_%s`.`%s`" ) @@ -298,6 +300,7 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, for i, targetShard := range targetShards { var streamInfoRows []string var streamExtInfoRows []string + var vreplIDs []string for j, sourceShard := range sourceShards { bls := &binlogdatapb.BinlogSource{ Keyspace: "ks1", @@ -314,8 +317,10 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, } streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls)) streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0||1||0", j+1, now, now)) - tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10)) } + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbTargetClients[i].addInvariant(streamInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys", "int64|varchar|varchar|varchar|varchar|int64|int64|int64"), @@ -332,6 +337,7 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, for i, sourceShard := range sourceShards { var streamInfoRows []string + var vreplIDs []string for j, targetShard := range targetShards { bls := &binlogdatapb.BinlogSource{ Keyspace: "ks2", @@ -347,8 +353,10 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, }, } streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls)) - tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10)) } + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbSourceClients[i].addInvariant(reverseStreamInfoKs1, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys", "int64|varchar|varchar|varchar|varchar|int64|int64|int64"), @@ -470,6 +478,7 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar for _, shardToMove := range shardsToMove { var streamInfoRows []string var streamExtInfoRows []string + var vreplIDs []string if shardToMove == shard { bls := &binlogdatapb.BinlogSource{ Keyspace: "ks1", @@ -486,8 +495,10 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar } streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", i+1, bls)) streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0|||1||0", i+1, now, now)) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(i+1), 10)) } - tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, i+1, i+1), noResult) + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbTargetClients[i].addInvariant(streamInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys", "int64|varchar|varchar|varchar|varchar|int64|int64|int64"), @@ -506,6 +517,7 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar for i, shard := range shards { for _, shardToMove := range shardsToMove { var streamInfoRows []string + var vreplIDs []string if shardToMove == shard { bls := &binlogdatapb.BinlogSource{ Keyspace: "ks2", @@ -521,8 +533,10 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar }, } streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", i+1, bls)) - tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, i+1, i+1), noResult) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(i+1), 10)) } + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbSourceClients[i].addInvariant(reverseStreamInfoKs1, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys", "int64|varchar|varchar|varchar|varchar|int64|int64|int64"), @@ -632,6 +646,7 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe for i, targetShard := range targetShards { var rows, rowsRdOnly []string var streamExtInfoRows []string + var vreplIDs []string for j, sourceShard := range sourceShards { if !key.KeyRangeIntersect(tme.targetKeyRanges[i], tme.sourceKeyRanges[j]) { continue @@ -649,8 +664,10 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe rows = append(rows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls)) rowsRdOnly = append(rows, fmt.Sprintf("%d|%v|||RDONLY|1|0|0", j+1, bls)) streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0|||", j+1, now, now)) - tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10)) } + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbTargetClients[i].addInvariant(streamInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys", "int64|varchar|varchar|varchar|varchar|int64|int64|int64"), @@ -670,11 +687,14 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe tme.targetKeyspace = "ks" for i, dbclient := range tme.dbSourceClients { var streamExtInfoRows []string + var vreplIDs []string dbclient.addInvariant(streamInfoKs, &sqltypes.Result{}) for j := range targetShards { streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks|%d|%d|0|0|||", j+1, now, now)) - tme.dbSourceClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult) + vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10)) } + vreplIDsJoined := strings.Join(vreplIDs, ",") + tme.dbSourceClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult) tme.dbSourceClients[i].addInvariant(streamExtInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields( "id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags", "int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"), diff --git a/go/vt/wrangler/vexec.go b/go/vt/wrangler/vexec.go index 0734fa7b593..0b59ff679ab 100644 --- a/go/vt/wrangler/vexec.go +++ b/go/vt/wrangler/vexec.go @@ -22,6 +22,7 @@ import ( "fmt" "math" "sort" + "strconv" "strings" "sync" "time" @@ -582,7 +583,7 @@ type ReplicationStatus struct { deferSecondaryKeys bool } -func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltypes.RowNamedValues, primary *topo.TabletInfo) (*ReplicationStatus, string, error) { +func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltypes.RowNamedValues, copyStates []copyState, primary *topo.TabletInfo) (*ReplicationStatus, string, error) { var err error var id int32 var timeUpdated, transactionTimestamp, timeHeartbeat, timeThrottled int64 @@ -688,11 +689,8 @@ func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltype deferSecondaryKeys: deferSecondaryKeys, RowsCopied: rowsCopied, } - status.CopyState, err = wr.getCopyState(ctx, primary, id) - if err != nil { - return nil, "", err - } + status.CopyState = copyStates status.State = updateState(message, binlogdatapb.VReplicationWorkflowState(binlogdatapb.VReplicationWorkflowState_value[state]), status.CopyState, timeUpdated) return status, bls.Keyspace, nil } @@ -739,8 +737,27 @@ func (wr *Wrangler) getStreams(ctx context.Context, workflow, keyspace string) ( if len(nqr.Rows) == 0 { continue } + // Get all copy states for the shard. + var vreplIDs []int64 + for _, row := range nqr.Rows { + vreplID, err := row.ToInt64("id") + if err != nil { + return nil, err + } + vreplIDs = append(vreplIDs, vreplID) + } + copyStatesByVReplID, err := wr.getCopyStates(ctx, primary, vreplIDs) + if err != nil { + return nil, err + } for _, row := range nqr.Rows { - status, sk, err := wr.getReplicationStatusFromRow(ctx, row, primary) + vreplID, err := row.ToInt64("id") + if err != nil { + return nil, err + } + + copyStates := copyStatesByVReplID[vreplID] + status, sk, err := wr.getReplicationStatusFromRow(ctx, row, copyStates, primary) if err != nil { return nil, err } @@ -902,11 +919,16 @@ func (wr *Wrangler) printWorkflowList(keyspace string, workflows []string) { wr.Logger().Printf("Following workflow(s) found in keyspace %s: %v\n", keyspace, list) } -func (wr *Wrangler) getCopyState(ctx context.Context, tablet *topo.TabletInfo, id int32) ([]copyState, error) { - var cs []copyState - query := fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)", - id, id) - qr, err := wr.VReplicationExec(ctx, tablet.Alias, query) +func (wr *Wrangler) getCopyStates(ctx context.Context, tablet *topo.TabletInfo, ids []int64) (map[int64][]copyState, error) { + var idStrs []string + for _, id := range ids { + idStrs = append(idStrs, strconv.FormatInt(id, 10)) + } + idsStr := strings.Join(idStrs, ",") + cs := make(map[int64][]copyState) + query := fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%s) and id in (select max(id) from _vt.copy_state where vrepl_id in (%s) group by vrepl_id, table_name)", + idsStr, idsStr) + qr, err := wr.tmc.VReplicationExec(ctx, tablet.Tablet, query) if err != nil { return nil, err } @@ -914,14 +936,18 @@ func (wr *Wrangler) getCopyState(ctx context.Context, tablet *topo.TabletInfo, i result := sqltypes.Proto3ToResult(qr) if result != nil { for _, row := range result.Rows { + vreplID, err := row[0].ToInt64() + if err != nil { + return nil, fmt.Errorf("failed to cast vrepl_id to int64: %v", err) + } // These fields are varbinary, but close enough - table := row[0].ToString() - lastPK := row[1].ToString() + table := row[1].ToString() + lastPK := row[2].ToString() copyState := copyState{ Table: table, LastPK: lastPK, } - cs = append(cs, copyState) + cs[vreplID] = append(cs[vreplID], copyState) } } diff --git a/go/vt/wrangler/workflow_test.go b/go/vt/wrangler/workflow_test.go index be3589a3f58..c29d5142295 100644 --- a/go/vt/wrangler/workflow_test.go +++ b/go/vt/wrangler/workflow_test.go @@ -109,18 +109,18 @@ func expectCanSwitchQueries(t *testing.T, tme *testMigraterEnv, keyspace, state "int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"), row) copyStateResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields( - "table|lastpk", - "varchar|varchar"), - "t1|pk1", + "vrepl_id|table|lastpk", + "int64|varchar|varchar"), + "1|t1|pk1", ) for _, db := range tme.dbTargetClients { db.addInvariant(streamExtInfoKs2, replicationResult) if state == "Copying" { - db.addInvariant(fmt.Sprintf(copyStateQuery, 1, 1), copyStateResult) + db.addInvariant(fmt.Sprintf(copyStateQuery, "1", "1"), copyStateResult) } else { - db.addInvariant(fmt.Sprintf(copyStateQuery, 1, 1), noResult) + db.addInvariant(fmt.Sprintf(copyStateQuery, "1", "1"), noResult) } } } diff --git a/go/vt/wrangler/wrangler_env_test.go b/go/vt/wrangler/wrangler_env_test.go index 4dd5e342c35..109cb046633 100644 --- a/go/vt/wrangler/wrangler_env_test.go +++ b/go/vt/wrangler/wrangler_env_test.go @@ -164,12 +164,12 @@ func newWranglerTestEnv(t testing.TB, ctx context.Context, sourceShards, targetS env.tmc.setVRResults(primary.tablet, "select distinct workflow from _vt.vreplication where state != 'Stopped' and db_name = 'vt_target'", result) result = sqltypes.MakeTestResult(sqltypes.MakeTestFields( - "table|lastpk", - "varchar|varchar"), - "t1|pk1", + "vrepl_id|table|lastpk", + "int64|varchar|varchar"), + "1|t1|pk1", ) - env.tmc.setVRResults(primary.tablet, "select table_name, lastpk from _vt.copy_state where vrepl_id = 1 and id in (select max(id) from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name)", result) + env.tmc.setVRResults(primary.tablet, "select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (1) and id in (select max(id) from _vt.copy_state where vrepl_id in (1) group by vrepl_id, table_name)", result) env.tmc.setVRResults(primary.tablet, "select id, source, pos, stop_pos, max_replication_lag, state, db_name, time_updated, transaction_timestamp, time_heartbeat, time_throttled, component_throttled, message, tags from _vt.vreplication where db_name = 'vt_target' and workflow = 'bad'", &sqltypes.Result{})