Skip to content

Commit

Permalink
Do not reuse dns.Envelope variable
Browse files Browse the repository at this point in the history
Sending a reference to a dns.Envelope (&env) to a channel and then
resetting the env variable to a fresh envelope causes the reference sent
over the channel to also reference an empty envelope.

Now we maintain a separate slice of dns.RR records and create a new
envelope on each send instead. Passing the dns.RR slice over the channel
is OK as even if it internally points to data, the slice header itself
will be copied and is not affected when we reset "rrs" to a new empty
slice.
  • Loading branch information
eest committed Jun 3, 2024
1 parent e456757 commit fcd4535
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions xfr.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,35 @@ func (td *TemData) RpzAxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er

count := 0
send_count := 0
env := dns.Envelope{}
rrs := []dns.RR{}

td.Rpz.Axfr.SOA.Serial = td.Rpz.CurrentSerial
env.RR = append(env.RR, dns.RR(&td.Rpz.Axfr.SOA))
rrs = append(rrs, dns.RR(&td.Rpz.Axfr.SOA))
// total_sent := 1
var total_sent int

env.RR = append(env.RR, td.Rpz.Axfr.NSrrs...)
rrs = append(rrs, td.Rpz.Axfr.NSrrs...)

for _, rpzn := range td.Rpz.Axfr.Data {
env.RR = append(env.RR, *rpzn.RR) // should do proper slice magic instead
rrs = append(rrs, *rpzn.RR) // should do proper slice magic instead
count++
if count >= 500 {
send_count++
total_sent += len(env.RR)
// fmt.Printf("Sending %d RRs\n", len(env.RR))
outbound_xfr <- &env
// fmt.Printf("Sent %d RRs: done\n", len(env.RR))
env = dns.Envelope{}
total_sent += len(rrs)
// fmt.Printf("Sending %d RRs\n", len(rrs))
outbound_xfr <- &dns.Envelope{RR: rrs}
rrs = []dns.RR{}
// fmt.Printf("Sent %d RRs: done\n", len(rrs))
count = 0
}
}

env.RR = append(env.RR, dns.RR(&td.Rpz.Axfr.SOA)) // trailing SOA
rrs = append(rrs, dns.RR(&td.Rpz.Axfr.SOA)) // trailing SOA

total_sent += len(env.RR)
total_sent += len(rrs)
// td.Logger.Printf("RpzAxfrOut: Zone %s: Sending final %d RRs (including trailing SOA, total sent %d)\n",
// zone, len(env.RR), total_sent)
outbound_xfr <- &env
// zone, len(rrs), total_sent)
outbound_xfr <- &dns.Envelope{RR: rrs}

close(outbound_xfr)
wg.Wait() // wait until everything is written out
Expand Down Expand Up @@ -214,12 +214,12 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er
wg.Done()
}()

env := dns.Envelope{}
rrs := []dns.RR{}

var total_sent int

td.Rpz.Axfr.SOA.Serial = td.Rpz.CurrentSerial
env.RR = append(env.RR, dns.RR(&td.Rpz.Axfr.SOA))
rrs = append(rrs, dns.RR(&td.Rpz.Axfr.SOA))

var totcount, count int
var finalSerial uint32
Expand All @@ -235,23 +235,23 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er
if td.Debug {
td.Logger.Printf("IxfrOut: adding FROMSOA to output: %s", fromsoa.String())
}
env.RR = append(env.RR, fromsoa)
rrs = append(rrs, fromsoa)
count++
td.Logger.Printf("RpzIxfrOut: IXFR[%d,%d] has %d RRs in the removal list",
ixfr.FromSerial, ixfr.ToSerial, len(ixfr.Removed))
for _, tn := range ixfr.Removed {
if td.Debug {
td.Logger.Printf("DEL: adding RR to ixfr output: %s", tn.Name)
}
env.RR = append(env.RR, *tn.RR) // should do proper slice magic instead
rrs = append(rrs, *tn.RR) // should do proper slice magic instead
count++
if count >= 500 {
td.Logger.Printf("Sending %d RRs\n", len(env.RR))
for _, rr := range env.RR {
td.Logger.Printf("Sending %d RRs\n", len(rrs))
for _, rr := range rrs {
td.Logger.Printf("SEND DELS: %s", rr.String())
}
outbound_xfr <- &env
env = dns.Envelope{}
outbound_xfr <- &dns.Envelope{RR: rrs}
rrs = []dns.RR{}
totcount += count
count = 0
}
Expand All @@ -261,42 +261,42 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er
if td.Debug {
td.Logger.Printf("RpzIxfrOut: adding TOSOA to output: %s", tosoa.String())
}
env.RR = append(env.RR, tosoa)
rrs = append(rrs, tosoa)
count++
td.Logger.Printf("RpzIxfrOut: IXFR[%d,%d] has %d RRs in the added list",
ixfr.FromSerial, ixfr.ToSerial, len(ixfr.Added))
for _, tn := range ixfr.Added {
if td.Debug {
td.Logger.Printf("ADD: adding RR to ixfr output: %s", tn.Name)
}
env.RR = append(env.RR, *tn.RR) // should do proper slice magic instead
rrs = append(rrs, *tn.RR) // should do proper slice magic instead
count++
if count >= 500 {
td.Logger.Printf("Sending %d RRs\n", len(env.RR))
for _, rr := range env.RR {
td.Logger.Printf("Sending %d RRs\n", len(rrs))
for _, rr := range rrs {
td.Logger.Printf("SEND ADDS: %s", rr.String())
}
outbound_xfr <- &env
// fmt.Printf("Sent %d RRs: done\n", len(env.RR))
env = dns.Envelope{}
outbound_xfr <- &dns.Envelope{RR: rrs}
// fmt.Printf("Sent %d RRs: done\n", len(rrs))
rrs = []dns.RR{}
totcount += count
count = 0
}
}
}
}

env.RR = append(env.RR, dns.RR(&td.Rpz.Axfr.SOA)) // trailing SOA
rrs = append(rrs, dns.RR(&td.Rpz.Axfr.SOA)) // trailing SOA

total_sent += len(env.RR)
total_sent += len(rrs)
td.Logger.Printf("RpzIxfrOut: Zone %s: Sending final %d RRs (including trailing SOA, total sent %d)\n",
zone, len(env.RR), total_sent)
zone, len(rrs), total_sent)

// td.Logger.Printf("Sending %d RRs\n", len(env.RR))
// for _, rr := range env.RR {
// td.Logger.Printf("Sending %d RRs\n", len(rrs))
// for _, rr := range rrs {
// td.Logger.Printf("SEND FINAL: %s", rr.String())
// }
outbound_xfr <- &env
outbound_xfr <- &dns.Envelope{RR: rrs}

close(outbound_xfr)
wg.Wait() // wait until everything is written out
Expand Down

0 comments on commit fcd4535

Please sign in to comment.