Skip to content

Commit

Permalink
go/vt/wrangler: reduce VReplicationExec calls when getting copy state
Browse files Browse the repository at this point in the history
Signed-off-by: Max Englander <[email protected]>
  • Loading branch information
maxenglander committed Oct 26, 2023
1 parent f8a274d commit cfc8b01
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 39 deletions.
18 changes: 9 additions & 9 deletions go/vt/vtctl/vtctl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ func TestMoveTables(t *testing.T) {
expectResults: func() {
env.tmc.setVRResults(
target.tablet,
fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)",
fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)",
vrID, vrID),
sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table_name|lastpk",
"varchar|varbinary"),
fmt.Sprintf("%s|", table),
"vrepl_id|table_name|lastpk",
"int64|varchar|varbinary"),
fmt.Sprintf("%d|%s|", vrID, table),
),
)
env.tmc.setDBAResults(
Expand Down Expand Up @@ -260,12 +260,12 @@ func TestMoveTables(t *testing.T) {
expectResults: func() {
env.tmc.setVRResults(
target.tablet,
fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)",
fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)",
vrID, vrID),
sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table_name|lastpk",
"varchar|varbinary"),
fmt.Sprintf("%s|", table),
"vrepl_id|table_name|lastpk",
"int64|varchar|varbinary"),
fmt.Sprintf("%d|%s|", vrID, table),
),
)
env.tmc.setDBAResults(
Expand Down Expand Up @@ -320,7 +320,7 @@ func TestMoveTables(t *testing.T) {
expectResults: func() {
env.tmc.setVRResults(
target.tablet,
fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)",
fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%d) and id in (select max(id) from _vt.copy_state where vrepl_id in (%d) group by vrepl_id, table_name)",
vrID, vrID),
&sqltypes.Result{},
)
Expand Down
34 changes: 27 additions & 7 deletions go/vt/wrangler/traffic_switcher_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"context"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -56,7 +58,7 @@ import (
const (
streamInfoQuery = "select id, source, message, cell, tablet_types, workflow_type, workflow_sub_type, defer_secondary_keys from _vt.vreplication where workflow='%s' and db_name='vt_%s'"
streamExtInfoQuery = "select id, source, pos, stop_pos, max_replication_lag, state, db_name, time_updated, transaction_timestamp, time_heartbeat, time_throttled, component_throttled, message, tags, workflow_type, workflow_sub_type, defer_secondary_keys, rows_copied from _vt.vreplication where db_name = 'vt_%s' and workflow = '%s'"
copyStateQuery = "select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)"
copyStateQuery = "select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%s) and id in (select max(id) from _vt.copy_state where vrepl_id in (%s) group by vrepl_id, table_name)"
maxValForSequence = "select max(`id`) as maxval from `vt_%s`.`%s`"
)

Expand Down Expand Up @@ -298,6 +300,7 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards,
for i, targetShard := range targetShards {
var streamInfoRows []string
var streamExtInfoRows []string
var vreplIDs []string
for j, sourceShard := range sourceShards {
bls := &binlogdatapb.BinlogSource{
Keyspace: "ks1",
Expand All @@ -314,8 +317,10 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards,
}
streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls))
streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0||1||0", j+1, now, now))
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10))
}
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbTargetClients[i].addInvariant(streamInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
Expand All @@ -332,6 +337,7 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards,

for i, sourceShard := range sourceShards {
var streamInfoRows []string
var vreplIDs []string
for j, targetShard := range targetShards {
bls := &binlogdatapb.BinlogSource{
Keyspace: "ks2",
Expand All @@ -347,8 +353,10 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards,
},
}
streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls))
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10))
}
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbSourceClients[i].addInvariant(reverseStreamInfoKs1, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
Expand Down Expand Up @@ -470,6 +478,7 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar
for _, shardToMove := range shardsToMove {
var streamInfoRows []string
var streamExtInfoRows []string
var vreplIDs []string
if shardToMove == shard {
bls := &binlogdatapb.BinlogSource{
Keyspace: "ks1",
Expand All @@ -486,8 +495,10 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar
}
streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", i+1, bls))
streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0|||1||0", i+1, now, now))
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(i+1), 10))
}
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, i+1, i+1), noResult)
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbTargetClients[i].addInvariant(streamInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
Expand All @@ -506,6 +517,7 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar
for i, shard := range shards {
for _, shardToMove := range shardsToMove {
var streamInfoRows []string
var vreplIDs []string
if shardToMove == shard {
bls := &binlogdatapb.BinlogSource{
Keyspace: "ks2",
Expand All @@ -521,8 +533,10 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar
},
}
streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v||||1|0|0", i+1, bls))
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, i+1, i+1), noResult)
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(i+1), 10))
}
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbSourceClients[i].addInvariant(reverseStreamInfoKs1, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
Expand Down Expand Up @@ -632,6 +646,7 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe
for i, targetShard := range targetShards {
var rows, rowsRdOnly []string
var streamExtInfoRows []string
var vreplIDs []string
for j, sourceShard := range sourceShards {
if !key.KeyRangeIntersect(tme.targetKeyRanges[i], tme.sourceKeyRanges[j]) {
continue
Expand All @@ -649,8 +664,10 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe
rows = append(rows, fmt.Sprintf("%d|%v||||1|0|0", j+1, bls))
rowsRdOnly = append(rows, fmt.Sprintf("%d|%v|||RDONLY|1|0|0", j+1, bls))
streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0|||", j+1, now, now))
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10))
}
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbTargetClients[i].addInvariant(streamInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
Expand All @@ -670,11 +687,14 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe
tme.targetKeyspace = "ks"
for i, dbclient := range tme.dbSourceClients {
var streamExtInfoRows []string
var vreplIDs []string
dbclient.addInvariant(streamInfoKs, &sqltypes.Result{})
for j := range targetShards {
streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks|%d|%d|0|0|||", j+1, now, now))
tme.dbSourceClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
vreplIDs = append(vreplIDs, strconv.FormatInt(int64(j+1), 10))
}
vreplIDsJoined := strings.Join(vreplIDs, ",")
tme.dbSourceClients[i].addInvariant(fmt.Sprintf(copyStateQuery, vreplIDsJoined, vreplIDsJoined), noResult)
tme.dbSourceClients[i].addInvariant(streamExtInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags",
"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"),
Expand Down
54 changes: 40 additions & 14 deletions go/vt/wrangler/vexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"math"
"sort"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -582,7 +583,7 @@ type ReplicationStatus struct {
deferSecondaryKeys bool
}

func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltypes.RowNamedValues, primary *topo.TabletInfo) (*ReplicationStatus, string, error) {
func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltypes.RowNamedValues, copyStates []copyState, primary *topo.TabletInfo) (*ReplicationStatus, string, error) {
var err error
var id int32
var timeUpdated, transactionTimestamp, timeHeartbeat, timeThrottled int64
Expand Down Expand Up @@ -688,11 +689,8 @@ func (wr *Wrangler) getReplicationStatusFromRow(ctx context.Context, row sqltype
deferSecondaryKeys: deferSecondaryKeys,
RowsCopied: rowsCopied,
}
status.CopyState, err = wr.getCopyState(ctx, primary, id)
if err != nil {
return nil, "", err
}

status.CopyState = copyStates
status.State = updateState(message, binlogdatapb.VReplicationWorkflowState(binlogdatapb.VReplicationWorkflowState_value[state]), status.CopyState, timeUpdated)
return status, bls.Keyspace, nil
}
Expand Down Expand Up @@ -739,8 +737,27 @@ func (wr *Wrangler) getStreams(ctx context.Context, workflow, keyspace string) (
if len(nqr.Rows) == 0 {
continue
}
// Get all copy states for the shard.
var vreplIDs []int64
for _, row := range nqr.Rows {
vreplID, err := row.ToInt64("id")
if err != nil {
return nil, err
}
vreplIDs = append(vreplIDs, vreplID)
}
copyStatesByVReplID, err := wr.getCopyStates(ctx, primary, vreplIDs)
if err != nil {
return nil, err
}
for _, row := range nqr.Rows {
status, sk, err := wr.getReplicationStatusFromRow(ctx, row, primary)
vreplID, err := row.ToInt64("id")
if err != nil {
return nil, err
}

copyStates := copyStatesByVReplID[vreplID]
status, sk, err := wr.getReplicationStatusFromRow(ctx, row, copyStates, primary)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -902,26 +919,35 @@ func (wr *Wrangler) printWorkflowList(keyspace string, workflows []string) {
wr.Logger().Printf("Following workflow(s) found in keyspace %s: %v\n", keyspace, list)
}

func (wr *Wrangler) getCopyState(ctx context.Context, tablet *topo.TabletInfo, id int32) ([]copyState, error) {
var cs []copyState
query := fmt.Sprintf("select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)",
id, id)
qr, err := wr.VReplicationExec(ctx, tablet.Alias, query)
func (wr *Wrangler) getCopyStates(ctx context.Context, tablet *topo.TabletInfo, ids []int64) (map[int64][]copyState, error) {
var idStrs []string
for _, id := range ids {
idStrs = append(idStrs, strconv.FormatInt(id, 10))
}
idsStr := strings.Join(idStrs, ",")
cs := make(map[int64][]copyState)
query := fmt.Sprintf("select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (%s) and id in (select max(id) from _vt.copy_state where vrepl_id in (%s) group by vrepl_id, table_name)",
idsStr, idsStr)
qr, err := wr.tmc.VReplicationExec(ctx, tablet.Tablet, query)
if err != nil {
return nil, err
}

result := sqltypes.Proto3ToResult(qr)
if result != nil {
for _, row := range result.Rows {
vreplID, err := row[0].ToInt64()
if err != nil {
return nil, fmt.Errorf("failed to cast vrepl_id to int64: %v", err)
}
// These fields are varbinary, but close enough
table := row[0].ToString()
lastPK := row[1].ToString()
table := row[1].ToString()
lastPK := row[2].ToString()
copyState := copyState{
Table: table,
LastPK: lastPK,
}
cs = append(cs, copyState)
cs[vreplID] = append(cs[vreplID], copyState)
}
}

Expand Down
10 changes: 5 additions & 5 deletions go/vt/wrangler/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,18 @@ func expectCanSwitchQueries(t *testing.T, tme *testMigraterEnv, keyspace, state
"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"),
row)
copyStateResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table|lastpk",
"varchar|varchar"),
"t1|pk1",
"vrepl_id|table|lastpk",
"int64|varchar|varchar"),
"1|t1|pk1",
)

for _, db := range tme.dbTargetClients {
db.addInvariant(streamExtInfoKs2, replicationResult)

if state == "Copying" {
db.addInvariant(fmt.Sprintf(copyStateQuery, 1, 1), copyStateResult)
db.addInvariant(fmt.Sprintf(copyStateQuery, "1", "1"), copyStateResult)
} else {
db.addInvariant(fmt.Sprintf(copyStateQuery, 1, 1), noResult)
db.addInvariant(fmt.Sprintf(copyStateQuery, "1", "1"), noResult)
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/wrangler/wrangler_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ func newWranglerTestEnv(t testing.TB, ctx context.Context, sourceShards, targetS
env.tmc.setVRResults(primary.tablet, "select distinct workflow from _vt.vreplication where state != 'Stopped' and db_name = 'vt_target'", result)

result = sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table|lastpk",
"varchar|varchar"),
"t1|pk1",
"vrepl_id|table|lastpk",
"int64|varchar|varchar"),
"1|t1|pk1",
)

env.tmc.setVRResults(primary.tablet, "select table_name, lastpk from _vt.copy_state where vrepl_id = 1 and id in (select max(id) from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name)", result)
env.tmc.setVRResults(primary.tablet, "select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (1) and id in (select max(id) from _vt.copy_state where vrepl_id in (1) group by vrepl_id, table_name)", result)

env.tmc.setVRResults(primary.tablet, "select id, source, pos, stop_pos, max_replication_lag, state, db_name, time_updated, transaction_timestamp, time_heartbeat, time_throttled, component_throttled, message, tags from _vt.vreplication where db_name = 'vt_target' and workflow = 'bad'", &sqltypes.Result{})

Expand Down

0 comments on commit cfc8b01

Please sign in to comment.