diff --git a/go/cmd/vtctldclient/command/root.go b/go/cmd/vtctldclient/command/root.go index 0fc39423bc7..048b0786cb4 100644 --- a/go/cmd/vtctldclient/command/root.go +++ b/go/cmd/vtctldclient/command/root.go @@ -218,7 +218,7 @@ func getClientForCommand(cmd *cobra.Command) (vtctldclient.VtctldClient, error) server = "" } - return vtctldclient.New(VtctldClientProtocol, server) + return vtctldclient.New(cmd.Context(), VtctldClientProtocol, server) } func init() { diff --git a/go/cmd/vtctldclient/command/vreplication/common/utils_test.go b/go/cmd/vtctldclient/command/vreplication/common/utils_test.go index a8a0df2e9b2..39de482da2c 100644 --- a/go/cmd/vtctldclient/command/vreplication/common/utils_test.go +++ b/go/cmd/vtctldclient/command/vreplication/common/utils_test.go @@ -148,7 +148,7 @@ func SetupLocalVtctldClient(t *testing.T, ctx context.Context, cells ...string) vtctld := grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) localvtctldclient.SetServer(vtctld) command.VtctldClientProtocol = "local" - client, err := vtctldclient.New(command.VtctldClientProtocol, "") + client, err := vtctldclient.New(ctx, command.VtctldClientProtocol, "") require.NoError(t, err, "failed to create local vtctld client which uses an internal vtctld server") common.SetClient(client) } diff --git a/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff_env_test.go b/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff_env_test.go index b42c2c55072..9c98338bf67 100644 --- a/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff_env_test.go +++ b/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff_env_test.go @@ -89,7 +89,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar // Generate a unique dialer name. dialerName := fmt.Sprintf("VDiffTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { env.mu.Lock() defer env.mu.Unlock() if qs, ok := env.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/cmd/vttestserver/cli/main_test.go b/go/cmd/vttestserver/cli/main_test.go index 0ea0e6b7c19..75597ffe687 100644 --- a/go/cmd/vttestserver/cli/main_test.go +++ b/go/cmd/vttestserver/cli/main_test.go @@ -238,6 +238,9 @@ func TestCanGetKeyspaces(t *testing.T) { conf := config defer resetConfig(conf) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + clusterInstance, err := startCluster() assert.NoError(t, err) defer clusterInstance.TearDown() @@ -248,13 +251,16 @@ func TestCanGetKeyspaces(t *testing.T) { } }() - assertGetKeyspaces(t, clusterInstance) + assertGetKeyspaces(ctx, t, clusterInstance) } func TestExternalTopoServerConsul(t *testing.T) { conf := config defer resetConfig(conf) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Start a single consul in the background. cmd, serverAddr := startConsul(t) defer func() { @@ -273,7 +279,7 @@ func TestExternalTopoServerConsul(t *testing.T) { assert.NoError(t, err) defer cluster.TearDown() - assertGetKeyspaces(t, cluster) + assertGetKeyspaces(ctx, t, cluster) } func TestMtlsAuth(t *testing.T) { @@ -445,12 +451,12 @@ func randomPort() int { return int(v + 10000) } -func assertGetKeyspaces(t *testing.T, cluster vttest.LocalCluster) { - client, err := vtctlclient.New(fmt.Sprintf("localhost:%v", cluster.GrpcPort())) +func assertGetKeyspaces(ctx context.Context, t *testing.T, cluster vttest.LocalCluster) { + client, err := vtctlclient.New(ctx, fmt.Sprintf("localhost:%v", cluster.GrpcPort())) assert.NoError(t, err) defer client.Close() stream, err := client.ExecuteVtctlCommand( - context.Background(), + ctx, []string{ "GetKeyspaces", "--server", diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index 3ef4e8a1b3b..0fc5edef1bb 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -897,7 +897,7 @@ func (cluster *LocalProcessCluster) ExecOnTablet(ctx context.Context, vttablet * return nil, err } - conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, tablet, grpcclient.FailFast(false)) if err != nil { return nil, err } @@ -940,7 +940,7 @@ func (cluster *LocalProcessCluster) StreamTabletHealth(ctx context.Context, vtta return nil, err } - conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, tablet, grpcclient.FailFast(false)) if err != nil { return nil, err } @@ -975,7 +975,7 @@ func (cluster *LocalProcessCluster) StreamTabletHealthUntil(ctx context.Context, return err } - conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, tablet, grpcclient.FailFast(false)) if err != nil { return err } diff --git a/go/test/endtoend/encryption/encryptedtransport/encrypted_transport_test.go b/go/test/endtoend/encryption/encryptedtransport/encrypted_transport_test.go index b076006ec2c..9147b7b9080 100644 --- a/go/test/endtoend/encryption/encryptedtransport/encrypted_transport_test.go +++ b/go/test/endtoend/encryption/encryptedtransport/encrypted_transport_test.go @@ -177,7 +177,7 @@ func TestSecureTransport(t *testing.T) { setCreds(t, "vtgate-client-1", "vtgate-server") ctx := context.Background() request := getRequest("select * from vt_insert_test") - vc, err := getVitessClient(grpcAddress) + vc, err := getVitessClient(ctx, grpcAddress) require.NoError(t, err) qr, err := vc.Execute(ctx, request) @@ -188,7 +188,7 @@ func TestSecureTransport(t *testing.T) { // 'vtgate client 2' is not authorized to access vt_insert_test setCreds(t, "vtgate-client-2", "vtgate-server") request = getRequest("select * from vt_insert_test") - vc, err = getVitessClient(grpcAddress) + vc, err = getVitessClient(ctx, grpcAddress) require.NoError(t, err) qr, err = vc.Execute(ctx, request) require.NoError(t, err) @@ -217,7 +217,7 @@ func useEffectiveCallerID(ctx context.Context, t *testing.T) { setSSLInfoEmpty() // get vitess client - vc, err := getVitessClient(grpcAddress) + vc, err := getVitessClient(ctx, grpcAddress) require.NoError(t, err) // test with empty effective caller Id @@ -266,7 +266,7 @@ func useEffectiveGroups(ctx context.Context, t *testing.T) { setSSLInfoEmpty() // get vitess client - vc, err := getVitessClient(grpcAddress) + vc, err := getVitessClient(ctx, grpcAddress) require.NoError(t, err) // test with empty effective caller Id @@ -452,12 +452,12 @@ func tabletConnExtraArgs(name string) []string { return args } -func getVitessClient(addr string) (vtgateservicepb.VitessClient, error) { +func getVitessClient(ctx context.Context, addr string) (vtgateservicepb.VitessClient, error) { opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, "", grpcName) if err != nil { return nil, err } - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { return nil, err } diff --git a/go/test/endtoend/mysqlctld/mysqlctld_test.go b/go/test/endtoend/mysqlctld/mysqlctld_test.go index 52be2fa4323..328bc563377 100644 --- a/go/test/endtoend/mysqlctld/mysqlctld_test.go +++ b/go/test/endtoend/mysqlctld/mysqlctld_test.go @@ -164,7 +164,7 @@ func TestAutoDetect(t *testing.T) { } func TestVersionString(t *testing.T) { - client, err := mysqlctlclient.New("unix", primaryTablet.MysqlctldProcess.SocketFile) + client, err := mysqlctlclient.New(context.Background(), "unix", primaryTablet.MysqlctldProcess.SocketFile) require.NoError(t, err) version, err := client.VersionString(context.Background()) require.NoError(t, err) @@ -172,7 +172,7 @@ func TestVersionString(t *testing.T) { } func TestReadBinlogFilesTimestamps(t *testing.T) { - client, err := mysqlctlclient.New("unix", primaryTablet.MysqlctldProcess.SocketFile) + client, err := mysqlctlclient.New(context.Background(), "unix", primaryTablet.MysqlctldProcess.SocketFile) require.NoError(t, err) _, err = client.ReadBinlogFilesTimestamps(context.Background(), &mysqlctl.ReadBinlogFilesTimestampsRequest{}) require.ErrorContains(t, err, "empty binlog list in ReadBinlogFilesTimestampsRequest") diff --git a/go/test/endtoend/reparent/prssettingspool/main_test.go b/go/test/endtoend/reparent/prssettingspool/main_test.go index a9f4312caea..872f1867c77 100644 --- a/go/test/endtoend/reparent/prssettingspool/main_test.go +++ b/go/test/endtoend/reparent/prssettingspool/main_test.go @@ -104,13 +104,13 @@ func TestSettingsPoolWithTXAndPRS(t *testing.T) { // prs should happen without any error. text, err := rutils.Prs(t, clusterInstance, tablets[1]) require.NoError(t, err, text) - rutils.WaitForTabletToBeServing(t, clusterInstance, tablets[0], 1*time.Minute) + rutils.WaitForTabletToBeServing(ctx, t, clusterInstance, tablets[0], 1*time.Minute) defer func() { // reset state text, err = rutils.Prs(t, clusterInstance, tablets[0]) require.NoError(t, err, text) - rutils.WaitForTabletToBeServing(t, clusterInstance, tablets[1], 1*time.Minute) + rutils.WaitForTabletToBeServing(ctx, t, clusterInstance, tablets[1], 1*time.Minute) }() // no error should occur and it should go to the right tablet. @@ -134,12 +134,12 @@ func TestSettingsPoolWithoutTXAndPRS(t *testing.T) { // prs should happen without any error. text, err := rutils.Prs(t, clusterInstance, tablets[1]) require.NoError(t, err, text) - rutils.WaitForTabletToBeServing(t, clusterInstance, tablets[0], 1*time.Minute) + rutils.WaitForTabletToBeServing(ctx, t, clusterInstance, tablets[0], 1*time.Minute) defer func() { // reset state text, err = rutils.Prs(t, clusterInstance, tablets[0]) require.NoError(t, err, text) - rutils.WaitForTabletToBeServing(t, clusterInstance, tablets[1], 1*time.Minute) + rutils.WaitForTabletToBeServing(ctx, t, clusterInstance, tablets[1], 1*time.Minute) }() // no error should occur and it should go to the right tablet. diff --git a/go/test/endtoend/reparent/utils/utils.go b/go/test/endtoend/reparent/utils/utils.go index fb782e69ea4..91fa4c66e3c 100644 --- a/go/test/endtoend/reparent/utils/utils.go +++ b/go/test/endtoend/reparent/utils/utils.go @@ -728,11 +728,11 @@ func CheckReplicationStatus(ctx context.Context, t *testing.T, tablet *cluster.V } } -func WaitForTabletToBeServing(t *testing.T, clusterInstance *cluster.LocalProcessCluster, tablet *cluster.Vttablet, timeout time.Duration) { +func WaitForTabletToBeServing(ctx context.Context, t *testing.T, clusterInstance *cluster.LocalProcessCluster, tablet *cluster.Vttablet, timeout time.Duration) { vTablet, err := clusterInstance.VtctldClientProcess.GetTablet(tablet.Alias) require.NoError(t, err) - tConn, err := tabletconn.GetDialer()(vTablet, false) + tConn, err := tabletconn.GetDialer()(ctx, vTablet, false) require.NoError(t, err) newCtx, cancel := context.WithTimeout(context.Background(), timeout) diff --git a/go/vt/binlog/binlogplayer/binlog_player.go b/go/vt/binlog/binlogplayer/binlog_player.go index ea2c9c63a51..711ea29a2e9 100644 --- a/go/vt/binlog/binlogplayer/binlog_player.go +++ b/go/vt/binlog/binlogplayer/binlog_player.go @@ -329,7 +329,7 @@ func (blp *BinlogPlayer) applyEvents(ctx context.Context) error { return fmt.Errorf("no binlog player client factory named %v", binlogPlayerProtocol) } blplClient := clientFactory() - err = blplClient.Dial(blp.tablet) + err = blplClient.Dial(ctx, blp.tablet) if err != nil { err := fmt.Errorf("error dialing binlog server: %v", err) log.Error(err) diff --git a/go/vt/binlog/binlogplayer/client.go b/go/vt/binlog/binlogplayer/client.go index d234a439845..3aaad1a705c 100644 --- a/go/vt/binlog/binlogplayer/client.go +++ b/go/vt/binlog/binlogplayer/client.go @@ -53,7 +53,7 @@ type BinlogTransactionStream interface { // Client is the interface all clients must satisfy type Client interface { // Dial a server - Dial(tablet *topodatapb.Tablet) error + Dial(ctx context.Context, tablet *topodatapb.Tablet) error // Close the connection Close() diff --git a/go/vt/binlog/binlogplayer/framework_test.go b/go/vt/binlog/binlogplayer/framework_test.go index 4bb61aa70a9..5455e7cc1bf 100644 --- a/go/vt/binlog/binlogplayer/framework_test.go +++ b/go/vt/binlog/binlogplayer/framework_test.go @@ -46,7 +46,7 @@ func newFakeBinlogClient() *fakeBinlogClient { return globalFBC } -func (fbc *fakeBinlogClient) Dial(tablet *topodatapb.Tablet) error { +func (fbc *fakeBinlogClient) Dial(ctx context.Context, tablet *topodatapb.Tablet) error { fbc.lastTablet = tablet return nil } diff --git a/go/vt/binlog/binlogplayertest/player.go b/go/vt/binlog/binlogplayertest/player.go index e3468f92913..028f027ab3d 100644 --- a/go/vt/binlog/binlogplayertest/player.go +++ b/go/vt/binlog/binlogplayertest/player.go @@ -17,13 +17,12 @@ limitations under the License. package binlogplayertest import ( + "context" "fmt" "reflect" "strings" "testing" - "context" - "google.golang.org/protobuf/proto" "vitess.io/vitess/go/vt/binlog/binlogplayer" @@ -227,8 +226,8 @@ func (fake *FakeBinlogStreamer) HandlePanic(err *error) { } // Run runs the test suite -func Run(t *testing.T, bpc binlogplayer.Client, tablet *topodatapb.Tablet, fake *FakeBinlogStreamer) { - if err := bpc.Dial(tablet); err != nil { +func Run(ctx context.Context, t *testing.T, bpc binlogplayer.Client, tablet *topodatapb.Tablet, fake *FakeBinlogStreamer) { + if err := bpc.Dial(ctx, tablet); err != nil { t.Fatalf("Dial failed: %v", err) } diff --git a/go/vt/binlog/grpcbinlogplayer/player.go b/go/vt/binlog/grpcbinlogplayer/player.go index 1d5111aa5b0..014860ccdaf 100644 --- a/go/vt/binlog/grpcbinlogplayer/player.go +++ b/go/vt/binlog/grpcbinlogplayer/player.go @@ -52,14 +52,14 @@ type client struct { c binlogservicepb.UpdateStreamClient } -func (client *client) Dial(tablet *topodatapb.Tablet) error { +func (client *client) Dial(ctx context.Context, tablet *topodatapb.Tablet) error { addr := netutil.JoinHostPort(tablet.Hostname, tablet.PortMap["grpc"]) var err error opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name) if err != nil { return err } - client.cc, err = grpcclient.Dial(addr, grpcclient.FailFast(true), opt) + client.cc, err = grpcclient.DialContext(ctx, addr, grpcclient.FailFast(true), opt) if err != nil { return err } diff --git a/go/vt/binlog/grpcbinlogplayer/player_test.go b/go/vt/binlog/grpcbinlogplayer/player_test.go index bde54cd2113..b290782f015 100644 --- a/go/vt/binlog/grpcbinlogplayer/player_test.go +++ b/go/vt/binlog/grpcbinlogplayer/player_test.go @@ -17,6 +17,7 @@ limitations under the License. package grpcbinlogplayer import ( + "context" "net" "testing" @@ -48,9 +49,11 @@ func TestGRPCBinlogStreamer(t *testing.T) { // Create a GRPC client to talk to the fake tablet c := &client{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // and send it to the test suite - binlogplayertest.Run(t, c, &topodatapb.Tablet{ + binlogplayertest.Run(ctx, t, c, &topodatapb.Tablet{ Hostname: host, PortMap: map[string]int32{ "grpc": int32(port), diff --git a/go/vt/discovery/fake_healthcheck.go b/go/vt/discovery/fake_healthcheck.go index 1c83de5b149..d1bde350276 100644 --- a/go/vt/discovery/fake_healthcheck.go +++ b/go/vt/discovery/fake_healthcheck.go @@ -229,7 +229,7 @@ func (fhc *FakeHealthCheck) ReplaceTablet(old, new *topodatapb.Tablet) { } // TabletConnection returns the TabletConn of the given tablet. -func (fhc *FakeHealthCheck) TabletConnection(alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) { +func (fhc *FakeHealthCheck) TabletConnection(ctx context.Context, alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) { aliasStr := topoproto.TabletAliasString(alias) fhc.mu.RLock() defer fhc.mu.RUnlock() diff --git a/go/vt/discovery/healthcheck.go b/go/vt/discovery/healthcheck.go index f37c9ad1d8b..46d92c7364e 100644 --- a/go/vt/discovery/healthcheck.go +++ b/go/vt/discovery/healthcheck.go @@ -214,7 +214,7 @@ type HealthCheck interface { WaitForAllServingTablets(ctx context.Context, targets []*query.Target) error // TabletConnection returns the TabletConn of the given tablet. - TabletConnection(alias *topodata.TabletAlias, target *query.Target) (queryservice.QueryService, error) + TabletConnection(ctx context.Context, alias *topodata.TabletAlias, target *query.Target) (queryservice.QueryService, error) // RegisterStats registers the connection counts stats RegisterStats() @@ -828,7 +828,7 @@ func (hc *HealthCheckImpl) GetTabletHealth(kst KeyspaceShardTabletType, alias *t } // TabletConnection returns the Connection to a given tablet. -func (hc *HealthCheckImpl) TabletConnection(alias *topodata.TabletAlias, target *query.Target) (queryservice.QueryService, error) { +func (hc *HealthCheckImpl) TabletConnection(ctx context.Context, alias *topodata.TabletAlias, target *query.Target) (queryservice.QueryService, error) { hc.mu.Lock() thc := hc.healthByAlias[tabletAliasString(topoproto.TabletAliasString(alias))] hc.mu.Unlock() @@ -836,7 +836,7 @@ func (hc *HealthCheckImpl) TabletConnection(alias *topodata.TabletAlias, target // TODO: test that throws this error return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "tablet: %v is either down or nonexistent", alias) } - return thc.Connection(hc), nil + return thc.Connection(ctx, hc), nil } // getAliasByCell should only be called while holding hc.mu diff --git a/go/vt/discovery/healthcheck_test.go b/go/vt/discovery/healthcheck_test.go index 31376bd8c7d..c87ba699234 100644 --- a/go/vt/discovery/healthcheck_test.go +++ b/go/vt/discovery/healthcheck_test.go @@ -1283,7 +1283,7 @@ func TestDebugURLFormatting(t *testing.T) { require.Contains(t, wr.String(), expectedURL, "output missing formatted URL") } -func tabletDialer(tablet *topodatapb.Tablet, _ grpcclient.FailFast) (queryservice.QueryService, error) { +func tabletDialer(ctx context.Context, tablet *topodatapb.Tablet, _ grpcclient.FailFast) (queryservice.QueryService, error) { connMapMu.Lock() defer connMapMu.Unlock() diff --git a/go/vt/discovery/tablet_health_check.go b/go/vt/discovery/tablet_health_check.go index fc3ab242210..64450f4c8c6 100644 --- a/go/vt/discovery/tablet_health_check.go +++ b/go/vt/discovery/tablet_health_check.go @@ -128,7 +128,7 @@ func (thc *tabletHealthCheck) setServingState(serving bool, reason string) { // stream streams healthcheck responses to callback. func (thc *tabletHealthCheck) stream(ctx context.Context, hc *HealthCheckImpl, callback func(*query.StreamHealthResponse) error) error { - conn := thc.Connection(hc) + conn := thc.Connection(ctx, hc) if conn == nil { // This signals the caller to retry return nil @@ -141,10 +141,10 @@ func (thc *tabletHealthCheck) stream(ctx context.Context, hc *HealthCheckImpl, c return err } -func (thc *tabletHealthCheck) Connection(hc *HealthCheckImpl) queryservice.QueryService { +func (thc *tabletHealthCheck) Connection(ctx context.Context, hc *HealthCheckImpl) queryservice.QueryService { thc.connMu.Lock() defer thc.connMu.Unlock() - return thc.connectionLocked(hc) + return thc.connectionLocked(ctx, hc) } func healthCheckDialerFactory(hc *HealthCheckImpl) func(ctx context.Context, addr string) (net.Conn, error) { @@ -162,14 +162,14 @@ func healthCheckDialerFactory(hc *HealthCheckImpl) func(ctx context.Context, add } } -func (thc *tabletHealthCheck) connectionLocked(hc *HealthCheckImpl) queryservice.QueryService { +func (thc *tabletHealthCheck) connectionLocked(ctx context.Context, hc *HealthCheckImpl) queryservice.QueryService { if thc.Conn == nil { withDialerContextOnce.Do(func() { grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { return append(opts, grpc.WithContextDialer(healthCheckDialerFactory(hc))), nil }) }) - conn, err := tabletconn.GetDialer()(thc.Tablet, grpcclient.FailFast(true)) + conn, err := tabletconn.GetDialer()(ctx, thc.Tablet, grpcclient.FailFast(true)) if err != nil { thc.LastError = err return nil diff --git a/go/vt/discovery/tablet_picker.go b/go/vt/discovery/tablet_picker.go index 4cfb6020fe6..fd1ff64a3ce 100644 --- a/go/vt/discovery/tablet_picker.go +++ b/go/vt/discovery/tablet_picker.go @@ -457,7 +457,7 @@ func (tp *TabletPicker) GetMatchingTablets(ctx context.Context) []*topo.TabletIn log.Warningf("Tablet picker failed to load tablet %v", tabletAlias) } else if topoproto.IsTypeInList(tabletInfo.Type, tp.tabletTypes) { // Try to connect to the tablet and confirm that it's usable. - if conn, err := tabletconn.GetDialer()(tabletInfo.Tablet, grpcclient.FailFast(true)); err == nil { + if conn, err := tabletconn.GetDialer()(ctx, tabletInfo.Tablet, grpcclient.FailFast(true)); err == nil { // Ensure that the tablet is healthy and serving. shortCtx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer cancel() diff --git a/go/vt/grpcclient/client.go b/go/vt/grpcclient/client.go index 7524298514e..e9209277b7c 100644 --- a/go/vt/grpcclient/client.go +++ b/go/vt/grpcclient/client.go @@ -93,13 +93,6 @@ func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([ grpcDialOptions = append(grpcDialOptions, grpcDialOptionsFunc) } -// Dial creates a grpc connection to the given target. -// failFast is a non-optional parameter because callers are required to specify -// what that should be. -func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - return DialContext(context.Background(), target, failFast, opts...) -} - // DialContext creates a grpc connection to the given target. Setup steps are // covered by the context deadline, and, if WithBlock is specified in the dial // options, connection establishment steps are covered by the context as well. diff --git a/go/vt/grpcclient/client_test.go b/go/vt/grpcclient/client_test.go index 40b03bef2f6..369ec8da17b 100644 --- a/go/vt/grpcclient/client_test.go +++ b/go/vt/grpcclient/client_test.go @@ -41,33 +41,36 @@ func TestDialErrors(t *testing.T) { } wantErr := "Unavailable" for _, address := range addresses { - gconn, err := Dial(address, true, grpc.WithTransportCredentials(insecure.NewCredentials())) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + gconn, err := DialContext(ctx, address, true, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { + cancel() t.Fatal(err) } vtg := vtgateservicepb.NewVitessClient(gconn) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) _, err = vtg.Execute(ctx, &vtgatepb.ExecuteRequest{}) cancel() gconn.Close() if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("Dial(%s, FailFast=true): %v, must contain %s", address, err, wantErr) + t.Errorf("DialContext(%s, FailFast=true): %v, must contain %s", address, err, wantErr) } } wantErr = "DeadlineExceeded" for _, address := range addresses { - gconn, err := Dial(address, false, grpc.WithTransportCredentials(insecure.NewCredentials())) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + gconn, err := DialContext(ctx, address, false, grpc.WithTransportCredentials(insecure.NewCredentials())) + cancel() if err != nil { t.Fatal(err) } vtg := vtgateservicepb.NewVitessClient(gconn) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) _, err = vtg.Execute(ctx, &vtgatepb.ExecuteRequest{}) cancel() gconn.Close() if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("Dial(%s, FailFast=false): %v, must contain %s", address, err, wantErr) + t.Errorf("DialContext(%s, FailFast=false): %v, must contain %s", address, err, wantErr) } } } diff --git a/go/vt/mysqlctl/grpcmysqlctlclient/client.go b/go/vt/mysqlctl/grpcmysqlctlclient/client.go index 150402a8c44..027f6709eb6 100644 --- a/go/vt/mysqlctl/grpcmysqlctlclient/client.go +++ b/go/vt/mysqlctl/grpcmysqlctlclient/client.go @@ -40,9 +40,10 @@ type client struct { c mysqlctlpb.MysqlCtlClient } -func factory(network, addr string) (mysqlctlclient.MysqlctlClient, error) { +func factory(ctx context.Context, network, addr string) (mysqlctlclient.MysqlctlClient, error) { // create the RPC client - cc, err := grpcclient.Dial( + cc, err := grpcclient.DialContext( + ctx, addr, grpcclient.FailFast(false), grpc.WithTransportCredentials(insecure.NewCredentials()), diff --git a/go/vt/mysqlctl/mysqlctlclient/interface.go b/go/vt/mysqlctl/mysqlctlclient/interface.go index 4ab03a9df5b..e6f15b230db 100644 --- a/go/vt/mysqlctl/mysqlctlclient/interface.go +++ b/go/vt/mysqlctl/mysqlctlclient/interface.go @@ -70,7 +70,7 @@ type MysqlctlClient interface { } // Factory functions are registered by client implementations. -type Factory func(network, addr string) (MysqlctlClient, error) +type Factory func(ctx context.Context, network, addr string) (MysqlctlClient, error) var factories = make(map[string]Factory) @@ -83,10 +83,10 @@ func RegisterFactory(name string, factory Factory) { } // New creates a client implementation as specified by a flag. -func New(network, addr string) (MysqlctlClient, error) { +func New(ctx context.Context, network, addr string) (MysqlctlClient, error) { factory, ok := factories[protocol] if !ok { return nil, fmt.Errorf("unknown mysqlctl client protocol: %v", protocol) } - return factory(network, addr) + return factory(ctx, network, addr) } diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index 433ccc7de64..4a18f0a9f3a 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -263,7 +263,7 @@ func (mysqld *Mysqld) RunMysqlUpgrade(ctx context.Context) error { // Execute as remote action on mysqlctld if requested. if socketFile != "" { log.Infof("executing Mysqld.RunMysqlUpgrade() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -332,7 +332,7 @@ func (mysqld *Mysqld) Start(ctx context.Context, cnf *Mycnf, mysqldArgs ...strin // Execute as remote action on mysqlctld if requested. if socketFile != "" { log.Infof("executing Mysqld.Start() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -594,7 +594,7 @@ func (mysqld *Mysqld) Shutdown(ctx context.Context, cnf *Mycnf, waitForMysqld bo // Execute as remote action on mysqlctld if requested. if socketFile != "" { log.Infof("executing Mysqld.Shutdown() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -965,7 +965,7 @@ func (mysqld *Mysqld) RefreshConfig(ctx context.Context, cnf *Mycnf) error { // Execute as remote action on mysqlctld if requested. if socketFile != "" { log.Infof("executing Mysqld.RefreshConfig() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -1023,7 +1023,7 @@ func (mysqld *Mysqld) ReinitConfig(ctx context.Context, cnf *Mycnf) error { // Execute as remote action on mysqlctld if requested. if socketFile != "" { log.Infof("executing Mysqld.ReinitConfig() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -1260,7 +1260,7 @@ func (mysqld *Mysqld) GetVersionString(ctx context.Context) (string, error) { // Execute as remote action on mysqlctld to use the actual running MySQL // version. if socketFile != "" { - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return "", fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -1289,7 +1289,7 @@ func (mysqld *Mysqld) GetVersionComment(ctx context.Context) (string, error) { func (mysqld *Mysqld) ApplyBinlogFile(ctx context.Context, req *mysqlctlpb.ApplyBinlogFileRequest) error { if socketFile != "" { log.Infof("executing Mysqld.ApplyBinlogFile() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return fmt.Errorf("can't dial mysqlctld: %v", err) } @@ -1520,7 +1520,7 @@ func (mysqld *Mysqld) ReadBinlogFilesTimestamps(ctx context.Context, req *mysqlc } if socketFile != "" { log.Infof("executing Mysqld.ReadBinlogFilesTimestamps() remotely via mysqlctld server: %v", socketFile) - client, err := mysqlctlclient.New("unix", socketFile) + client, err := mysqlctlclient.New(ctx, "unix", socketFile) if err != nil { return nil, fmt.Errorf("can't dial mysqlctld: %v", err) } diff --git a/go/vt/srvtopo/resolver.go b/go/vt/srvtopo/resolver.go index 042e291c0a6..0ccfb0fd872 100644 --- a/go/vt/srvtopo/resolver.go +++ b/go/vt/srvtopo/resolver.go @@ -41,7 +41,7 @@ type Gateway interface { queryservice.QueryService // QueryServiceByAlias returns a QueryService - QueryServiceByAlias(alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) + QueryServiceByAlias(ctx context.Context, alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) // GetServingKeyspaces returns list of serving keyspaces. GetServingKeyspaces() []string diff --git a/go/vt/vtadmin/api_test.go b/go/vt/vtadmin/api_test.go index bb9cd62d788..82c744b95db 100644 --- a/go/vt/vtadmin/api_test.go +++ b/go/vt/vtadmin/api_test.go @@ -1065,7 +1065,7 @@ func TestGetKeyspace(t *testing.T) { }) } - testutil.WithTestServers(t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { + testutil.WithTestServers(ctx, t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { clusters := make([]*cluster.Cluster, len(clients)) for i, client := range clients { clusters[i] = vtadmintestutil.BuildCluster(t, vtadmintestutil.TestClusterConfig{ @@ -1310,7 +1310,7 @@ func TestGetKeyspaces(t *testing.T) { }), } - testutil.WithTestServers(t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { + testutil.WithTestServers(ctx, t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { clusters := []*cluster.Cluster{ vtadmintestutil.BuildCluster(t, vtadmintestutil.TestClusterConfig{ Cluster: &vtadminpb.Cluster{ @@ -1541,7 +1541,7 @@ func TestGetSchema(t *testing.T) { testutil.AddTablets(ctx, t, tt.ts, nil, vtadmintestutil.TopodataTabletsFromVTAdminTablets(tt.tablets)...) - testutil.WithTestServer(t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { c := vtadmintestutil.BuildCluster(t, vtadmintestutil.TestClusterConfig{ Cluster: &vtadminpb.Cluster{ Id: fmt.Sprintf("c%d", tt.clusterID), @@ -2195,7 +2195,7 @@ func TestGetSchemas(t *testing.T) { }), } - testutil.WithTestServers(t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { + testutil.WithTestServers(ctx, t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { clusters := make([]*cluster.Cluster, len(topos)) for cdx, toposerver := range topos { // Handle when a test doesn't define any tablets for a given cluster. @@ -2628,7 +2628,7 @@ func TestGetSrvKeyspace(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { for cell, sks := range tt.cellSrvKeyspaces { err := toposerver.UpdateSrvKeyspace(ctx, cell, tt.keyspace, sks) require.NoError(t, err) @@ -2790,7 +2790,7 @@ func TestGetSrvKeyspaces(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { for keyspace, sks := range tt.cellSrvKeyspaces { for cell, sk := range sks { err := toposerver.UpdateSrvKeyspace(ctx, cell, keyspace, sk) @@ -2953,7 +2953,7 @@ func TestGetSrvVSchema(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { for cell, svs := range tt.cellSrvVSchemas { err := toposerver.UpdateSrvVSchema(ctx, cell, svs) require.NoError(t, err) @@ -3245,7 +3245,7 @@ func TestGetSrvVSchemas(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { for cell, svs := range tt.cellSrvVSchemas { err := toposerver.UpdateSrvVSchema(ctx, cell, svs) require.NoError(t, err) @@ -5083,7 +5083,7 @@ func TestVTExplain(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctldserver, func(t *testing.T, vtctldClient vtctldclient.VtctldClient) { if tt.srvVSchema != nil { err := toposerver.UpdateSrvVSchema(ctx, "c0_cell1", tt.srvVSchema) require.NoError(t, err) diff --git a/go/vt/vtadmin/testutil/cluster.go b/go/vt/vtadmin/testutil/cluster.go index 793ef3b2142..bd238b388a8 100644 --- a/go/vt/vtadmin/testutil/cluster.go +++ b/go/vt/vtadmin/testutil/cluster.go @@ -120,7 +120,7 @@ func BuildCluster(t testing.TB, cfg TestClusterConfig) *cluster.Cluster { clusterConf.Name = cfg.Cluster.Name clusterConf.DiscoveryImpl = discoveryTestImplName - clusterConf = clusterConf.WithVtctldTestConfigOptions(vtadminvtctldclient.WithDialFunc(func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { + clusterConf = clusterConf.WithVtctldTestConfigOptions(vtadminvtctldclient.WithDialFunc(func(ctx context.Context, addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { return cfg.VtctldClient, nil })).WithVtSQLTestConfigOptions(vtsql.WithDialFunc(func(c vitessdriver.Configuration) (*sql.DB, error) { return sql.OpenDB(&fakevtsql.Connector{Tablets: tablets, ShouldErr: cfg.DBConfig.ShouldErr}), nil diff --git a/go/vt/vtadmin/vtctldclient/config.go b/go/vt/vtadmin/vtctldclient/config.go index 53b6fd83a5c..4a11001c3e6 100644 --- a/go/vt/vtadmin/vtctldclient/config.go +++ b/go/vt/vtadmin/vtctldclient/config.go @@ -17,6 +17,7 @@ limitations under the License. package vtctldclient import ( + "context" "fmt" "github.com/spf13/pflag" @@ -40,7 +41,7 @@ type Config struct { ResolverOptions *resolver.Options - dialFunc func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) + dialFunc func(ctx context.Context, addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) } // ConfigOption is a function that mutates a Config. It should return the same @@ -52,7 +53,7 @@ type ConfigOption func(cfg *Config) *Config // // It is used to support dependency injection in tests, and needs to be exported // for higher-level tests (via vtadmin/testutil). -func WithDialFunc(f func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error)) ConfigOption { +func WithDialFunc(f func(ctx context.Context, addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error)) ConfigOption { return func(cfg *Config) *Config { cfg.dialFunc = f return cfg diff --git a/go/vt/vtadmin/vtctldclient/proxy.go b/go/vt/vtadmin/vtctldclient/proxy.go index abb16ac556d..c0a9773ea2c 100644 --- a/go/vt/vtadmin/vtctldclient/proxy.go +++ b/go/vt/vtadmin/vtctldclient/proxy.go @@ -63,7 +63,7 @@ type ClientProxy struct { // DialFunc is called to open a new vtctdclient connection. In production, // this should always be grpcvtctldclient.NewWithDialOpts, but it is // exported for testing purposes. - dialFunc func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) + dialFunc func(ctx context.Context, addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) resolver grpcresolver.Builder m sync.Mutex @@ -124,8 +124,7 @@ func (vtctld *ClientProxy) dial(ctx context.Context) error { opts = append(opts, grpc.WithResolvers(vtctld.resolver)) - // TODO: update dialFunc to take ctx as first arg. - client, err := vtctld.dialFunc(resolver.DialAddr(vtctld.resolver, "vtctld"), grpcclient.FailFast(false), opts...) + client, err := vtctld.dialFunc(ctx, resolver.DialAddr(vtctld.resolver, "vtctld"), grpcclient.FailFast(false), opts...) if err != nil { return err } diff --git a/go/vt/vtcombo/tablet_map.go b/go/vt/vtcombo/tablet_map.go index 78b39cc2f14..5043ec0b48e 100644 --- a/go/vt/vtcombo/tablet_map.go +++ b/go/vt/vtcombo/tablet_map.go @@ -405,7 +405,7 @@ func CreateKs( // // dialer is our tabletconn.Dialer -func dialer(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { +func dialer(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { t, ok := tabletMap[tablet.Alias.Uid] if !ok { return nil, vterrors.New(vtrpcpb.Code_UNAVAILABLE, "connection refused") diff --git a/go/vt/vtctl/grpcvtctlclient/client.go b/go/vt/vtctl/grpcvtctlclient/client.go index f0fe94ca330..c09ead0687c 100644 --- a/go/vt/vtctl/grpcvtctlclient/client.go +++ b/go/vt/vtctl/grpcvtctlclient/client.go @@ -18,9 +18,8 @@ limitations under the License. package grpcvtctlclient import ( - "time" - "context" + "time" "google.golang.org/grpc" @@ -39,13 +38,13 @@ type gRPCVtctlClient struct { c vtctlservicepb.VtctlClient } -func gRPCVtctlClientFactory(addr string) (vtctlclient.VtctlClient, error) { +func gRPCVtctlClientFactory(ctx context.Context, addr string) (vtctlclient.VtctlClient, error) { opt, err := grpcclientcommon.SecureDialOption() if err != nil { return nil, err } // create the RPC client - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { return nil, err } diff --git a/go/vt/vtctl/grpcvtctlclient/client_test.go b/go/vt/vtctl/grpcvtctlclient/client_test.go index d065a706c65..00ec4888e76 100644 --- a/go/vt/vtctl/grpcvtctlclient/client_test.go +++ b/go/vt/vtctl/grpcvtctlclient/client_test.go @@ -57,7 +57,7 @@ func TestVtctlServer(t *testing.T) { go server.Serve(listener) // Create a VtctlClient gRPC client to talk to the fake server - client, err := gRPCVtctlClientFactory(fmt.Sprintf("localhost:%v", port)) + client, err := gRPCVtctlClientFactory(ctx, fmt.Sprintf("localhost:%v", port)) if err != nil { t.Fatalf("Cannot create client: %v", err) } @@ -117,7 +117,7 @@ func TestVtctlAuthClient(t *testing.T) { require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name()) // Create a VtctlClient gRPC client to talk to the fake server - client, err := gRPCVtctlClientFactory(fmt.Sprintf("localhost:%v", port)) + client, err := gRPCVtctlClientFactory(ctx, fmt.Sprintf("localhost:%v", port)) if err != nil { t.Fatalf("Cannot create client: %v", err) } diff --git a/go/vt/vtctl/grpcvtctldclient/client.go b/go/vt/vtctl/grpcvtctldclient/client.go index 497867aebb0..9015fee8009 100644 --- a/go/vt/vtctl/grpcvtctldclient/client.go +++ b/go/vt/vtctl/grpcvtctldclient/client.go @@ -48,13 +48,13 @@ type gRPCVtctldClient struct { //go:generate -command grpcvtctldclient go run ../vtctldclient/codegen //go:generate grpcvtctldclient --out client_gen.go -func gRPCVtctldClientFactory(addr string) (vtctldclient.VtctldClient, error) { +func gRPCVtctldClientFactory(ctx context.Context, addr string) (vtctldclient.VtctldClient, error) { opt, err := grpcclientcommon.SecureDialOption() if err != nil { return nil, err } - conn, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + conn, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { return nil, err } @@ -67,8 +67,8 @@ func gRPCVtctldClientFactory(addr string) (vtctldclient.VtctldClient, error) { // NewWithDialOpts returns a vtctldclient.VtctldClient configured with the given // DialOptions. It is exported for use in vtadmin. -func NewWithDialOpts(addr string, failFast grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { - conn, err := grpcclient.Dial(addr, failFast, opts...) +func NewWithDialOpts(ctx context.Context, addr string, failFast grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { + conn, err := grpcclient.DialContext(ctx, addr, failFast, opts...) if err != nil { return nil, err } diff --git a/go/vt/vtctl/grpcvtctldclient/client_test.go b/go/vt/vtctl/grpcvtctldclient/client_test.go index 1de9b6c895c..cb0e1477cf7 100644 --- a/go/vt/vtctl/grpcvtctldclient/client_test.go +++ b/go/vt/vtctl/grpcvtctldclient/client_test.go @@ -45,7 +45,7 @@ func TestFindAllShardsInKeyspace(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { ks := &vtctldatapb.Keyspace{ Name: "testkeyspace", Keyspace: &topodatapb.Keyspace{}, @@ -92,7 +92,7 @@ func TestGetKeyspace(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { expected := &vtctldatapb.GetKeyspaceResponse{ Keyspace: &vtctldatapb.Keyspace{ Name: "testkeyspace", @@ -121,7 +121,7 @@ func TestGetKeyspaces(t *testing.T) { return grpcvtctldserver.NewVtctldServer(vtenv.NewTestEnv(), ts) }) - testutil.WithTestServer(t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { + testutil.WithTestServer(ctx, t, vtctld, func(t *testing.T, client vtctldclient.VtctldClient) { resp, err := client.GetKeyspaces(ctx, &vtctldatapb.GetKeyspacesRequest{}) assert.NoError(t, err) assert.Empty(t, resp.Keyspaces) diff --git a/go/vt/vtctl/grpcvtctldserver/testutil/util.go b/go/vt/vtctl/grpcvtctldserver/testutil/util.go index 97638e9c41e..b685d22840b 100644 --- a/go/vt/vtctl/grpcvtctldserver/testutil/util.go +++ b/go/vt/vtctl/grpcvtctldserver/testutil/util.go @@ -41,7 +41,7 @@ import ( // implementation, then runs the test func with a client created to point at // that server. func WithTestServer( - t *testing.T, + ctx context.Context, t *testing.T, server vtctlservicepb.VtctldServer, test func(t *testing.T, client vtctldclient.VtctldClient), ) { @@ -56,7 +56,7 @@ func WithTestServer( go s.Serve(lis) defer s.Stop() - client, err := vtctldclient.New("grpc", lis.Addr().String()) + client, err := vtctldclient.New(ctx, "grpc", lis.Addr().String()) require.NoError(t, err, "cannot create vtctld client") defer client.Close() @@ -67,7 +67,7 @@ func WithTestServer( // implementations, and then runs the test func with N clients created, where // clients[i] points at servers[i]. func WithTestServers( - t *testing.T, + ctx context.Context, t *testing.T, test func(t *testing.T, clients ...vtctldclient.VtctldClient), servers ...vtctlservicepb.VtctldServer, ) { @@ -91,7 +91,7 @@ func WithTestServers( // Start up a test server for the head of our server slice, accumulate // the resulting client, and recurse on the tail of our server slice. - WithTestServer(t, servers[0], func(t *testing.T, client vtctldclient.VtctldClient) { + WithTestServer(ctx, t, servers[0], func(t *testing.T, client vtctldclient.VtctldClient) { clients = append(clients, client) withTestServers(t, servers[1:]...) }) diff --git a/go/vt/vtctl/localvtctldclient/client.go b/go/vt/vtctl/localvtctldclient/client.go index f94f1124037..abd02b7e28a 100644 --- a/go/vt/vtctl/localvtctldclient/client.go +++ b/go/vt/vtctl/localvtctldclient/client.go @@ -17,6 +17,7 @@ limitations under the License. package localvtctldclient import ( + "context" "errors" "sync" @@ -58,7 +59,7 @@ func SetServer(s vtctlservicepb.VtctldServer) { server = s } -func localVtctldClientFactory(addr string) (vtctldclient.VtctldClient, error) { +func localVtctldClientFactory(ctx context.Context, addr string) (vtctldclient.VtctldClient, error) { m.Lock() defer m.Unlock() diff --git a/go/vt/vtctl/vdiff_env_test.go b/go/vt/vtctl/vdiff_env_test.go index d09448c2866..fdcf29367cc 100644 --- a/go/vt/vtctl/vdiff_env_test.go +++ b/go/vt/vtctl/vdiff_env_test.go @@ -83,7 +83,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar // Generate a unique dialer name. dialerName := fmt.Sprintf("VDiffTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { env.mu.Lock() defer env.mu.Unlock() if qs, ok := env.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/vt/vtctl/vtctl_env_test.go b/go/vt/vtctl/vtctl_env_test.go index e502fbdf86a..7537eae9e8b 100644 --- a/go/vt/vtctl/vtctl_env_test.go +++ b/go/vt/vtctl/vtctl_env_test.go @@ -55,7 +55,7 @@ type testVTCtlEnv struct { var vtctlEnv *testVTCtlEnv func init() { - tabletconn.RegisterDialer("VTCtlTest", func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer("VTCtlTest", func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { vtctlEnv.mu.Lock() defer vtctlEnv.mu.Unlock() if qs, ok := vtctlEnv.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/vt/vtctl/vtctlclient/interface.go b/go/vt/vtctl/vtctlclient/interface.go index b750cdf8db6..8de7f48097b 100644 --- a/go/vt/vtctl/vtctlclient/interface.go +++ b/go/vt/vtctl/vtctlclient/interface.go @@ -56,7 +56,7 @@ type VtctlClient interface { } // Factory functions are registered by client implementations -type Factory func(addr string) (VtctlClient, error) +type Factory func(ctx context.Context, addr string) (VtctlClient, error) var factories = make(map[string]Factory) @@ -69,10 +69,10 @@ func RegisterFactory(name string, factory Factory) { } // New allows a user of the client library to get its implementation. -func New(addr string) (VtctlClient, error) { +func New(ctx context.Context, addr string) (VtctlClient, error) { factory, ok := factories[vtctlClientProtocol] if !ok { return nil, fmt.Errorf("unknown vtctl client protocol: %v", vtctlClientProtocol) } - return factory(addr) + return factory(ctx, addr) } diff --git a/go/vt/vtctl/vtctlclient/wrapper.go b/go/vt/vtctl/vtctlclient/wrapper.go index a30dde3e8dd..d33aad5b4e3 100644 --- a/go/vt/vtctl/vtctlclient/wrapper.go +++ b/go/vt/vtctl/vtctlclient/wrapper.go @@ -17,13 +17,12 @@ limitations under the License. package vtctlclient import ( + "context" "errors" "fmt" "io" "time" - "context" - logutilpb "vitess.io/vitess/go/vt/proto/logutil" ) @@ -39,7 +38,7 @@ func RunCommandAndWait(ctx context.Context, server string, args []string, recv f return errors.New("no function closure for Event stream specified") } // create the client - client, err := New(server) + client, err := New(ctx, server) if err != nil { return fmt.Errorf("cannot dial to server %v: %v", server, err) } diff --git a/go/vt/vtctl/vtctldclient/client.go b/go/vt/vtctl/vtctldclient/client.go index 6e0c97bb8a5..4b6def326db 100644 --- a/go/vt/vtctl/vtctldclient/client.go +++ b/go/vt/vtctl/vtctldclient/client.go @@ -3,6 +3,7 @@ package vtctldclient import ( + "context" "fmt" "log" @@ -17,7 +18,7 @@ type VtctldClient interface { } // Factory is a function that creates new VtctldClients. -type Factory func(addr string) (VtctldClient, error) +type Factory func(ctx context.Context, addr string) (VtctldClient, error) var registry = map[string]Factory{} @@ -40,11 +41,11 @@ func Register(name string, factory Factory) { // global namespace to determine the protocol to use. Instead, we require // users to specify their own flag in their own (hopefully not global) namespace // to determine the protocol to pass into here. -func New(protocol string, addr string) (VtctldClient, error) { +func New(ctx context.Context, protocol string, addr string) (VtctldClient, error) { factory, ok := registry[protocol] if !ok { return nil, fmt.Errorf("unknown vtctld client protocol: %s", protocol) } - return factory(addr) + return factory(ctx, addr) } diff --git a/go/vt/vtctld/tablet_data.go b/go/vt/vtctld/tablet_data.go index 66cfed6b4a9..f9482849bee 100644 --- a/go/vt/vtctld/tablet_data.go +++ b/go/vt/vtctld/tablet_data.go @@ -113,7 +113,7 @@ func (th *tabletHealth) stream(ctx context.Context, ts *topo.Server, tabletAlias return err } - conn, err := tabletconn.GetDialer()(ti.Tablet, grpcclient.FailFast(true)) + conn, err := tabletconn.GetDialer()(ctx, ti.Tablet, grpcclient.FailFast(true)) if err != nil { return err } diff --git a/go/vt/vtgate/sandbox_test.go b/go/vt/vtgate/sandbox_test.go index 27be6442cfe..dc3c1f103af 100644 --- a/go/vt/vtgate/sandbox_test.go +++ b/go/vt/vtgate/sandbox_test.go @@ -318,7 +318,7 @@ func (sct *sandboxTopo) WatchSrvVSchema(ctx context.Context, cell string, callba }() } -func sandboxDialer(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { +func sandboxDialer(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { sand := getSandbox(tablet.Keyspace) sand.sandmu.Lock() defer sand.sandmu.Unlock() diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 2b37c865187..8b571f7b67d 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -188,7 +188,7 @@ func (stc *ScatterConn) ExecuteMultiShard( } } - qs, err = getQueryService(rs, info, session, false) + qs, err = getQueryService(ctx, rs, info, session, false) if err != nil { return nil, err } @@ -300,11 +300,11 @@ func checkAndResetShardSession(info *shardActionInfo, err error, session *SafeSe return retry } -func getQueryService(rs *srvtopo.ResolvedShard, info *shardActionInfo, session *SafeSession, skipReset bool) (queryservice.QueryService, error) { +func getQueryService(ctx context.Context, rs *srvtopo.ResolvedShard, info *shardActionInfo, session *SafeSession, skipReset bool) (queryservice.QueryService, error) { if info.alias == nil { return rs.Gateway, nil } - qs, err := rs.Gateway.QueryServiceByAlias(info.alias, rs.Target) + qs, err := rs.Gateway.QueryServiceByAlias(ctx, info.alias, rs.Target) if err == nil || skipReset { return qs, err } @@ -386,7 +386,7 @@ func (stc *ScatterConn) StreamExecuteMulti( } } - qs, err = getQueryService(rs, info, session, false) + qs, err = getQueryService(ctx, rs, info, session, false) if err != nil { return nil, err } @@ -732,7 +732,7 @@ func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedSha _ = stc.txConn.ReleaseLock(ctx, session) return nil, vterrors.Wrap(err, "Any previous held locks are released") } - qs, err := getQueryService(rs, info, nil, true) + qs, err := getQueryService(ctx, rs, info, nil, true) if err != nil { return nil, err } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 496224c207f..1139221b659 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -146,8 +146,8 @@ func (gw *TabletGateway) setupBuffering(ctx context.Context) { } // QueryServiceByAlias satisfies the Gateway interface -func (gw *TabletGateway) QueryServiceByAlias(alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) { - qs, err := gw.hc.TabletConnection(alias, target) +func (gw *TabletGateway) QueryServiceByAlias(ctx context.Context, alias *topodatapb.TabletAlias, target *querypb.Target) (queryservice.QueryService, error) { + qs, err := gw.hc.TabletConnection(ctx, alias, target) return queryservice.Wrap(qs, gw.withShardError), NewShardError(err, target) } diff --git a/go/vt/vtgate/tx_conn.go b/go/vt/vtgate/tx_conn.go index f21686d01d8..2eccdc54992 100644 --- a/go/vt/vtgate/tx_conn.go +++ b/go/vt/vtgate/tx_conn.go @@ -104,11 +104,11 @@ func (txc *TxConn) Commit(ctx context.Context, session *SafeSession) error { return txc.commitNormal(ctx, session) } -func (txc *TxConn) queryService(alias *topodatapb.TabletAlias) (queryservice.QueryService, error) { +func (txc *TxConn) queryService(ctx context.Context, alias *topodatapb.TabletAlias) (queryservice.QueryService, error) { if alias == nil { return txc.tabletGateway, nil } - return txc.tabletGateway.QueryServiceByAlias(alias, nil) + return txc.tabletGateway.QueryServiceByAlias(ctx, alias, nil) } func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { @@ -117,7 +117,7 @@ func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSes } var qs queryservice.QueryService var err error - qs, err = txc.queryService(s.TabletAlias) + qs, err = txc.queryService(ctx, s.TabletAlias) if err != nil { return err } @@ -243,7 +243,7 @@ func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error { if s.TransactionId == 0 { return nil } - qs, err := txc.queryService(s.TabletAlias) + qs, err := txc.queryService(ctx, s.TabletAlias) if err != nil { return err } @@ -279,7 +279,7 @@ func (txc *TxConn) Release(ctx context.Context, session *SafeSession) error { if s.ReservedId == 0 && s.TransactionId == 0 { return nil } - qs, err := txc.queryService(s.TabletAlias) + qs, err := txc.queryService(ctx, s.TabletAlias) if err != nil { return err } @@ -305,7 +305,7 @@ func (txc *TxConn) ReleaseLock(ctx context.Context, session *SafeSession) error if ls.ReservedId == 0 { return nil } - qs, err := txc.queryService(ls.TabletAlias) + qs, err := txc.queryService(ctx, ls.TabletAlias) if err != nil { return err } @@ -329,7 +329,7 @@ func (txc *TxConn) ReleaseAll(ctx context.Context, session *SafeSession) error { if s.ReservedId == 0 && s.TransactionId == 0 { return nil } - qs, err := txc.queryService(s.TabletAlias) + qs, err := txc.queryService(ctx, s.TabletAlias) if err != nil { return err } @@ -362,7 +362,7 @@ func (txc *TxConn) Resolve(ctx context.Context, dtid string) error { case querypb.TransactionState_PREPARE: // If state is PREPARE, make a decision to rollback and // fallthrough to the rollback workflow. - qs, err := txc.queryService(mmShard.TabletAlias) + qs, err := txc.queryService(ctx, mmShard.TabletAlias) if err != nil { return err } diff --git a/go/vt/vtgate/vstream_manager.go b/go/vt/vtgate/vstream_manager.go index c5bc460de09..e0d195853cf 100644 --- a/go/vt/vtgate/vstream_manager.go +++ b/go/vt/vtgate/vstream_manager.go @@ -541,7 +541,7 @@ func (vs *vstream) streamFromTablet(ctx context.Context, sgtid *binlogdatapb.Sha TabletType: vs.tabletType, Cell: vs.vsm.cell, } - tabletConn, err := vs.vsm.resolver.GetGateway().QueryServiceByAlias(tablet.Alias, target) + tabletConn, err := vs.vsm.resolver.GetGateway().QueryServiceByAlias(ctx, tablet.Alias, target) if err != nil { log.Errorf(err.Error()) return err diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index 8bb8a466b21..fe446fbec27 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -83,7 +83,7 @@ type gRPCQueryClient struct { var _ queryservice.QueryService = (*gRPCQueryClient)(nil) // DialTablet creates and initializes gRPCQueryClient. -func DialTablet(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { +func DialTablet(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { // create the RPC client addr := "" if grpcPort, ok := tablet.PortMap["grpc"]; ok { @@ -95,7 +95,7 @@ func DialTablet(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (querys if err != nil { return nil, err } - cc, err := grpcclient.Dial(addr, failFast, opt) + cc, err := grpcclient.DialContext(ctx, addr, failFast, opt) if err != nil { return nil, err } diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index 70e30e337bc..74ed85a335f 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -56,9 +56,13 @@ func TestGRPCTabletConn(t *testing.T) { server := grpc.NewServer() grpcqueryservice.Register(server, service) go server.Serve(listener) + defer server.Stop() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // run the test suite - tabletconntest.TestSuite(t, protocolName, &topodatapb.Tablet{ + tabletconntest.TestSuite(ctx, t, protocolName, &topodatapb.Tablet{ Keyspace: tabletconntest.TestTarget.Keyspace, Shard: tabletconntest.TestTarget.Shard, Type: tabletconntest.TestTarget.TabletType, @@ -91,6 +95,7 @@ func TestGRPCTabletAuthConn(t *testing.T) { grpcqueryservice.Register(server, service) go server.Serve(listener) + defer server.Stop() authJSON := `{ "Username": "valid", @@ -109,8 +114,10 @@ func TestGRPCTabletAuthConn(t *testing.T) { t.Fatal(err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // run the test suite - tabletconntest.TestSuite(t, protocolName, &topodatapb.Tablet{ + tabletconntest.TestSuite(ctx, t, protocolName, &topodatapb.Tablet{ Keyspace: tabletconntest.TestTarget.Keyspace, Shard: tabletconntest.TestTarget.Shard, Type: tabletconntest.TestTarget.TabletType, diff --git a/go/vt/vttablet/grpctmclient/client.go b/go/vt/vttablet/grpctmclient/client.go index d6f19be7d6d..1d054eb7b9c 100644 --- a/go/vt/vttablet/grpctmclient/client.go +++ b/go/vt/vttablet/grpctmclient/client.go @@ -157,7 +157,7 @@ func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) ( if err != nil { return nil, nil, err } - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { return nil, nil, err } @@ -165,8 +165,8 @@ func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) ( return tabletmanagerservicepb.NewTabletManagerClient(cc), cc, nil } -func (client *grpcClient) createTmc(addr string, opt grpc.DialOption) (*tmc, error) { - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) +func (client *grpcClient) createTmc(ctx context.Context, addr string, opt grpc.DialOption) (*tmc, error) { + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { return nil, err } @@ -194,7 +194,7 @@ func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Table client.mu.Unlock() for i := 0; i < cap(c); i++ { - tm, err := client.createTmc(addr, opt) + tm, err := client.createTmc(ctx, addr, opt) if err != nil { return nil, err } @@ -226,7 +226,7 @@ func (client *grpcClient) dialDedicatedPool(ctx context.Context, dialPoolGroup D } m := client.rpcDialPoolMap[dialPoolGroup] if _, ok := m[addr]; !ok { - tm, err := client.createTmc(addr, opt) + tm, err := client.createTmc(ctx, addr, opt) if err != nil { return nil, nil, err } diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index 8def73bf99e..618a87b1d81 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -586,7 +586,7 @@ func (sbc *SandboxConn) VStreamResults(ctx context.Context, target *querypb.Targ } // QueryServiceByAlias is part of the Gateway interface. -func (sbc *SandboxConn) QueryServiceByAlias(_ *topodatapb.TabletAlias, _ *querypb.Target) (queryservice.QueryService, error) { +func (sbc *SandboxConn) QueryServiceByAlias(_ context.Context, _ *topodatapb.TabletAlias, _ *querypb.Target) (queryservice.QueryService, error) { return sbc, nil } diff --git a/go/vt/vttablet/tabletconn/tablet_conn.go b/go/vt/vttablet/tabletconn/tablet_conn.go index 0c91fdd55bc..1ed806bcc53 100644 --- a/go/vt/vttablet/tabletconn/tablet_conn.go +++ b/go/vt/vttablet/tabletconn/tablet_conn.go @@ -17,6 +17,7 @@ limitations under the License. package tabletconn import ( + "context" "sync" "github.com/spf13/pflag" @@ -65,7 +66,7 @@ func init() { // timeout represents the connection timeout. If set to 0, this // connection should be established in the background and the // TabletDialer should return right away. -type TabletDialer func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) +type TabletDialer func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) var dialers = make(map[string]TabletDialer) diff --git a/go/vt/vttablet/tabletconntest/fakequeryservice.go b/go/vt/vttablet/tabletconntest/fakequeryservice.go index d3adff022e4..2efd7d330ed 100644 --- a/go/vt/vttablet/tabletconntest/fakequeryservice.go +++ b/go/vt/vttablet/tabletconntest/fakequeryservice.go @@ -709,7 +709,7 @@ func (f *FakeQueryService) VStreamResults(ctx context.Context, target *querypb.T } // QueryServiceByAlias satisfies the Gateway interface -func (f *FakeQueryService) QueryServiceByAlias(_ *topodatapb.TabletAlias, _ *querypb.Target) (queryservice.QueryService, error) { +func (f *FakeQueryService) QueryServiceByAlias(_ context.Context, _ *topodatapb.TabletAlias, _ *querypb.Target) (queryservice.QueryService, error) { panic("not implemented") } diff --git a/go/vt/vttablet/tabletconntest/tabletconntest.go b/go/vt/vttablet/tabletconntest/tabletconntest.go index b279ac53726..f8dafb0636e 100644 --- a/go/vt/vttablet/tabletconntest/tabletconntest.go +++ b/go/vt/vttablet/tabletconntest/tabletconntest.go @@ -922,7 +922,7 @@ func testStreamHealthPanics(t *testing.T, conn queryservice.QueryService, f *Fak // TestSuite runs all the tests. // If fake.TestingGateway is set, we only test the calls that can go through // a gateway. -func TestSuite(t *testing.T, protocol string, tablet *topodatapb.Tablet, fake *FakeQueryService, clientCreds *os.File) { +func TestSuite(ctx context.Context, t *testing.T, protocol string, tablet *topodatapb.Tablet, fake *FakeQueryService, clientCreds *os.File) { tests := []func(*testing.T, queryservice.QueryService, *FakeQueryService){ // positive test cases testBegin, @@ -1015,7 +1015,7 @@ func TestSuite(t *testing.T, protocol string, tablet *topodatapb.Tablet, fake *F require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", clientCreds.Name()) } - conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, tablet, grpcclient.FailFast(false)) if err != nil { t.Fatalf("dial failed: %v", err) } diff --git a/go/vt/vttablet/tabletmanager/framework_test.go b/go/vt/vttablet/tabletmanager/framework_test.go index 2d606d81597..27a3a562cd3 100644 --- a/go/vt/vttablet/tabletmanager/framework_test.go +++ b/go/vt/vttablet/tabletmanager/framework_test.go @@ -56,7 +56,7 @@ const ( ) func init() { - tabletconn.RegisterDialer("grpc", func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer("grpc", func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { return &tabletconntest.FakeQueryService{ StreamHealthResponse: &querypb.StreamHealthResponse{ Serving: true, @@ -98,7 +98,7 @@ func newTestEnv(t *testing.T, ctx context.Context, sourceKeyspace string, source tenv.tmc.sourceShards = sourceShards tenv.tmc.schema = defaultSchema - tabletconn.RegisterDialer(t.Name(), func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(t.Name(), func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { tenv.mu.Lock() defer tenv.mu.Unlock() if qs, ok := tenv.tmc.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/vt/vttablet/tabletmanager/vdiff/framework_test.go b/go/vt/vttablet/tabletmanager/vdiff/framework_test.go index 0676c5204be..43aa76894d4 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/framework_test.go +++ b/go/vt/vttablet/tabletmanager/vdiff/framework_test.go @@ -155,7 +155,7 @@ type LogExpectation struct { } func init() { - tabletconn.RegisterDialer("test", func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer("test", func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { vdiffenv.mu.Lock() defer vdiffenv.mu.Unlock() if qs, ok := vdiffenv.tablets[int(tablet.Alias.Uid)]; ok { @@ -164,7 +164,7 @@ func init() { return nil, fmt.Errorf("tablet %d not found", tablet.Alias.Uid) }) // TableDiffer does a default grpc dial just to be sure it can talk to the tablet. - tabletconn.RegisterDialer("grpc", func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer("grpc", func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { vdiffenv.mu.Lock() defer vdiffenv.mu.Unlock() if qs, ok := vdiffenv.tablets[int(tablet.Alias.Uid)]; ok { @@ -295,7 +295,7 @@ type fakeBinlogClient struct { lastCharset *binlogdatapb.Charset } -func (fbc *fakeBinlogClient) Dial(tablet *topodatapb.Tablet) error { +func (fbc *fakeBinlogClient) Dial(ctx context.Context, tablet *topodatapb.Tablet) error { fbc.lastTablet = tablet return nil } diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 1b64662e551..142e79c40d0 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -406,7 +406,7 @@ func (td *tableDiffer) streamOneShard(ctx context.Context, participant *shardStr td.wgShardStreamers.Done() }() participant.err = func() error { - conn, err := tabletconn.GetDialer()(participant.tablet, false) + conn, err := tabletconn.GetDialer()(ctx, participant.tablet, false) if err != nil { return err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/external_connector.go b/go/vt/vttablet/tabletmanager/vreplication/external_connector.go index 873bf498c14..c53bfd2a584 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/external_connector.go +++ b/go/vt/vttablet/tabletmanager/vreplication/external_connector.go @@ -172,7 +172,7 @@ func newTabletConnector(tablet *topodatapb.Tablet) *tabletConnector { func (tc *tabletConnector) Open(ctx context.Context) error { var err error - tc.qs, err = tabletconn.GetDialer()(tc.tablet, grpcclient.FailFast(true)) + tc.qs, err = tabletconn.GetDialer()(ctx, tc.tablet, grpcclient.FailFast(true)) return err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go index 262ba28187d..ec7d2d4529d 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go @@ -109,7 +109,7 @@ func setFlag(flagName, flagValue string) { } func init() { - tabletconn.RegisterDialer("test", func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer("test", func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { return &fakeTabletConn{ QueryService: fakes.ErrorQueryService, tablet: tablet, @@ -364,7 +364,7 @@ type fakeBinlogClient struct { lastCharset *binlogdatapb.Charset } -func (fbc *fakeBinlogClient) Dial(tablet *topodatapb.Tablet) error { +func (fbc *fakeBinlogClient) Dial(ctx context.Context, tablet *topodatapb.Tablet) error { fbc.lastTablet = tablet return nil } diff --git a/go/vt/vttablet/tabletserver/txthrottler/mock_healthcheck_test.go b/go/vt/vttablet/tabletserver/txthrottler/mock_healthcheck_test.go index 3b298cacddf..ecc6688fb9d 100644 --- a/go/vt/vttablet/tabletserver/txthrottler/mock_healthcheck_test.go +++ b/go/vt/vttablet/tabletserver/txthrottler/mock_healthcheck_test.go @@ -210,9 +210,9 @@ func (mr *MockHealthCheckMockRecorder) Subscribe() *gomock.Call { } // TabletConnection mocks base method. -func (m *MockHealthCheck) TabletConnection(arg0 *topodata.TabletAlias, arg1 *query.Target) (queryservice.QueryService, error) { +func (m *MockHealthCheck) TabletConnection(arg0 context.Context, arg1 *topodata.TabletAlias, arg2 *query.Target) (queryservice.QueryService, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TabletConnection", arg0, arg1) + ret := m.ctrl.Call(m, "TabletConnection", arg0, arg1, arg2) ret0, _ := ret[0].(queryservice.QueryService) ret1, _ := ret[1].(error) return ret0, ret1 diff --git a/go/vt/wrangler/split.go b/go/vt/wrangler/split.go index 543d50a808d..197bfe4cc66 100644 --- a/go/vt/wrangler/split.go +++ b/go/vt/wrangler/split.go @@ -101,7 +101,7 @@ func (wr *Wrangler) WaitForFilteredReplication(ctx context.Context, keyspace, sh return fmt.Errorf("failed to run explicit healthcheck on tablet: %v err: %v", tabletInfo, err) } - conn, err := tabletconn.GetDialer()(tabletInfo.Tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, tabletInfo.Tablet, grpcclient.FailFast(false)) if err != nil { return fmt.Errorf("cannot connect to tablet %v: %v", alias, err) } diff --git a/go/vt/wrangler/testlib/backup_test.go b/go/vt/wrangler/testlib/backup_test.go index 0de8bfd78f3..ce46734a5e9 100644 --- a/go/vt/wrangler/testlib/backup_test.go +++ b/go/vt/wrangler/testlib/backup_test.go @@ -93,7 +93,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { defer db.Close() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Set up mock query results. @@ -345,7 +345,7 @@ func TestBackupRestoreLagged(t *testing.T) { defer db.Close() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Set up mock query results. @@ -564,7 +564,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { defer db.Close() ts := memorytopo.NewServer(ctx, "cell1") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Set up mock query results. @@ -739,7 +739,7 @@ func TestDisableActiveReparents(t *testing.T) { defer db.Close() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Set up mock query results. diff --git a/go/vt/wrangler/testlib/copy_schema_shard_test.go b/go/vt/wrangler/testlib/copy_schema_shard_test.go index 262a93f4e23..2e113eec3ae 100644 --- a/go/vt/wrangler/testlib/copy_schema_shard_test.go +++ b/go/vt/wrangler/testlib/copy_schema_shard_test.go @@ -57,7 +57,7 @@ func copySchema(t *testing.T, useShardAsSource bool) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() if err := ts.CreateKeyspace(context.Background(), "ks", &topodatapb.Keyspace{}); err != nil { diff --git a/go/vt/wrangler/testlib/emergency_reparent_shard_test.go b/go/vt/wrangler/testlib/emergency_reparent_shard_test.go index 6cafe83b684..a23562153e2 100644 --- a/go/vt/wrangler/testlib/emergency_reparent_shard_test.go +++ b/go/vt/wrangler/testlib/emergency_reparent_shard_test.go @@ -52,7 +52,7 @@ func TestEmergencyReparentShard(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas diff --git a/go/vt/wrangler/testlib/external_reparent_test.go b/go/vt/wrangler/testlib/external_reparent_test.go index fdc1ca664ee..59d7c05d0f3 100644 --- a/go/vt/wrangler/testlib/external_reparent_test.go +++ b/go/vt/wrangler/testlib/external_reparent_test.go @@ -51,7 +51,7 @@ func TestTabletExternallyReparentedBasic(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create an old primary, a new primary, two good replicas, one bad replica diff --git a/go/vt/wrangler/testlib/permissions_test.go b/go/vt/wrangler/testlib/permissions_test.go index ba110a30d87..35c92a0233d 100644 --- a/go/vt/wrangler/testlib/permissions_test.go +++ b/go/vt/wrangler/testlib/permissions_test.go @@ -49,7 +49,7 @@ func TestPermissions(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() primary := NewFakeTablet(t, wr, "cell1", 0, topodatapb.TabletType_PRIMARY, nil) diff --git a/go/vt/wrangler/testlib/planned_reparent_shard_test.go b/go/vt/wrangler/testlib/planned_reparent_shard_test.go index 7069df9d3e1..65babcaf48d 100644 --- a/go/vt/wrangler/testlib/planned_reparent_shard_test.go +++ b/go/vt/wrangler/testlib/planned_reparent_shard_test.go @@ -53,7 +53,7 @@ func TestPlannedReparentShardNoPrimaryProvided(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -169,7 +169,7 @@ func TestPlannedReparentShardNoError(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -305,7 +305,7 @@ func TestPlannedReparentInitialization(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a few replicas. @@ -391,7 +391,7 @@ func TestPlannedReparentShardWaitForPositionFail(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -499,7 +499,7 @@ func TestPlannedReparentShardWaitForPositionTimeout(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -605,7 +605,7 @@ func TestPlannedReparentShardRelayLogError(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -685,7 +685,7 @@ func TestPlannedReparentShardRelayLogErrorStartReplication(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -770,7 +770,7 @@ func TestPlannedReparentShardPromoteReplicaFail(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas @@ -910,7 +910,7 @@ func TestPlannedReparentShardSamePrimary(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas diff --git a/go/vt/wrangler/testlib/shard_test.go b/go/vt/wrangler/testlib/shard_test.go index 400071d9e3c..7528a220d1f 100644 --- a/go/vt/wrangler/testlib/shard_test.go +++ b/go/vt/wrangler/testlib/shard_test.go @@ -37,7 +37,7 @@ func TestDeleteShardCleanup(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // Create a primary, a couple good replicas diff --git a/go/vt/wrangler/testlib/version_test.go b/go/vt/wrangler/testlib/version_test.go index c0ea92a5b46..ea13a43bc8f 100644 --- a/go/vt/wrangler/testlib/version_test.go +++ b/go/vt/wrangler/testlib/version_test.go @@ -72,7 +72,7 @@ func TestVersion(t *testing.T) { defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") wr := wrangler.New(vtenv.NewTestEnv(), logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() // couple tablets is enough diff --git a/go/vt/wrangler/testlib/vtctl_pipe.go b/go/vt/wrangler/testlib/vtctl_pipe.go index 44a6931870a..594290e4023 100644 --- a/go/vt/wrangler/testlib/vtctl_pipe.go +++ b/go/vt/wrangler/testlib/vtctl_pipe.go @@ -53,7 +53,7 @@ type VtctlPipe struct { } // NewVtctlPipe creates a new VtctlPipe based on the given topo server. -func NewVtctlPipe(t *testing.T, ts *topo.Server) *VtctlPipe { +func NewVtctlPipe(ctx context.Context, t *testing.T, ts *topo.Server) *VtctlPipe { // Register all vtctl commands servenvInitialized.Do(func() { // make sure we use the right protocol @@ -81,7 +81,7 @@ func NewVtctlPipe(t *testing.T, ts *topo.Server) *VtctlPipe { go server.Serve(listener) // Create a VtctlClient gRPC client to talk to the fake server - client, err := vtctlclient.New(listener.Addr().String()) + client, err := vtctlclient.New(ctx, listener.Addr().String()) if err != nil { t.Fatalf("Cannot create client: %v", err) } diff --git a/go/vt/wrangler/testlib/vtctl_topo_test.go b/go/vt/wrangler/testlib/vtctl_topo_test.go index a13535f4111..325d629c1ff 100644 --- a/go/vt/wrangler/testlib/vtctl_topo_test.go +++ b/go/vt/wrangler/testlib/vtctl_topo_test.go @@ -62,7 +62,7 @@ func TestVtctlTopoCommands(t *testing.T) { if err := ts.CreateKeyspace(context.Background(), "ks2", &topodatapb.Keyspace{KeyspaceType: topodatapb.KeyspaceType_SNAPSHOT}); err != nil { t.Fatalf("CreateKeyspace() failed: %v", err) } - vp := NewVtctlPipe(t, ts) + vp := NewVtctlPipe(ctx, t, ts) defer vp.Close() tmp := t.TempDir() diff --git a/go/vt/wrangler/traffic_switcher_env_test.go b/go/vt/wrangler/traffic_switcher_env_test.go index 7705ee49f45..4e58024785d 100644 --- a/go/vt/wrangler/traffic_switcher_env_test.go +++ b/go/vt/wrangler/traffic_switcher_env_test.go @@ -159,7 +159,7 @@ func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, } dialerName := fmt.Sprintf("TrafficSwitcherTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { tme.mu.Lock() defer tme.mu.Unlock() allPrimaries := append(tme.sourcePrimaries, tme.targetPrimaries...) @@ -425,7 +425,7 @@ func newTestTablePartialMigrater(ctx context.Context, t *testing.T, shards, shar } dialerName := fmt.Sprintf("TrafficSwitcherTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { tme.mu.Lock() defer tme.mu.Unlock() for _, ft := range append(tme.sourcePrimaries, tme.targetPrimaries...) { @@ -590,7 +590,7 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe } dialerName := fmt.Sprintf("TrafficSwitcherTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { tme.mu.Lock() defer tme.mu.Unlock() for _, ft := range append(tme.sourcePrimaries, tme.targetPrimaries...) { diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 2196152b122..8145d1c9e51 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -969,7 +969,7 @@ func (df *vdiff) streamOne(ctx context.Context, keyspace, shard string, particip // Wrap the streaming in a separate function so we can capture the error. // This shows that the error will be set before the channels are closed. participant.err = func() error { - conn, err := tabletconn.GetDialer()(participant.tablet, grpcclient.FailFast(false)) + conn, err := tabletconn.GetDialer()(ctx, participant.tablet, grpcclient.FailFast(false)) if err != nil { return err } diff --git a/go/vt/wrangler/vdiff_env_test.go b/go/vt/wrangler/vdiff_env_test.go index ac30736c999..5d1967770ce 100644 --- a/go/vt/wrangler/vdiff_env_test.go +++ b/go/vt/wrangler/vdiff_env_test.go @@ -82,7 +82,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar // Generate a unique dialer name. dialerName := fmt.Sprintf("VDiffTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { env.mu.Lock() defer env.mu.Unlock() if qs, ok := env.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/vt/wrangler/wrangler_env_test.go b/go/vt/wrangler/wrangler_env_test.go index 04231fb7bf3..2b174bee176 100644 --- a/go/vt/wrangler/wrangler_env_test.go +++ b/go/vt/wrangler/wrangler_env_test.go @@ -74,7 +74,7 @@ func newWranglerTestEnv(t testing.TB, ctx context.Context, sourceShards, targetS // Generate a unique dialer name. dialerName := fmt.Sprintf("WranglerTest-%s-%d", t.Name(), rand.IntN(1000000000)) - tabletconn.RegisterDialer(dialerName, func(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { + tabletconn.RegisterDialer(dialerName, func(ctx context.Context, tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (queryservice.QueryService, error) { env.mu.Lock() defer env.mu.Unlock() if qs, ok := env.tmc.tablets[int(tablet.Alias.Uid)]; ok { diff --git a/go/vtbench/client.go b/go/vtbench/client.go index 1a6751a62db..3e3ef3c495d 100644 --- a/go/vtbench/client.go +++ b/go/vtbench/client.go @@ -137,7 +137,7 @@ func (c *grpcVttabletConn) connect(ctx context.Context, cp ConnParams) error { Keyspace: keyspace, } var err error - qs, err = tabletconn.GetDialer()(&tablet, true) + qs, err = tabletconn.GetDialer()(ctx, &tablet, true) if err != nil { return err } diff --git a/tools/rowlog/rowlog.go b/tools/rowlog/rowlog.go index 8092159c6b6..34d16a1777b 100644 --- a/tools/rowlog/rowlog.go +++ b/tools/rowlog/rowlog.go @@ -496,7 +496,7 @@ func getPosition(ctx context.Context, server, keyspace, shard string) (string, e } func execVtctl(ctx context.Context, server string, args []string) ([]string, error) { - client, err := vtctlclient.New(server) + client, err := vtctlclient.New(ctx, server) if err != nil { fmt.Println(err) return nil, err