Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor storage #158

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Blink Labs Software
// Copyright 2024 Blink Labs Software
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
Expand Down Expand Up @@ -294,15 +294,21 @@ func findNameserversForDomain(
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
// Convert to canonical form for consistency
lookupDomainName = dns.CanonicalName(lookupDomainName)
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
nsRecords, err := state.GetState().LookupRecords([]string{"NS"}, lookupDomainName)
if err != nil {
return "", nil, err
}
if nameservers != nil {
if len(nsRecords) > 0 {
ret := map[string][]net.IP{}
for k, v := range nameservers {
k = dns.Fqdn(k)
ret[k] = append(ret[k], net.ParseIP(v))
for _, nsRecord := range nsRecords {
// Get matching A/AAAA records for NS entry
aRecords, err := state.GetState().LookupRecords([]string{"A", "AAAA"}, nsRecord.Rhs)
if err != nil {
return "", nil, err
}
for _, aRecord := range aRecords {
ret[nsRecord.Rhs] = append(ret[nsRecord.Rhs], net.ParseIP(aRecord.Rhs))
}
}
return dns.Fqdn(lookupDomainName), ret, nil
}
Expand Down
27 changes: 12 additions & 15 deletions internal/indexer/indexer.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Blink Labs Software
// Copyright 2024 Blink Labs Software
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
Expand Down Expand Up @@ -234,23 +234,20 @@ func (i *Indexer) handleEvent(evt event.Event) error {
continue
}
}
nameServers := map[string]string{}
// Convert domain records into our storage format
tmpRecords := []state.DomainRecord{}
for _, record := range dnsDomain.Records {
recordName := strings.Trim(
string(record.Lhs),
`.`,
)
// NOTE: we're losing information here, but we need to revamp the storage
// format before we can use it. We're also making the assumption that all
// records are for nameservers
switch strings.ToUpper(string(record.Type)) {
case "A", "AAAA":
nameServers[recordName] = string(record.Rhs)
default:
continue
tmpRecord := state.DomainRecord{
Lhs: string(record.Lhs),
Type: string(record.Type),
Rhs: string(record.Rhs),
}
if record.Ttl.HasValue() {
tmpRecord.Ttl = int(record.Ttl.Value)
}
tmpRecords = append(tmpRecords, tmpRecord)
}
if err := state.GetState().UpdateDomain(domainName, nameServers); err != nil {
if err := state.GetState().UpdateDomain(domainName, tmpRecords); err != nil {
return err
}
logger.Infof(
Expand Down
120 changes: 76 additions & 44 deletions internal/state/state.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Blink Labs Software
// Copyright 2024 Blink Labs Software
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
Expand All @@ -7,8 +7,11 @@
package state

import (
"bytes"
"encoding/gob"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
Expand All @@ -28,6 +31,13 @@ type State struct {
gcTimer *time.Ticker
}

type DomainRecord struct {
Lhs string
Type string
Ttl int
Rhs string
}

var globalState = &State{}

func (s *State) Load() error {
Expand Down Expand Up @@ -151,69 +161,91 @@ func (s *State) GetCursor() (uint64, string, error) {

func (s *State) UpdateDomain(
domainName string,
nameServers map[string]string,
records []DomainRecord,
) error {
logger := logging.GetLogger()
err := s.db.Update(func(txn *badger.Txn) error {
// Delete old records for domain
keyPrefix := []byte(fmt.Sprintf("domain_%s_", domainName))
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
item := it.Item()
k := item.Key()
if err := txn.Delete(k); err != nil {
return err
}
logger.Debug(
fmt.Sprintf(
"deleted record for domain %s with key: %s",
domainName,
k,
),
)
}
// Add new records
for nameServer, ipAddress := range nameServers {
recordKeys := make([]string, 0)
for recordIdx, record := range records {
key := fmt.Sprintf(
"domain_%s_nameserver_%s",
domainName,
nameServer,
"r_%s_%s_%d",
strings.ToUpper(record.Type),
strings.Trim(record.Lhs, `.`),
wolf31o2 marked this conversation as resolved.
Show resolved Hide resolved
recordIdx,
)
if err := txn.Set([]byte(key), []byte(ipAddress)); err != nil {
recordKeys = append(recordKeys, key)
var gobBuf bytes.Buffer
gobEnc := gob.NewEncoder(&gobBuf)
if err := gobEnc.Encode(&record); err != nil {
return err
}
recordVal := gobBuf.Bytes()[:]
if err := txn.Set([]byte(key), recordVal); err != nil {
return err
}
logger.Debug(
fmt.Sprintf(
"added record for domain %s: %s: %s",
"added record for domain %s: %s: %s: %s",
domainName,
nameServer,
ipAddress,
record.Type,
record.Lhs,
record.Rhs,
),
)
}
// Delete old records in tracking key that are no longer present after this update
domainRecordsKey := []byte(fmt.Sprintf("d_%s_records", domainName))
domainRecordsItem, err := txn.Get(domainRecordsKey)
if err != nil {
if !errors.Is(err, badger.ErrKeyNotFound) {
return err
}
} else {
domainRecordsVal, err := domainRecordsItem.ValueCopy(nil)
if err != nil {
return err
}
domainRecordsSplit := strings.Split(string(domainRecordsVal), ",")
for _, tmpRecordKey := range domainRecordsSplit {
if !slices.Contains(recordKeys, tmpRecordKey) {
if err := txn.Delete([]byte(tmpRecordKey)); err != nil {
return err
}
}
}
}
// Update tracking key with new record keys
recordKeysJoin := strings.Join(recordKeys, ",")
if err := txn.Set(domainRecordsKey, []byte(recordKeysJoin)); err != nil {
return err
}
return nil
})
return err
}

func (s *State) LookupDomain(domainName string) (map[string]string, error) {
ret := map[string]string{}
keyPrefix := []byte(fmt.Sprintf("domain_%s_nameserver_", domainName))
func (s *State) LookupRecords(recordTypes []string, recordName string) ([]DomainRecord, error) {
ret := []DomainRecord{}
recordName = strings.Trim(recordName, `.`)
err := s.db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
item := it.Item()
k := item.Key()
keyParts := strings.Split(string(k), "_")
nameServer := keyParts[len(keyParts)-1]
err := item.Value(func(v []byte) error {
ret[nameServer] = string(v)
return nil
})
if err != nil {
return err
for _, recordType := range recordTypes {
keyPrefix := []byte(fmt.Sprintf("r_%s_%s_", strings.ToUpper(recordType), recordName))
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
item := it.Item()
val, err := item.ValueCopy(nil)
if err != nil {
return err
}
gobBuf := bytes.NewReader(val)
gobDec := gob.NewDecoder(gobBuf)
var tmpRecord DomainRecord
if err := gobDec.Decode(&tmpRecord); err != nil {
return err
}
ret = append(ret, tmpRecord)
}
}
return nil
Expand Down