diff --git a/.golangci.yml b/.golangci.yml index 36205d71f79..a60f59e4f12 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -103,6 +103,8 @@ linters-settings: excludes: - G404 - G601 + - G108 + - G114 issues: exclude-rules: diff --git a/cmd/kafka-consumer/main.go b/cmd/kafka-consumer/main.go index b160d423b54..0b4d7c3b867 100644 --- a/cmd/kafka-consumer/main.go +++ b/cmd/kafka-consumer/main.go @@ -16,9 +16,12 @@ package main import ( "context" "database/sql" + "errors" "flag" "fmt" "math" + "net/http" + _ "net/http/pprof" "net/url" "os" "os/signal" @@ -31,8 +34,9 @@ import ( "time" "github.com/IBM/sarama" + "github.com/edwingeng/deque" "github.com/google/uuid" - "github.com/pingcap/errors" + cerror "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/cdc/sink/ddlsink" @@ -97,6 +101,8 @@ type consumerOption struct { // upstreamTiDBDSN is the dsn of the upstream TiDB cluster upstreamTiDBDSN string + + enableProfiling bool } // Adjust the consumer option by the upstream uri passed in parameters. @@ -193,10 +199,10 @@ func (o *consumerOption) Adjust(upstreamURI *url.URL, configFile string) error { replicaConfig.Sink.Protocol = util.AddressOf(o.protocol.String()) err := cmdUtil.StrictDecodeFile(configFile, "kafka consumer", replicaConfig) if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } if _, err := filter.VerifyTableRules(replicaConfig.Filter); err != nil { - return errors.Trace(err) + return cerror.Trace(err) } o.replicaConfig = replicaConfig } @@ -238,6 +244,7 @@ func main() { flag.StringVar(&consumerOption.ca, "ca", "", "CA certificate path for Kafka SSL connection") flag.StringVar(&consumerOption.cert, "cert", "", "Certificate path for Kafka SSL connection") flag.StringVar(&consumerOption.key, "key", "", "Private key path for Kafka SSL connection") + flag.BoolVar(&consumerOption.enableProfiling, "enable-profiling", false, "enable pprof profiling") flag.Parse() err := logutil.InitLogger(&logutil.Config{ @@ -293,12 +300,22 @@ func main() { log.Panic("Error creating consumer group client", zap.Error(err)) } - wg := &sync.WaitGroup{} + var wg sync.WaitGroup + if consumerOption.enableProfiling { + wg.Add(1) + go func() { + defer wg.Done() + if err := http.ListenAndServe(":6060", nil); err != nil { + log.Panic("Error starting pprof", zap.Error(err)) + } + }() + } + wg.Add(1) go func() { defer wg.Done() for { - // `Consume` should be called inside an infinite loop, when a + // `consume` should be called inside an infinite loop, when a // server-side rebalance happens, the consumer session will need to be // recreated to get the new claims if err := client.Consume(ctx, strings.Split(consumerOption.topic, ","), consumer); err != nil { @@ -342,19 +359,19 @@ func getPartitionNum(address []string, topic string, cfg *sarama.Config) (int32, // get partition number or create topic automatically admin, err := sarama.NewClusterAdmin(address, cfg) if err != nil { - return 0, errors.Trace(err) + return 0, cerror.Trace(err) } topics, err := admin.ListTopics() if err != nil { - return 0, errors.Trace(err) + return 0, cerror.Trace(err) } err = admin.Close() if err != nil { - return 0, errors.Trace(err) + return 0, cerror.Trace(err) } topicDetail, exist := topics[topic] if !exist { - return 0, errors.Errorf("can not find topic %s", topic) + return 0, cerror.Errorf("can not find topic %s", topic) } log.Info("get partition number of topic", zap.String("topic", topic), @@ -365,13 +382,13 @@ func getPartitionNum(address []string, topic string, cfg *sarama.Config) (int32, func waitTopicCreated(address []string, topic string, cfg *sarama.Config) error { admin, err := sarama.NewClusterAdmin(address, cfg) if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } defer admin.Close() for i := 0; i <= 30; i++ { topics, err := admin.ListTopics() if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } if _, ok := topics[topic]; ok { return nil @@ -379,7 +396,7 @@ func waitTopicCreated(address []string, topic string, cfg *sarama.Config) error log.Info("wait the topic created", zap.String("topic", topic)) time.Sleep(1 * time.Second) } - return errors.Errorf("wait the topic(%s) created timeout", topic) + return cerror.Errorf("wait the topic(%s) created timeout", topic) } func newSaramaConfig(o *consumerOption) (*sarama.Config, error) { @@ -387,7 +404,7 @@ func newSaramaConfig(o *consumerOption) (*sarama.Config, error) { version, err := sarama.ParseKafkaVersion(o.version) if err != nil { - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } config.ClientID = "ticdc_kafka_sarama_consumer" @@ -406,7 +423,7 @@ func newSaramaConfig(o *consumerOption) (*sarama.Config, error) { KeyPath: o.key, }).ToTLSConfig() if err != nil { - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } } @@ -419,6 +436,8 @@ type partitionSinks struct { tableSinksMap sync.Map // resolvedTs record the maximum timestamp of the received event resolvedTs uint64 + + flowController *flowController } // Consumer represents a Sarama consumer group consumer @@ -450,6 +469,10 @@ type Consumer struct { upstreamTiDB *sql.DB } +const ( + defaultMemoryQuotaInBytes = 2 * 1024 * 1024 * 1024 // 2GB +) + // NewConsumer creates a new cdc kafka consumer func NewConsumer(ctx context.Context, o *consumerOption) (*Consumer, error) { c := new(Consumer) @@ -457,7 +480,7 @@ func NewConsumer(ctx context.Context, o *consumerOption) (*Consumer, error) { tz, err := util.GetTimezone(o.timezone) if err != nil { - return nil, errors.Annotate(err, "can not load timezone") + return nil, cerror.Annotate(err, "can not load timezone") } config.GetGlobalServerConfig().TZ = o.timezone c.tz = tz @@ -488,7 +511,7 @@ func NewConsumer(ctx context.Context, o *consumerOption) (*Consumer, error) { if o.replicaConfig != nil { eventRouter, err := dispatcher.NewEventRouter(o.replicaConfig, o.protocol, o.topic, "kafka") if err != nil { - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } c.eventRouter = eventRouter } @@ -496,21 +519,28 @@ func NewConsumer(ctx context.Context, o *consumerOption) (*Consumer, error) { c.sinks = make([]*partitionSinks, o.partitionNum) ctx, cancel := context.WithCancel(ctx) errChan := make(chan error, 1) + + memoryQuotaPerPartition := defaultMemoryQuotaInBytes / int(o.partitionNum) for i := 0; i < int(o.partitionNum); i++ { - c.sinks[i] = &partitionSinks{} + c.sinks[i] = &partitionSinks{ + flowController: newFlowController(uint64(memoryQuotaPerPartition)), + } } + log.Info("flow controller created for each partition", + zap.Int32("partitionNum", o.partitionNum), + zap.Int("quota", memoryQuotaPerPartition)) changefeedID := model.DefaultChangeFeedID("kafka-consumer") f, err := eventsinkfactory.New(ctx, changefeedID, o.downstreamURI, config.GetDefaultReplicaConfig(), errChan) if err != nil { cancel() - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } c.sinkFactory = f go func() { err := <-errChan - if errors.Cause(err) != context.Canceled { + if !errors.Is(cerror.Cause(err), context.Canceled) { log.Error("error on running consumer", zap.Error(err)) } else { log.Info("consumer exited") @@ -521,7 +551,7 @@ func NewConsumer(ctx context.Context, o *consumerOption) (*Consumer, error) { ddlSink, err := ddlsinkfactory.New(ctx, changefeedID, o.downstreamURI, config.GetDefaultReplicaConfig()) if err != nil { cancel() - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } c.ddlSink = ddlSink c.ready = make(chan bool) @@ -595,14 +625,14 @@ func (c *Consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram case config.ProtocolAvro: schemaM, err := avro.NewConfluentSchemaManager(ctx, c.option.schemaRegistryURI, nil) if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } decoder = avro.NewDecoder(c.codecConfig, schemaM, c.option.topic, c.tz) default: log.Panic("Protocol not supported", zap.Any("Protocol", c.codecConfig.Protocol)) } if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } log.Info("start consume claim", @@ -613,7 +643,7 @@ func (c *Consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram for message := range claim.Messages() { if err := decoder.AddKeyValue(message.Key, message.Value); err != nil { log.Error("add key value to the decoder failed", zap.Error(err)) - return errors.Trace(err) + return cerror.Trace(err) } counter := 0 @@ -664,7 +694,7 @@ func (c *Consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram if c.eventRouter != nil { target, _, err := c.eventRouter.GetPartitionForRowChange(row, c.option.partitionNum) if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } if partition != target { log.Panic("RowChangedEvent dispatched to wrong partition", @@ -706,6 +736,16 @@ func (c *Consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram group.Append(row) // todo: mark the offset after the DDL is fully synced to the downstream mysql. session.MarkMessage(message, "") + + size := uint64(row.ApproximateBytes()) + err = sink.flowController.consume(row.CommitTs, size) + if err != nil { + if errors.Is(err, errFlowControllerAborted) { + log.Info("flow control aborted") + return nil + } + return cerror.Trace(err) + } case model.MessageTypeResolved: ts, err := decoder.NextResolvedEvent() if err != nil { @@ -818,7 +858,7 @@ func (c *Consumer) forEachSink(fn func(sink *partitionSinks) error) error { defer c.sinksMu.Unlock() for _, sink := range c.sinks { if err := fn(sink); err != nil { - return errors.Trace(err) + return cerror.Trace(err) } } return nil @@ -849,7 +889,7 @@ func (c *Consumer) Run(ctx context.Context) error { minPartitionResolvedTs, err := c.getMinPartitionResolvedTs() if err != nil { - return errors.Trace(err) + return cerror.Trace(err) } // handle DDL @@ -859,12 +899,12 @@ func (c *Consumer) Run(ctx context.Context) error { if err := c.forEachSink(func(sink *partitionSinks) error { return syncFlushRowChangedEvents(ctx, sink, todoDDL.CommitTs) }); err != nil { - return errors.Trace(err) + return cerror.Trace(err) } // DDL can be executed, do it first. if err := c.ddlSink.WriteDDLEvent(ctx, todoDDL); err != nil { - return errors.Trace(err) + return cerror.Trace(err) } c.popDDL() @@ -890,7 +930,7 @@ func (c *Consumer) Run(ctx context.Context) error { if err := c.forEachSink(func(sink *partitionSinks) error { return syncFlushRowChangedEvents(ctx, sink, c.globalResolvedTs) }); err != nil { - return errors.Trace(err) + return cerror.Trace(err) } } } @@ -914,9 +954,11 @@ func syncFlushRowChangedEvents(ctx context.Context, sink *partitionSinks, resolv log.Error("Failed to update resolved ts", zap.Error(err)) return false } - if !tableSink.(tablesink.TableSink).GetCheckpointTs().EqualOrGreater(resolvedTs) { + checkpoint := tableSink.(tablesink.TableSink).GetCheckpointTs() + if !checkpoint.EqualOrGreater(resolvedTs) { flushedResolvedTs = false } + sink.flowController.release(checkpoint.Ts) return true }) if flushedResolvedTs { @@ -950,7 +992,7 @@ func openDB(ctx context.Context, dsn string) (*sql.DB, error) { db, err := sql.Open("mysql", dsn) if err != nil { log.Error("open db failed", zap.Error(err)) - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } db.SetMaxOpenConns(10) @@ -961,8 +1003,146 @@ func openDB(ctx context.Context, dsn string) (*sql.DB, error) { defer cancel() if err = db.PingContext(ctx); err != nil { log.Error("ping db failed", zap.Error(err)) - return nil, errors.Trace(err) + return nil, cerror.Trace(err) } log.Info("open db success", zap.String("dsn", dsn)) return db, nil } + +var ( + errFlowControllerLargerThanQuota = errors.New("flow controller request memory larger than quota") + errFlowControllerAborted = errors.New("flow controller aborted") +) + +type memoryQuota struct { + quota uint64 // should not be changed once initialized + + isAborted atomic.Bool + + consumed struct { + sync.Mutex + bytes uint64 + } + + consumedCond *sync.Cond +} + +// newMemoryQuota creates a new memoryQuota +// quota: max advised memory consumption in bytes. +func newMemoryQuota(quota uint64) *memoryQuota { + ret := &memoryQuota{ + quota: quota, + } + + ret.consumedCond = sync.NewCond(&ret.consumed) + return ret +} + +// consumeWithBlocking is called when a hard-limit is needed. The method will +// block until enough memory has been freed up by release. +// blockCallBack will be called if the function will block. +// Should be used with care to prevent deadlock. +func (c *memoryQuota) consumeWithBlocking(nBytes uint64) error { + if nBytes >= c.quota { + return errFlowControllerLargerThanQuota + } + + c.consumed.Lock() + defer c.consumed.Unlock() + + for { + if c.isAborted.Load() { + return errFlowControllerAborted + } + + newConsumed := c.consumed.bytes + nBytes + if newConsumed < c.quota { + break + } + c.consumedCond.Wait() + } + + c.consumed.bytes += nBytes + return nil +} + +// release is called when a chuck of memory is done being used. +func (c *memoryQuota) release(nBytes uint64) { + c.consumed.Lock() + + if c.consumed.bytes < nBytes { + c.consumed.Unlock() + log.Panic("memoryQuota: releasing more than consumed, report a bug", + zap.Uint64("consumed", c.consumed.bytes), + zap.Uint64("released", nBytes)) + } + + c.consumed.bytes -= nBytes + if c.consumed.bytes < c.quota { + c.consumed.Unlock() + c.consumedCond.Signal() + return + } + + c.consumed.Unlock() +} + +type flowController struct { + memoryQuota *memoryQuota + + queueMu struct { + sync.Mutex + queue deque.Deque + } +} + +type entry struct { + commitTs uint64 + size uint64 +} + +func newFlowController(quota uint64) *flowController { + return &flowController{ + memoryQuota: newMemoryQuota(quota), + queueMu: struct { + sync.Mutex + queue deque.Deque + }{ + queue: deque.NewDeque(), + }, + } +} + +func (c *flowController) consume(commitTs uint64, size uint64) error { + err := c.memoryQuota.consumeWithBlocking(size) + if err != nil { + return cerror.Trace(err) + } + + c.queueMu.Lock() + defer c.queueMu.Unlock() + + c.queueMu.queue.PushBack(&entry{ + commitTs: commitTs, + size: size, + }) + + return nil +} + +func (c *flowController) release(resolvedTs uint64) { + var nBytesToRelease uint64 + + c.queueMu.Lock() + for c.queueMu.queue.Len() > 0 { + if peeked := c.queueMu.queue.Front().(*entry); peeked.commitTs <= resolvedTs { + nBytesToRelease += peeked.size + c.queueMu.queue.PopFront() + } else { + break + } + } + c.queueMu.Unlock() + + c.memoryQuota.release(nBytesToRelease) +} diff --git a/pkg/sink/codec/avro/confluent_schema_registry.go b/pkg/sink/codec/avro/confluent_schema_registry.go index 9f06fe64587..c7e0851810c 100644 --- a/pkg/sink/codec/avro/confluent_schema_registry.go +++ b/pkg/sink/codec/avro/confluent_schema_registry.go @@ -210,20 +210,12 @@ func (m *confluentSchemaManager) Lookup( m.cacheRWLock.RLock() entry, exists := m.cache[schemaName] if exists && entry.schemaID.confluentSchemaID == schemaID.confluentSchemaID { - log.Debug("Avro schema lookup cache hit", - zap.String("key", schemaName), - zap.Int("schemaID", entry.schemaID.confluentSchemaID)) m.cacheRWLock.RUnlock() return entry.codec, nil } m.cacheRWLock.RUnlock() - log.Info("Avro schema lookup cache miss", - zap.String("key", schemaName), - zap.Int("schemaID", schemaID.confluentSchemaID)) - uri := m.registryURL + "/schemas/ids/" + strconv.Itoa(schemaID.confluentSchemaID) - log.Debug("Querying for latest schema", zap.String("uri", uri)) req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) if err != nil { @@ -289,11 +281,6 @@ func (m *confluentSchemaManager) Lookup( m.cacheRWLock.Lock() m.cache[schemaName] = cacheEntry m.cacheRWLock.Unlock() - - log.Info("Avro schema lookup successful with cache miss", - zap.Int("schemaID", cacheEntry.schemaID.confluentSchemaID), - zap.String("schema", cacheEntry.codec.Schema())) - return cacheEntry.codec, nil }