Skip to content

Commit

Permalink
roachprod: lock dns records by name
Browse files Browse the repository at this point in the history
Previously, the lower level DNS implementation had no internal protection
against concurrently modifying the same SRV record entry. This was because from
a use case perspective this should never happen as we generally don't query and
register services at the same time.

This could however cause several race conditions, and it should be protected
against. This change locks records based on the record that is being modified.
It allows concurrent modification of different records, but will lock on record
name to prevent querying and modifying the same record at any point, at least
for this process.

Epic: None
Release Note: None
  • Loading branch information
herkolategan committed Jul 4, 2024
1 parent 0d65492 commit 4f1482c
Showing 1 changed file with 105 additions and 64 deletions.
169 changes: 105 additions & 64 deletions pkg/roachprod/vm/gce/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,16 @@ type dnsProvider struct {
mu syncutil.Mutex
records map[string][]vm.DNSRecord
}
recordLock struct {
mu syncutil.Mutex
locks map[string]*syncutil.Mutex
}
execFn ExecFn
resolvers []*net.Resolver
}

func NewDNSProvider() *dnsProvider {
var gcloudMu syncutil.Mutex
return NewDNSProviderWithExec(func(cmd *exec.Cmd) ([]byte, error) {
// Limit to one gcloud command at a time. At this time we are unsure if it's
// safe to make concurrent calls to the `gcloud` CLI to mutate DNS records
// in the same zone. We don't mutate the same record in parallel, but we do
// mutate different records in the same zone. See: #122180 for more details.
gcloudMu.Lock()
defer gcloudMu.Unlock()
return cmd.CombinedOutput()
})
}
Expand All @@ -116,6 +113,10 @@ func NewDNSProviderWithExec(execFn ExecFn) *dnsProvider {
mu syncutil.Mutex
records map[string][]vm.DNSRecord
}{records: make(map[string][]vm.DNSRecord)},
recordLock: struct {
mu syncutil.Mutex
locks map[string]*syncutil.Mutex
}{locks: make(map[string]*syncutil.Mutex)},
execFn: execFn,
resolvers: googleDNSResolvers(),
}
Expand All @@ -131,64 +132,78 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
}

for name, recordGroup := range recordsByName {
existingRecords, err := n.lookupSRVRecords(ctx, name)
if err != nil {
return err
}
command := "create"
if len(existingRecords) > 0 {
command = "update"
}
err := n.withRecordLock(name, func() error {
existingRecords, err := n.lookupSRVRecords(ctx, name)
if err != nil {
return err
}
command := "create"
if len(existingRecords) > 0 {
command = "update"
}

// Combine old and new records using a map to deduplicate with the record
// data as the key.
combinedRecords := make(map[string]vm.DNSRecord)
for _, record := range existingRecords {
combinedRecords[record.Data] = record
}
for _, record := range recordGroup {
combinedRecords[record.Data] = record
}
// Combine old and new records using a map to deduplicate with the record
// data as the key.
combinedRecords := make(map[string]vm.DNSRecord)
for _, record := range existingRecords {
combinedRecords[record.Data] = record
}
for _, record := range recordGroup {
combinedRecords[record.Data] = record
}

// We assume that all records in a group have the same name, type, and ttl.
// TODO(herko): Add error checking to ensure that the above is the case.
firstRecord := recordGroup[0]
data := maps.Keys(combinedRecords)
sort.Strings(data)
args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name,
"--type", string(firstRecord.Type),
"--ttl", strconv.Itoa(firstRecord.TTL),
"--zone", n.managedZone,
"--rrdatas", strings.Join(data, ","),
}
cmd := exec.CommandContext(ctx, "gcloud", args...)
out, err := n.execFn(cmd)
if err != nil {
// Clear the cache entry if the operation failed, as the records may
// have been partially updated.
n.clearCacheEntry(name)
return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel)
}
// If fastDNS is enabled, we need to wait for the records to become available
// on the Google DNS servers.
if config.FastDNS {
err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...)
// We assume that all records in a group have the same name, type, and ttl.
// TODO(herko): Add error checking to ensure that the above is the case.
firstRecord := recordGroup[0]
data := maps.Keys(combinedRecords)
sort.Strings(data)
args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name,
"--type", string(firstRecord.Type),
"--ttl", strconv.Itoa(firstRecord.TTL),
"--zone", n.managedZone,
"--rrdatas", strings.Join(data, ","),
}
cmd := exec.CommandContext(ctx, "gcloud", args...)
out, err := n.execFn(cmd)
if err != nil {
return err
// Clear the cache entry if the operation failed, as the records may
// have been partially updated.
n.clearCacheEntry(name)
return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel)
}
// If fastDNS is enabled, we need to wait for the records to become available
// on the Google DNS servers.
if config.FastDNS {
err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...)
if err != nil {
return err
}
}
n.updateCache(name, maps.Values(combinedRecords))
return err

})
if err != nil {
return errors.Wrapf(err, "failed to update records for %s", name)
}
n.updateCache(name, maps.Values(combinedRecords))
}
return nil
}

// LookupSRVRecords implements the vm.DNSProvider interface.
func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) {
if config.FastDNS {
rIdx := randutil.FastUint32() % uint32(len(n.resolvers))
return n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true)
}
return n.lookupSRVRecords(ctx, name)
var records []vm.DNSRecord
var err error
err = n.withRecordLock(name, func() error {
if config.FastDNS {
rIdx := randutil.FastUint32() % uint32(len(n.resolvers))
records, err = n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true)
return err
}
records, err = n.lookupSRVRecords(ctx, name)
return err
})
return records, err
}

// ListRecords implements the vm.DNSProvider interface.
Expand All @@ -199,17 +214,23 @@ func (n *dnsProvider) ListRecords(ctx context.Context) ([]vm.DNSRecord, error) {
// DeleteRecordsByName implements the vm.DNSProvider interface.
func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) error {
for _, name := range names {
args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name,
"--type", string(vm.SRV),
"--zone", n.managedZone,
}
cmd := exec.CommandContext(ctx, "gcloud", args...)
out, err := n.execFn(cmd)
// Clear the cache entry regardless of the outcome. As the records may
// have been partially deleted.
n.clearCacheEntry(name)
err := n.withRecordLock(name, func() error {
args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name,
"--type", string(vm.SRV),
"--zone", n.managedZone,
}
cmd := exec.CommandContext(ctx, "gcloud", args...)
out, err := n.execFn(cmd)
// Clear the cache entry regardless of the outcome. As the records may
// have been partially deleted.
n.clearCacheEntry(name)
if err != nil {
return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel)
}
return nil
})
if err != nil {
return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel)
return err
}
}
return nil
Expand Down Expand Up @@ -323,6 +344,26 @@ func (n *dnsProvider) listSRVRecords(
return records, nil
}

// lockRecordByName locks the record with the given name and returns a function
// that can be used to unlock it. The lock is used to prevent concurrent
// operations on the same record.
func (n *dnsProvider) withRecordLock(name string, f func() error) error {
recordMutex := func() *syncutil.Mutex {
n.recordLock.mu.Lock()
defer n.recordLock.mu.Unlock()
normalisedName := n.normaliseName(name)
mutex, ok := n.recordLock.locks[normalisedName]
if !ok {
mutex = new(syncutil.Mutex)
n.recordLock.locks[normalisedName] = mutex
}
return mutex
}()
recordMutex.Lock()
defer recordMutex.Unlock()
return f()
}

func (n *dnsProvider) updateCache(name string, records []vm.DNSRecord) {
n.recordsCache.mu.Lock()
defer n.recordsCache.mu.Unlock()
Expand Down

0 comments on commit 4f1482c

Please sign in to comment.