-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
163 lines (143 loc) · 4.32 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package main
// Generic functions and types for servers
import (
"crypto/tls"
"fmt"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/sys/unix"
"log"
"net"
"strings"
"syscall"
"time"
)
// making this to support dependency injection into the server
type Client interface {
// Make a new connection
Dial(address string) (conn *dns.Conn, err error)
// Run DNS queries
ExchangeWithConn(s *dns.Msg, conn *dns.Conn) (r *dns.Msg, rtt time.Duration, err error)
}
// this abstraction helps us test the entire servedns path
type ResponseWriter interface {
WriteMsg(*dns.Msg) error
}
type Server interface {
// Needs to handle DNS queries
dns.Handler
// Internal function to implement ServeDNS, this allows testing
HandleDNS(w ResponseWriter, m *dns.Msg)
// Retrieves a new connection to an upstream
GetConnection() (ConnEntry, error)
// Runs a recursive query for a given record and record type
RecursiveQuery(domain string, rrtype uint16) (Response, string, error)
// Retrieves records from cache or an upstream
RetrieveRecords(domain string, rrtype uint16) (Response, string, error)
// Retrieve the server's outbound client
GetDnsClient() Client
// Retrieve the cache of locally hosted records
GetHostedCache() Cache
// Add a upstream to the server's list
AddUpstream(u *Upstream)
// Get a copy of the connection pool for this server
GetConnectionPool() ConnPool
}
func processResults(r dns.Msg, domain string, rrtype uint16) (Response, error) {
return Response{
Entry: r,
CreationTime: time.Now(),
Name: domain,
Qtype: rrtype,
}, nil
}
func sendServfail(w ResponseWriter, duration time.Duration, r *dns.Msg) {
LocalServfailsCounter.Inc()
m := &dns.Msg{}
m.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(m)
logQuery("servfail", duration, m)
}
func logQuery(source string, duration time.Duration, response *dns.Msg) error {
var queryContext LogContext
for i, _ := range response.Question {
for j, _ := range response.Answer {
answerBits := strings.Split(response.Answer[j].String(), " ")
queryContext = LogContext{
"name": response.Question[i].Name,
"type": dns.Type(response.Question[i].Qtype).String(),
"opcode": dns.OpcodeToString[response.Opcode],
"answer": answerBits[len(answerBits)-1],
"answerSource": fmt.Sprintf("[%s]", source),
"duration": fmt.Sprintf("%s", duration),
}
QueryLogger.Log(LogMessage{
Context: queryContext,
})
}
}
return nil
}
func sockoptSetter(network, address string, c syscall.RawConn) (err error) {
config := GetConfiguration()
err = c.Control(func(fd uintptr) {
if config.UseTfo {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_FASTOPEN_CONNECT, 1); err != nil {
log.Printf("could not set TCP fast open to [%s]: %s", address, err.Error())
}
}
})
return
}
func buildDialer(timeout time.Duration) (dialer *net.Dialer) {
return &net.Dialer{
Control: sockoptSetter,
Timeout: timeout,
}
}
func BuildClient() (*dns.Client, error) {
config := GetConfiguration()
timeout := config.Timeout * time.Millisecond
cl := &dns.Client{
SingleInflight: true,
Dialer: buildDialer(timeout),
Timeout: timeout,
Net: "tcp-tls",
TLSConfig: &tls.Config{
InsecureSkipVerify: config.SkipUpstreamVerification,
},
}
Logger.Log(LogMessage{
Level: INFO,
Context: LogContext{
"what": "instantiated new dns client in TLS mode",
"next": "returning for use",
},
})
return cl, nil
}
// assumes that the caller will close connection upon any errors
func attemptExchange(m *dns.Msg, ce ConnEntry, client Client) (reply *dns.Msg, err error) {
address := ce.GetAddress()
exchangeTimer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
ExchangeTimer.WithLabelValues(address).Observe(v)
}),
)
reply, rtt, err := client.ExchangeWithConn(m, ce.GetConn())
exchangeTimer.ObserveDuration()
ce.AddExchange(rtt)
if err != nil {
UpstreamErrorsCounter.WithLabelValues(address).Inc()
Logger.Log(NewLogMessage(
ERROR,
LogContext{
"what": fmt.Sprintf("error looking up domain [%s] on server [%s]", m.Question[0].Name, address),
"error": fmt.Sprintf("%s", err),
},
func() string { return fmt.Sprintf("request [%v]", m) },
))
// try the next one
return &dns.Msg{}, err
}
return reply, nil
}