diff --git a/pkg/roachprod/vm/gce/dns.go b/pkg/roachprod/vm/gce/dns.go index be4803ad16ad..f614e9602125 100644 --- a/pkg/roachprod/vm/gce/dns.go +++ b/pkg/roachprod/vm/gce/dns.go @@ -88,19 +88,16 @@ type dnsProvider struct { mu syncutil.Mutex records map[string][]vm.DNSRecord } + recordLock struct { + mu syncutil.Mutex + locks map[string]*syncutil.Mutex + } execFn ExecFn resolvers []*net.Resolver } func NewDNSProvider() *dnsProvider { - var gcloudMu syncutil.Mutex return NewDNSProviderWithExec(func(cmd *exec.Cmd) ([]byte, error) { - // Limit to one gcloud command at a time. At this time we are unsure if it's - // safe to make concurrent calls to the `gcloud` CLI to mutate DNS records - // in the same zone. We don't mutate the same record in parallel, but we do - // mutate different records in the same zone. See: #122180 for more details. - gcloudMu.Lock() - defer gcloudMu.Unlock() return cmd.CombinedOutput() }) } @@ -116,6 +113,10 @@ func NewDNSProviderWithExec(execFn ExecFn) *dnsProvider { mu syncutil.Mutex records map[string][]vm.DNSRecord }{records: make(map[string][]vm.DNSRecord)}, + recordLock: struct { + mu syncutil.Mutex + locks map[string]*syncutil.Mutex + }{locks: make(map[string]*syncutil.Mutex)}, execFn: execFn, resolvers: googleDNSResolvers(), } @@ -131,64 +132,78 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord } for name, recordGroup := range recordsByName { - existingRecords, err := n.lookupSRVRecords(ctx, name) - if err != nil { - return err - } - command := "create" - if len(existingRecords) > 0 { - command = "update" - } + err := n.withRecordLock(name, func() error { + existingRecords, err := n.lookupSRVRecords(ctx, name) + if err != nil { + return err + } + command := "create" + if len(existingRecords) > 0 { + command = "update" + } - // Combine old and new records using a map to deduplicate with the record - // data as the key. - combinedRecords := make(map[string]vm.DNSRecord) - for _, record := range existingRecords { - combinedRecords[record.Data] = record - } - for _, record := range recordGroup { - combinedRecords[record.Data] = record - } + // Combine old and new records using a map to deduplicate with the record + // data as the key. + combinedRecords := make(map[string]vm.DNSRecord) + for _, record := range existingRecords { + combinedRecords[record.Data] = record + } + for _, record := range recordGroup { + combinedRecords[record.Data] = record + } - // We assume that all records in a group have the same name, type, and ttl. - // TODO(herko): Add error checking to ensure that the above is the case. - firstRecord := recordGroup[0] - data := maps.Keys(combinedRecords) - sort.Strings(data) - args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name, - "--type", string(firstRecord.Type), - "--ttl", strconv.Itoa(firstRecord.TTL), - "--zone", n.managedZone, - "--rrdatas", strings.Join(data, ","), - } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) - if err != nil { - // Clear the cache entry if the operation failed, as the records may - // have been partially updated. - n.clearCacheEntry(name) - return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) - } - // If fastDNS is enabled, we need to wait for the records to become available - // on the Google DNS servers. - if config.FastDNS { - err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...) + // We assume that all records in a group have the same name, type, and ttl. + // TODO(herko): Add error checking to ensure that the above is the case. + firstRecord := recordGroup[0] + data := maps.Keys(combinedRecords) + sort.Strings(data) + args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name, + "--type", string(firstRecord.Type), + "--ttl", strconv.Itoa(firstRecord.TTL), + "--zone", n.managedZone, + "--rrdatas", strings.Join(data, ","), + } + cmd := exec.CommandContext(ctx, "gcloud", args...) + out, err := n.execFn(cmd) if err != nil { - return err + // Clear the cache entry if the operation failed, as the records may + // have been partially updated. + n.clearCacheEntry(name) + return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) } + // If fastDNS is enabled, we need to wait for the records to become available + // on the Google DNS servers. + if config.FastDNS { + err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...) + if err != nil { + return err + } + } + n.updateCache(name, maps.Values(combinedRecords)) + return err + + }) + if err != nil { + return errors.Wrapf(err, "failed to update records for %s", name) } - n.updateCache(name, maps.Values(combinedRecords)) } return nil } // LookupSRVRecords implements the vm.DNSProvider interface. func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) { - if config.FastDNS { - rIdx := randutil.FastUint32() % uint32(len(n.resolvers)) - return n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true) - } - return n.lookupSRVRecords(ctx, name) + var records []vm.DNSRecord + var err error + err = n.withRecordLock(name, func() error { + if config.FastDNS { + rIdx := randutil.FastUint32() % uint32(len(n.resolvers)) + records, err = n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true) + return err + } + records, err = n.lookupSRVRecords(ctx, name) + return err + }) + return records, err } // ListRecords implements the vm.DNSProvider interface. @@ -199,17 +214,23 @@ func (n *dnsProvider) ListRecords(ctx context.Context) ([]vm.DNSRecord, error) { // DeleteRecordsByName implements the vm.DNSProvider interface. func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) error { for _, name := range names { - args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name, - "--type", string(vm.SRV), - "--zone", n.managedZone, - } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) - // Clear the cache entry regardless of the outcome. As the records may - // have been partially deleted. - n.clearCacheEntry(name) + err := n.withRecordLock(name, func() error { + args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name, + "--type", string(vm.SRV), + "--zone", n.managedZone, + } + cmd := exec.CommandContext(ctx, "gcloud", args...) + out, err := n.execFn(cmd) + // Clear the cache entry regardless of the outcome. As the records may + // have been partially deleted. + n.clearCacheEntry(name) + if err != nil { + return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) + } + return nil + }) if err != nil { - return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel) + return err } } return nil @@ -323,6 +344,26 @@ func (n *dnsProvider) listSRVRecords( return records, nil } +// lockRecordByName locks the record with the given name and returns a function +// that can be used to unlock it. The lock is used to prevent concurrent +// operations on the same record. +func (n *dnsProvider) withRecordLock(name string, f func() error) error { + recordMutex := func() *syncutil.Mutex { + n.recordLock.mu.Lock() + defer n.recordLock.mu.Unlock() + normalisedName := n.normaliseName(name) + mutex, ok := n.recordLock.locks[normalisedName] + if !ok { + mutex = new(syncutil.Mutex) + n.recordLock.locks[normalisedName] = mutex + } + return mutex + }() + recordMutex.Lock() + defer recordMutex.Unlock() + return f() +} + func (n *dnsProvider) updateCache(name string, records []vm.DNSRecord) { n.recordsCache.mu.Lock() defer n.recordsCache.mu.Unlock()