diff --git a/tools/pd-simulator/main.go b/tools/pd-simulator/main.go index 05763cc93b8..12254c1a947 100644 --- a/tools/pd-simulator/main.go +++ b/tools/pd-simulator/main.go @@ -25,7 +25,6 @@ import ( "github.com/BurntSushi/toml" "github.com/pingcap/log" flag "github.com/spf13/pflag" - pdHttp "github.com/tikv/pd/client/http" "github.com/tikv/pd/pkg/schedule/schedulers" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/utils/logutil" @@ -93,7 +92,6 @@ func main() { func run(simCase string, simConfig *sc.SimConfig) { if *pdAddr != "" { - simulator.PDHTTPClient = pdHttp.NewClient("pd-simulator", []string{*pdAddr}) simStart(*pdAddr, *statusAddress, simCase, simConfig) } else { local, clean := NewSingleServer(context.Background(), simConfig) @@ -107,7 +105,6 @@ func run(simCase string, simConfig *sc.SimConfig) { } time.Sleep(100 * time.Millisecond) } - simulator.PDHTTPClient = pdHttp.NewClient("pd-simulator", []string{local.GetAddr()}) simStart(local.GetAddr(), "", simCase, simConfig, clean) } } diff --git a/tools/pd-simulator/simulator/client.go b/tools/pd-simulator/simulator/client.go index 0bbbebe4602..f5bd379d17e 100644 --- a/tools/pd-simulator/simulator/client.go +++ b/tools/pd-simulator/simulator/client.go @@ -16,6 +16,7 @@ package simulator import ( "context" + "fmt" "strconv" "strings" "sync" @@ -24,6 +25,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + pd "github.com/tikv/pd/client" pdHttp "github.com/tikv/pd/client/http" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/utils/typeutil" @@ -37,32 +39,41 @@ import ( // Client is a PD (Placement Driver) client. // It should not be used after calling Close(). type Client interface { - GetClusterID(ctx context.Context) uint64 - AllocID(ctx context.Context) (uint64, error) - Bootstrap(ctx context.Context, store *metapb.Store, region *metapb.Region) error - PutStore(ctx context.Context, store *metapb.Store) error - StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error - RegionHeartbeat(ctx context.Context, region *core.RegionInfo) error - PutPDConfig(*sc.PDConfig) error + AllocID(context.Context) (uint64, error) + PutStore(context.Context, *metapb.Store) error + StoreHeartbeat(context.Context, *pdpb.StoreStats) error + RegionHeartbeat(context.Context, *core.RegionInfo) error + HeartbeatStreamLoop() + ChangeConn(*grpc.ClientConn) error Close() } const ( pdTimeout = time.Second maxInitClusterRetries = 100 + // retry to get leader URL + leaderChangedWaitTime = 100 * time.Millisecond + retryTimes = 10 ) var ( // errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses. errFailInitClusterID = errors.New("[pd] failed to get cluster id") PDHTTPClient pdHttp.Client + sd pd.ServiceDiscovery + ClusterID uint64 ) +// requestHeader returns a header for fixed ClusterID. +func requestHeader() *pdpb.RequestHeader { + return &pdpb.RequestHeader{ + ClusterId: ClusterID, + } +} + type client struct { - url string tag string - clusterID uint64 clientConn *grpc.ClientConn reportRegionHeartbeatCh chan *core.RegionInfo @@ -74,29 +85,15 @@ type client struct { } // NewClient creates a PD client. -func NewClient(pdAddr string, tag string) (Client, <-chan *pdpb.RegionHeartbeatResponse, error) { - simutil.Logger.Info("create pd client with endpoints", zap.String("tag", tag), zap.String("pd-address", pdAddr)) +func NewClient(tag string) (Client, <-chan *pdpb.RegionHeartbeatResponse, error) { ctx, cancel := context.WithCancel(context.Background()) c := &client{ - url: pdAddr, reportRegionHeartbeatCh: make(chan *core.RegionInfo, 1), receiveRegionHeartbeatCh: make(chan *pdpb.RegionHeartbeatResponse, 1), ctx: ctx, cancel: cancel, tag: tag, } - cc, err := c.createConn() - if err != nil { - return nil, nil, err - } - c.clientConn = cc - if err := c.initClusterID(); err != nil { - return nil, nil, err - } - simutil.Logger.Info("init cluster id", zap.String("tag", c.tag), zap.Uint64("cluster-id", c.clusterID)) - c.wg.Add(1) - go c.heartbeatStreamLoop() - return c, c.receiveRegionHeartbeatCh, nil } @@ -104,39 +101,18 @@ func (c *client) pdClient() pdpb.PDClient { return pdpb.NewPDClient(c.clientConn) } -func (c *client) initClusterID() error { - ctx, cancel := context.WithCancel(c.ctx) - defer cancel() - for i := 0; i < maxInitClusterRetries; i++ { - members, err := c.getMembers(ctx) - if err != nil || members.GetHeader() == nil { - simutil.Logger.Error("failed to get cluster id", zap.String("tag", c.tag), zap.Error(err)) - continue - } - c.clusterID = members.GetHeader().GetClusterId() - return nil - } - - return errors.WithStack(errFailInitClusterID) -} - -func (c *client) getMembers(ctx context.Context) (*pdpb.GetMembersResponse, error) { - members, err := c.pdClient().GetMembers(ctx, &pdpb.GetMembersRequest{}) +func createConn(url string) (*grpc.ClientConn, error) { + cc, err := grpc.Dial(strings.TrimPrefix(url, "http://"), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, errors.WithStack(err) } - if members.GetHeader().GetError() != nil { - return nil, errors.WithStack(errors.New(members.GetHeader().GetError().String())) - } - return members, nil + return cc, nil } -func (c *client) createConn() (*grpc.ClientConn, error) { - cc, err := grpc.Dial(strings.TrimPrefix(c.url, "http://"), grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, errors.WithStack(err) - } - return cc, nil +func (c *client) ChangeConn(cc *grpc.ClientConn) error { + c.clientConn = cc + simutil.Logger.Info("change pd client with endpoints", zap.String("tag", c.tag), zap.String("pd-address", cc.Target())) + return nil } func (c *client) createHeartbeatStream() (pdpb.PD_RegionHeartbeatClient, context.Context, context.CancelFunc) { @@ -166,7 +142,8 @@ func (c *client) createHeartbeatStream() (pdpb.PD_RegionHeartbeatClient, context return stream, ctx, cancel } -func (c *client) heartbeatStreamLoop() { +func (c *client) HeartbeatStreamLoop() { + c.wg.Add(1) defer c.wg.Done() for { stream, ctx, cancel := c.createHeartbeatStream() @@ -187,6 +164,23 @@ func (c *client) heartbeatStreamLoop() { return } wg.Wait() + + // update connection to recreate heartbeat stream + for i := 0; i < retryTimes; i++ { + sd.ScheduleCheckMemberChanged() + time.Sleep(leaderChangedWaitTime) + if client := sd.GetServiceClient(); client != nil { + _, conn, err := getLeaderURL(ctx, client.GetClientConn()) + if err != nil { + simutil.Logger.Error("[HeartbeatStreamLoop] failed to get leader URL", zap.Error(err)) + continue + } + if err = c.ChangeConn(conn); err == nil { + break + } + } + } + simutil.Logger.Info("recreate heartbeat stream", zap.String("tag", c.tag)) } } @@ -196,6 +190,7 @@ func (c *client) receiveRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regi resp, err := stream.Recv() if err != nil { errCh <- err + simutil.Logger.Error("receive regionHeartbeat error", zap.String("tag", c.tag), zap.Error(err)) return } select { @@ -213,7 +208,7 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio case r := <-c.reportRegionHeartbeatCh: region := r.Clone() request := &pdpb.RegionHeartbeatRequest{ - Header: c.requestHeader(), + Header: requestHeader(), Region: region.GetMeta(), Leader: region.GetLeader(), DownPeers: region.GetDownPeers(), @@ -227,6 +222,7 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio if err != nil { errCh <- err simutil.Logger.Error("report regionHeartbeat error", zap.String("tag", c.tag), zap.Error(err)) + return } case <-ctx.Done(): return @@ -235,6 +231,11 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio } func (c *client) Close() { + if c.cancel == nil { + simutil.Logger.Info("pd client has been closed", zap.String("tag", c.tag)) + return + } + simutil.Logger.Info("closing pd client", zap.String("tag", c.tag)) c.cancel() c.wg.Wait() @@ -243,14 +244,10 @@ func (c *client) Close() { } } -func (c *client) GetClusterID(context.Context) uint64 { - return c.clusterID -} - func (c *client) AllocID(ctx context.Context) (uint64, error) { ctx, cancel := context.WithTimeout(ctx, pdTimeout) resp, err := c.pdClient().AllocID(ctx, &pdpb.AllocIDRequest{ - Header: c.requestHeader(), + Header: requestHeader(), }) cancel() if err != nil { @@ -262,57 +259,259 @@ func (c *client) AllocID(ctx context.Context) (uint64, error) { return resp.GetId(), nil } -func (c *client) Bootstrap(ctx context.Context, store *metapb.Store, region *metapb.Region) error { +func (c *client) PutStore(ctx context.Context, store *metapb.Store) error { ctx, cancel := context.WithTimeout(ctx, pdTimeout) - defer cancel() - req := &pdpb.IsBootstrappedRequest{ - Header: &pdpb.RequestHeader{ - ClusterId: c.clusterID, - }, - } - resp, err := c.pdClient().IsBootstrapped(ctx, req) - if resp.GetBootstrapped() { - simutil.Logger.Fatal("failed to bootstrap, server is not clean") - } - if err != nil { - return err - } newStore := typeutil.DeepClone(store, core.StoreFactory) - newRegion := typeutil.DeepClone(region, core.RegionFactory) - - res, err := c.pdClient().Bootstrap(ctx, &pdpb.BootstrapRequest{ - Header: c.requestHeader(), + resp, err := c.pdClient().PutStore(ctx, &pdpb.PutStoreRequest{ + Header: requestHeader(), Store: newStore, - Region: newRegion, }) + cancel() if err != nil { return err } - if res.GetHeader().GetError() != nil { - return errors.Errorf("bootstrap failed: %s", resp.GetHeader().GetError().String()) + if resp.Header.GetError() != nil { + simutil.Logger.Error("put store error", zap.Reflect("error", resp.Header.GetError())) + return nil } return nil } -func (c *client) PutStore(ctx context.Context, store *metapb.Store) error { +func (c *client) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error { ctx, cancel := context.WithTimeout(ctx, pdTimeout) - newStore := typeutil.DeepClone(store, core.StoreFactory) - resp, err := c.pdClient().PutStore(ctx, &pdpb.PutStoreRequest{ - Header: c.requestHeader(), - Store: newStore, + newStats := typeutil.DeepClone(stats, core.StoreStatsFactory) + resp, err := c.pdClient().StoreHeartbeat(ctx, &pdpb.StoreHeartbeatRequest{ + Header: requestHeader(), + Stats: newStats, }) cancel() if err != nil { return err } if resp.Header.GetError() != nil { - simutil.Logger.Error("put store error", zap.Reflect("error", resp.Header.GetError())) + simutil.Logger.Error("store heartbeat error", zap.Reflect("error", resp.Header.GetError())) return nil } return nil } -func (c *client) PutPDConfig(config *sc.PDConfig) error { +func (c *client) RegionHeartbeat(_ context.Context, region *core.RegionInfo) error { + c.reportRegionHeartbeatCh <- region + return nil +} + +type RetryClient struct { + client Client + retryCount int +} + +func NewRetryClient(node *Node) *RetryClient { + // Init PD client and putting it into node. + tag := fmt.Sprintf("store %d", node.Store.Id) + var ( + client Client + receiveRegionHeartbeatCh <-chan *pdpb.RegionHeartbeatResponse + err error + ) + + // Client should wait if PD server is not ready. + for i := 0; i < maxInitClusterRetries; i++ { + client, receiveRegionHeartbeatCh, err = NewClient(tag) + if err == nil { + break + } + time.Sleep(time.Second) + } + + if err != nil { + simutil.Logger.Fatal("create client failed", zap.Error(err)) + } + node.client = client + + // Init RetryClient + retryClient := &RetryClient{ + client: client, + retryCount: retryTimes, + } + // check leader url firstly + retryClient.requestWithRetry(func() (any, error) { + return nil, errors.New("retry to create client") + }) + // start heartbeat stream + node.receiveRegionHeartbeatCh = receiveRegionHeartbeatCh + go client.HeartbeatStreamLoop() + + return retryClient +} + +func (rc *RetryClient) requestWithRetry(f func() (any, error)) (any, error) { + // execute the function directly + if res, err := f(); err == nil { + return res, nil + } + // retry to get leader URL + for i := 0; i < rc.retryCount; i++ { + sd.ScheduleCheckMemberChanged() + time.Sleep(100 * time.Millisecond) + if client := sd.GetServiceClient(); client != nil { + _, conn, err := getLeaderURL(context.Background(), client.GetClientConn()) + if err != nil { + simutil.Logger.Error("[retry] failed to get leader URL", zap.Error(err)) + return nil, err + } + if err = rc.client.ChangeConn(conn); err != nil { + simutil.Logger.Error("failed to change connection", zap.Error(err)) + return nil, err + } + return f() + } + } + return nil, errors.New("failed to retry") +} + +func getLeaderURL(ctx context.Context, conn *grpc.ClientConn) (string, *grpc.ClientConn, error) { + pdCli := pdpb.NewPDClient(conn) + members, err := pdCli.GetMembers(ctx, &pdpb.GetMembersRequest{}) + if err != nil { + return "", nil, err + } + if members.GetHeader().GetError() != nil { + return "", nil, errors.New(members.GetHeader().GetError().String()) + } + ClusterID = members.GetHeader().GetClusterId() + if ClusterID == 0 { + return "", nil, errors.New("cluster id is 0") + } + if members.GetLeader() == nil { + return "", nil, errors.New("leader is nil") + } + leaderURL := members.GetLeader().ClientUrls[0] + conn, err = createConn(leaderURL) + return leaderURL, conn, err +} + +func (rc *RetryClient) AllocID(ctx context.Context) (uint64, error) { + res, err := rc.requestWithRetry(func() (any, error) { + id, err := rc.client.AllocID(ctx) + return id, err + }) + if err != nil { + return 0, err + } + return res.(uint64), nil +} + +func (rc *RetryClient) PutStore(ctx context.Context, store *metapb.Store) error { + _, err := rc.requestWithRetry(func() (any, error) { + err := rc.client.PutStore(ctx, store) + return nil, err + }) + return err +} + +func (rc *RetryClient) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error { + _, err := rc.requestWithRetry(func() (any, error) { + err := rc.client.StoreHeartbeat(ctx, stats) + return nil, err + }) + return err +} + +func (rc *RetryClient) RegionHeartbeat(ctx context.Context, region *core.RegionInfo) error { + _, err := rc.requestWithRetry(func() (any, error) { + err := rc.client.RegionHeartbeat(ctx, region) + return nil, err + }) + return err +} + +func (*RetryClient) ChangeConn(_ *grpc.ClientConn) error { + panic("unImplement") +} + +func (rc *RetryClient) HeartbeatStreamLoop() { + rc.client.HeartbeatStreamLoop() +} + +func (rc *RetryClient) Close() { + rc.client.Close() +} + +// Bootstrap bootstraps the cluster and using the given PD address firstly. +// because before bootstrapping the cluster, PDServiceDiscovery can not been started. +func Bootstrap(ctx context.Context, pdAddrs string, store *metapb.Store, region *metapb.Region) ( + leaderURL string, pdCli pdpb.PDClient, err error) { + urls := strings.Split(pdAddrs, ",") + if len(urls) == 0 { + return "", nil, errors.New("empty pd address") + } + +retry: + for i := 0; i < maxInitClusterRetries; i++ { + time.Sleep(100 * time.Millisecond) + for _, url := range urls { + conn, err := createConn(url) + if err != nil { + continue + } + leaderURL, conn, err = getLeaderURL(ctx, conn) + if err != nil { + continue + } + pdCli = pdpb.NewPDClient(conn) + break retry + } + } + if ClusterID == 0 { + return "", nil, errors.WithStack(errFailInitClusterID) + } + simutil.Logger.Info("get cluster id successfully", zap.Uint64("cluster-id", ClusterID)) + + // Check if the cluster is already bootstrapped. + ctx, cancel := context.WithTimeout(ctx, pdTimeout) + defer cancel() + req := &pdpb.IsBootstrappedRequest{ + Header: requestHeader(), + } + resp, err := pdCli.IsBootstrapped(ctx, req) + if resp.GetBootstrapped() { + simutil.Logger.Fatal("failed to bootstrap, server is not clean") + } + if err != nil { + return "", nil, err + } + // Bootstrap the cluster. + newStore := typeutil.DeepClone(store, core.StoreFactory) + newRegion := typeutil.DeepClone(region, core.RegionFactory) + var res *pdpb.BootstrapResponse + for i := 0; i < maxInitClusterRetries; i++ { + // Bootstrap the cluster. + res, err = pdCli.Bootstrap(ctx, &pdpb.BootstrapRequest{ + Header: requestHeader(), + Store: newStore, + Region: newRegion, + }) + if err != nil { + continue + } + if res.GetHeader().GetError() != nil { + continue + } + break + } + if err != nil { + return "", nil, err + } + if res.GetHeader().GetError() != nil { + return "", nil, errors.New(res.GetHeader().GetError().String()) + } + + return leaderURL, pdCli, nil +} + +/* PDHTTPClient is a client for PD HTTP API, these are the functions that are used in the simulator */ + +func PutPDConfig(config *sc.PDConfig) error { if len(config.PlacementRules) > 0 { ruleOps := make([]*pdHttp.RuleOp, 0) for _, rule := range config.PlacementRules { @@ -321,7 +520,7 @@ func (c *client) PutPDConfig(config *sc.PDConfig) error { Action: pdHttp.RuleOpAdd, }) } - err := PDHTTPClient.SetPlacementRuleInBatch(c.ctx, ruleOps) + err := PDHTTPClient.SetPlacementRuleInBatch(context.Background(), ruleOps) if err != nil { return err } @@ -330,7 +529,7 @@ func (c *client) PutPDConfig(config *sc.PDConfig) error { if len(config.LocationLabels) > 0 { data := make(map[string]any) data["location-labels"] = config.LocationLabels - err := PDHTTPClient.SetConfig(c.ctx, data) + err := PDHTTPClient.SetConfig(context.Background(), data) if err != nil { return err } @@ -339,35 +538,6 @@ func (c *client) PutPDConfig(config *sc.PDConfig) error { return nil } -func (c *client) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error { - ctx, cancel := context.WithTimeout(ctx, pdTimeout) - newStats := typeutil.DeepClone(stats, core.StoreStatsFactory) - resp, err := c.pdClient().StoreHeartbeat(ctx, &pdpb.StoreHeartbeatRequest{ - Header: c.requestHeader(), - Stats: newStats, - }) - cancel() - if err != nil { - return err - } - if resp.Header.GetError() != nil { - simutil.Logger.Error("store heartbeat error", zap.Reflect("error", resp.Header.GetError())) - return nil - } - return nil -} - -func (c *client) RegionHeartbeat(_ context.Context, region *core.RegionInfo) error { - c.reportRegionHeartbeatCh <- region - return nil -} - -func (c *client) requestHeader() *pdpb.RequestHeader { - return &pdpb.RequestHeader{ - ClusterId: c.clusterID, - } -} - func ChooseToHaltPDSchedule(halt bool) { PDHTTPClient.SetConfig(context.Background(), map[string]any{ "schedule.halt-scheduling": strconv.FormatBool(halt), diff --git a/tools/pd-simulator/simulator/conn.go b/tools/pd-simulator/simulator/conn.go index 4be8a2b76dc..b1000c0f17b 100644 --- a/tools/pd-simulator/simulator/conn.go +++ b/tools/pd-simulator/simulator/conn.go @@ -22,19 +22,17 @@ import ( // Connection records the information of connection among nodes. type Connection struct { - pdAddr string - Nodes map[uint64]*Node + Nodes map[uint64]*Node } // NewConnection creates nodes according to the configuration and returns the connection among nodes. -func NewConnection(simCase *cases.Case, pdAddr string, storeConfig *config.SimConfig) (*Connection, error) { +func NewConnection(simCase *cases.Case, storeConfig *config.SimConfig) (*Connection, error) { conn := &Connection{ - pdAddr: pdAddr, - Nodes: make(map[uint64]*Node), + Nodes: make(map[uint64]*Node), } for _, store := range simCase.Stores { - node, err := NewNode(store, pdAddr, storeConfig) + node, err := NewNode(store, storeConfig) if err != nil { return nil, err } diff --git a/tools/pd-simulator/simulator/drive.go b/tools/pd-simulator/simulator/drive.go index 700dd58f87a..0296710b705 100644 --- a/tools/pd-simulator/simulator/drive.go +++ b/tools/pd-simulator/simulator/drive.go @@ -20,12 +20,16 @@ import ( "net/http/pprof" "path" "strconv" + "strings" "sync" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/prometheus/client_golang/prometheus/promhttp" + pd "github.com/tikv/pd/client" + pdHttp "github.com/tikv/pd/client/http" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/tools/pd-simulator/simulator/cases" @@ -42,7 +46,6 @@ type Driver struct { pdAddr string statusAddress string simCase *cases.Case - client Client tickCount int64 eventRunner *EventRunner raftEngine *RaftEngine @@ -71,7 +74,7 @@ func NewDriver(pdAddr, statusAddress, caseName string, simConfig *config.SimConf // Prepare initializes cluster information, bootstraps cluster and starts nodes. func (d *Driver) Prepare() error { - conn, err := NewConnection(d.simCase, d.pdAddr, d.simConfig) + conn, err := NewConnection(d.simCase, d.simConfig) if err != nil { return err } @@ -79,22 +82,27 @@ func (d *Driver) Prepare() error { d.raftEngine = NewRaftEngine(d.simCase, d.conn, d.simConfig) d.eventRunner = NewEventRunner(d.simCase.Events, d.raftEngine) - d.updateNodeAvailable() if d.statusAddress != "" { go d.runHTTPServer() } + + if err = d.allocID(); err != nil { + return err + } + + return d.Start() +} + +func (d *Driver) allocID() error { // Bootstrap. store, region, err := d.GetBootstrapInfo(d.raftEngine) if err != nil { return err } - d.client = d.conn.Nodes[store.GetId()].client - ctx, cancel := context.WithTimeout(context.Background(), pdTimeout) - err = d.client.Bootstrap(ctx, store, region) - cancel() + leaderURL, pdCli, err := Bootstrap(context.Background(), d.pdAddr, store, region) if err != nil { simutil.Logger.Fatal("bootstrap error", zap.Error(err)) } else { @@ -107,15 +115,14 @@ func (d *Driver) Prepare() error { requestTimeout := 10 * time.Second etcdTimeout := 3 * time.Second etcdClient, err := clientv3.New(clientv3.Config{ - Endpoints: []string{d.pdAddr}, + Endpoints: []string{leaderURL}, DialTimeout: etcdTimeout, }) if err != nil { return err } - ctx, cancel = context.WithTimeout(context.Background(), requestTimeout) - clusterID := d.client.GetClusterID(ctx) - rootPath := path.Join("/pd", strconv.FormatUint(clusterID, 10)) + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + rootPath := path.Join("/pd", strconv.FormatUint(ClusterID, 10)) allocIDPath := path.Join(rootPath, "alloc_id") _, err = etcdClient.Put(ctx, allocIDPath, string(typeutil.Uint64ToBytes(maxID+1000))) if err != nil { @@ -125,22 +132,34 @@ func (d *Driver) Prepare() error { cancel() for { - var id uint64 - id, err = d.client.AllocID(context.Background()) + var resp *pdpb.AllocIDResponse + resp, err = pdCli.AllocID(context.Background(), &pdpb.AllocIDRequest{ + Header: requestHeader(), + }) if err != nil { return errors.WithStack(err) } - if id > maxID { + if resp.Id > maxID { simutil.IDAllocator.ResetID() break } } + return nil +} - err = d.Start() - if err != nil { +func (d *Driver) updateNodesClient() error { + urls := strings.Split(d.pdAddr, ",") + ctx, cancel := context.WithCancel(context.Background()) + sd = pd.NewDefaultPDServiceDiscovery(ctx, cancel, urls, nil) + if err := sd.Init(); err != nil { return err } + // Init PD HTTP client. + PDHTTPClient = pdHttp.NewClientWithServiceDiscovery("pd-simulator", sd) + for _, node := range d.conn.Nodes { + node.client = NewRetryClient(node) + } return nil } @@ -174,19 +193,18 @@ func (d *Driver) Check() bool { // Start starts all nodes. func (d *Driver) Start() error { + if err := d.updateNodesClient(); err != nil { + return err + } + for _, n := range d.conn.Nodes { err := n.Start() if err != nil { return err } } - d.ChangePDConfig() - return nil -} -// ChangePDConfig changes pd config -func (d *Driver) ChangePDConfig() error { - d.client.PutPDConfig(d.pdConfig) + PutPDConfig(d.pdConfig) return nil } diff --git a/tools/pd-simulator/simulator/event.go b/tools/pd-simulator/simulator/event.go index 20c75b58384..8e01a8f5f40 100644 --- a/tools/pd-simulator/simulator/event.go +++ b/tools/pd-simulator/simulator/event.go @@ -182,7 +182,7 @@ func (*AddNode) Run(raft *RaftEngine, _ int64) bool { Capacity: uint64(config.RaftStore.Capacity), Version: config.StoreVersion, } - n, err := NewNode(s, raft.conn.pdAddr, config) + n, err := NewNode(s, config) if err != nil { simutil.Logger.Error("create node failed", zap.Error(err)) return false @@ -190,6 +190,8 @@ func (*AddNode) Run(raft *RaftEngine, _ int64) bool { raft.conn.Nodes[s.ID] = n n.raftEngine = raft + n.client = NewRetryClient(n) + err = n.Start() if err != nil { delete(raft.conn.Nodes, s.ID) diff --git a/tools/pd-simulator/simulator/node.go b/tools/pd-simulator/simulator/node.go index fe8dc74a944..8238a6486c1 100644 --- a/tools/pd-simulator/simulator/node.go +++ b/tools/pd-simulator/simulator/node.go @@ -42,23 +42,24 @@ const ( type Node struct { *metapb.Store syncutil.RWMutex - stats *info.StoreStats - tick uint64 - wg sync.WaitGroup - tasks map[uint64]*Task + stats *info.StoreStats + tick uint64 + wg sync.WaitGroup + tasks map[uint64]*Task + ctx context.Context + cancel context.CancelFunc + raftEngine *RaftEngine + limiter *ratelimit.RateLimiter + sizeMutex syncutil.Mutex + hasExtraUsedSpace bool + snapStats []*pdpb.SnapshotStat + // PD client client Client receiveRegionHeartbeatCh <-chan *pdpb.RegionHeartbeatResponse - ctx context.Context - cancel context.CancelFunc - raftEngine *RaftEngine - limiter *ratelimit.RateLimiter - sizeMutex syncutil.Mutex - hasExtraUsedSpace bool - snapStats []*pdpb.SnapshotStat } // NewNode returns a Node. -func NewNode(s *cases.Store, pdAddr string, config *sc.SimConfig) (*Node, error) { +func NewNode(s *cases.Store, config *sc.SimConfig) (*Node, error) { ctx, cancel := context.WithCancel(context.Background()) store := &metapb.Store{ Id: s.ID, @@ -75,40 +76,19 @@ func NewNode(s *cases.Store, pdAddr string, config *sc.SimConfig) (*Node, error) Available: uint64(config.RaftStore.Capacity), }, } - tag := fmt.Sprintf("store %d", s.ID) - var ( - client Client - receiveRegionHeartbeatCh <-chan *pdpb.RegionHeartbeatResponse - err error - ) - // Client should wait if PD server is not ready. - for i := 0; i < maxInitClusterRetries; i++ { - client, receiveRegionHeartbeatCh, err = NewClient(pdAddr, tag) - if err == nil { - break - } - time.Sleep(time.Second) - } - - if err != nil { - cancel() - return nil, err - } ratio := config.Speed() speed := config.StoreIOMBPerSecond * units.MiB * int64(ratio) return &Node{ - Store: store, - stats: stats, - client: client, - ctx: ctx, - cancel: cancel, - tasks: make(map[uint64]*Task), - receiveRegionHeartbeatCh: receiveRegionHeartbeatCh, - limiter: ratelimit.NewRateLimiter(float64(speed), int(speed)), - tick: uint64(rand.Intn(storeHeartBeatPeriod)), - hasExtraUsedSpace: s.HasExtraUsedSpace, - snapStats: make([]*pdpb.SnapshotStat, 0), + Store: store, + stats: stats, + ctx: ctx, + cancel: cancel, + tasks: make(map[uint64]*Task), + limiter: ratelimit.NewRateLimiter(float64(speed), int(speed)), + tick: uint64(rand.Intn(storeHeartBeatPeriod)), + hasExtraUsedSpace: s.HasExtraUsedSpace, + snapStats: make([]*pdpb.SnapshotStat, 0), }, nil } @@ -205,7 +185,7 @@ func (n *Node) storeHeartBeat() { n.stats.SnapshotStats = stats err := n.client.StoreHeartbeat(ctx, &n.stats.StoreStats) if err != nil { - simutil.Logger.Info("report heartbeat error", + simutil.Logger.Info("report store heartbeat error", zap.Uint64("node-id", n.GetId()), zap.Error(err)) } @@ -230,7 +210,7 @@ func (n *Node) regionHeartBeat() { ctx, cancel := context.WithTimeout(n.ctx, pdTimeout) err := n.client.RegionHeartbeat(ctx, region) if err != nil { - simutil.Logger.Info("report heartbeat error", + simutil.Logger.Info("report region heartbeat error", zap.Uint64("node-id", n.Id), zap.Uint64("region-id", region.GetID()), zap.Error(err)) @@ -247,7 +227,7 @@ func (n *Node) reportRegionChange() { ctx, cancel := context.WithTimeout(n.ctx, pdTimeout) err := n.client.RegionHeartbeat(ctx, region) if err != nil { - simutil.Logger.Info("report heartbeat error", + simutil.Logger.Info("report region change heartbeat error", zap.Uint64("node-id", n.Id), zap.Uint64("region-id", region.GetID()), zap.Error(err))