From 4f1482cf7c614ce5db06c338c6983b2392796549 Mon Sep 17 00:00:00 2001 From: Herko Lategan Date: Mon, 10 Jun 2024 19:39:53 +0100 Subject: [PATCH] roachprod: lock dns records by name Previously, the lower level DNS implementation had no internal protection against concurrently modifying the same SRV record entry. This was because from a use case perspective this should never happen as we generally don't query and register services at the same time. This could however cause several race conditions, and it should be protected against. This change locks records based on the record that is being modified. It allows concurrent modification of different records, but will lock on record name to prevent querying and modifying the same record at any point, at least for this process. Epic: None Release Note: None --- pkg/roachprod/vm/gce/dns.go | 169 ++++++++++++++++++++++-------------- 1 file changed, 105 insertions(+), 64 deletions(-) diff --git a/pkg/roachprod/vm/gce/dns.go b/pkg/roachprod/vm/gce/dns.go index be4803ad16ad..f614e9602125 100644 --- a/pkg/roachprod/vm/gce/dns.go +++ b/pkg/roachprod/vm/gce/dns.go @@ -88,19 +88,16 @@ type dnsProvider struct { mu syncutil.Mutex records map[string][]vm.DNSRecord } + recordLock struct { + mu syncutil.Mutex + locks map[string]*syncutil.Mutex + } execFn ExecFn resolvers []*net.Resolver } func NewDNSProvider() *dnsProvider { - var gcloudMu syncutil.Mutex return NewDNSProviderWithExec(func(cmd *exec.Cmd) ([]byte, error) { - // Limit to one gcloud command at a time. At this time we are unsure if it's - // safe to make concurrent calls to the `gcloud` CLI to mutate DNS records - // in the same zone. We don't mutate the same record in parallel, but we do - // mutate different records in the same zone. See: #122180 for more details. - gcloudMu.Lock() - defer gcloudMu.Unlock() return cmd.CombinedOutput() }) } @@ -116,6 +113,10 @@ func NewDNSProviderWithExec(execFn ExecFn) *dnsProvider { mu syncutil.Mutex records map[string][]vm.DNSRecord }{records: make(map[string][]vm.DNSRecord)}, + recordLock: struct { + mu syncutil.Mutex + locks map[string]*syncutil.Mutex + }{locks: make(map[string]*syncutil.Mutex)}, execFn: execFn, resolvers: googleDNSResolvers(), } @@ -131,64 +132,78 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord } for name, recordGroup := range recordsByName { - existingRecords, err := n.lookupSRVRecords(ctx, name) - if err != nil { - return err - } - command := "create" - if len(existingRecords) > 0 { - command = "update" - } + err := n.withRecordLock(name, func() error { + existingRecords, err := n.lookupSRVRecords(ctx, name) + if err != nil { + return err + } + command := "create" + if len(existingRecords) > 0 { + command = "update" + } - // Combine old and new records using a map to deduplicate with the record - // data as the key. - combinedRecords := make(map[string]vm.DNSRecord) - for _, record := range existingRecords { - combinedRecords[record.Data] = record - } - for _, record := range recordGroup { - combinedRecords[record.Data] = record - } + // Combine old and new records using a map to deduplicate with the record + // data as the key. + combinedRecords := make(map[string]vm.DNSRecord) + for _, record := range existingRecords { + combinedRecords[record.Data] = record + } + for _, record := range recordGroup { + combinedRecords[record.Data] = record + } - // We assume that all records in a group have the same name, type, and ttl. - // TODO(herko): Add error checking to ensure that the above is the case. - firstRecord := recordGroup[0] - data := maps.Keys(combinedRecords) - sort.Strings(data) - args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name, - "--type", string(firstRecord.Type), - "--ttl", strconv.Itoa(firstRecord.TTL), - "--zone", n.managedZone, - "--rrdatas", strings.Join(data, ","), - } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) - if err != nil { - // Clear the cache entry if the operation failed, as the records may - // have been partially updated. - n.clearCacheEntry(name) - return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) - } - // If fastDNS is enabled, we need to wait for the records to become available - // on the Google DNS servers. - if config.FastDNS { - err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...) + // We assume that all records in a group have the same name, type, and ttl. + // TODO(herko): Add error checking to ensure that the above is the case. + firstRecord := recordGroup[0] + data := maps.Keys(combinedRecords) + sort.Strings(data) + args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name, + "--type", string(firstRecord.Type), + "--ttl", strconv.Itoa(firstRecord.TTL), + "--zone", n.managedZone, + "--rrdatas", strings.Join(data, ","), + } + cmd := exec.CommandContext(ctx, "gcloud", args...) + out, err := n.execFn(cmd) if err != nil { - return err + // Clear the cache entry if the operation failed, as the records may + // have been partially updated. + n.clearCacheEntry(name) + return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) } + // If fastDNS is enabled, we need to wait for the records to become available + // on the Google DNS servers. + if config.FastDNS { + err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...) + if err != nil { + return err + } + } + n.updateCache(name, maps.Values(combinedRecords)) + return err + + }) + if err != nil { + return errors.Wrapf(err, "failed to update records for %s", name) } - n.updateCache(name, maps.Values(combinedRecords)) } return nil } // LookupSRVRecords implements the vm.DNSProvider interface. func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) { - if config.FastDNS { - rIdx := randutil.FastUint32() % uint32(len(n.resolvers)) - return n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true) - } - return n.lookupSRVRecords(ctx, name) + var records []vm.DNSRecord + var err error + err = n.withRecordLock(name, func() error { + if config.FastDNS { + rIdx := randutil.FastUint32() % uint32(len(n.resolvers)) + records, err = n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true) + return err + } + records, err = n.lookupSRVRecords(ctx, name) + return err + }) + return records, err } // ListRecords implements the vm.DNSProvider interface. @@ -199,17 +214,23 @@ func (n *dnsProvider) ListRecords(ctx context.Context) ([]vm.DNSRecord, error) { // DeleteRecordsByName implements the vm.DNSProvider interface. func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) error { for _, name := range names { - args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name, - "--type", string(vm.SRV), - "--zone", n.managedZone, - } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) - // Clear the cache entry regardless of the outcome. As the records may - // have been partially deleted. - n.clearCacheEntry(name) + err := n.withRecordLock(name, func() error { + args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name, + "--type", string(vm.SRV), + "--zone", n.managedZone, + } + cmd := exec.CommandContext(ctx, "gcloud", args...) + out, err := n.execFn(cmd) + // Clear the cache entry regardless of the outcome. As the records may + // have been partially deleted. + n.clearCacheEntry(name) + if err != nil { + return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) + } + return nil + }) if err != nil { - return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) + return err } } return nil @@ -323,6 +344,26 @@ func (n *dnsProvider) listSRVRecords( return records, nil } +// lockRecordByName locks the record with the given name and returns a function +// that can be used to unlock it. The lock is used to prevent concurrent +// operations on the same record. +func (n *dnsProvider) withRecordLock(name string, f func() error) error { + recordMutex := func() *syncutil.Mutex { + n.recordLock.mu.Lock() + defer n.recordLock.mu.Unlock() + normalisedName := n.normaliseName(name) + mutex, ok := n.recordLock.locks[normalisedName] + if !ok { + mutex = new(syncutil.Mutex) + n.recordLock.locks[normalisedName] = mutex + } + return mutex + }() + recordMutex.Lock() + defer recordMutex.Unlock() + return f() +} + func (n *dnsProvider) updateCache(name string, records []vm.DNSRecord) { n.recordsCache.mu.Lock() defer n.recordsCache.mu.Unlock()