Skip to content

Commit

Permalink
* restructured tem/sources.go
Browse files Browse the repository at this point in the history
* implemented support for pruning the IXFR chain

* implemented support for AXFR fallback if the IXFR chain
  isn't long enough to handle an IXFR request

* changed various data structures as needed
  • Loading branch information
johanix committed May 25, 2024
1 parent b89d6b7 commit 9a5736b
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 470 deletions.
113 changes: 113 additions & 0 deletions bootstrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright (c) DNS TAPIR
*/
package main

import (
"bytes"
"encoding/gob"
"encoding/json"
"fmt"
"log"
"net/http"
"time"

"github.com/dnstapir/tapir"
"github.com/ryanuber/columnize"
"github.com/spf13/viper"
)

func (td *TemData) BootstrapMqttSource(s *tapir.WBGlist, src SourceConf) (*tapir.WBGlist, error) {
// Initialize the API client
api := &tapir.ApiClient{
BaseUrl: fmt.Sprintf(src.BootstrapUrl, src.Bootstrap[0]), // Must specify a valid BaseUrl
ApiKey: src.BootstrapKey,
AuthMethod: "X-API-Key",
}

cd := viper.GetString("certs.certdir")
if cd == "" {
log.Fatalf("Error: missing config key: certs.certdir")
}
// cert := cd + "/" + certname
cert := cd + "/" + "tem"
tlsConfig, err := tapir.NewClientConfig(viper.GetString("certs.cacertfile"), cert+".key", cert+".crt")
if err != nil {
TEMExiter("BootstrapMqttSource: Error: Could not set up TLS: %v", err)
}
// XXX: Need to verify that the server cert is valid for the bootstrap server
tlsConfig.InsecureSkipVerify = true
err = api.SetupTLS(tlsConfig)
if err != nil {
return nil, fmt.Errorf("Error setting up TLS for the API client: %v", err)
}

// Iterate over the bootstrap servers
for _, server := range src.Bootstrap {
api.BaseUrl = fmt.Sprintf(src.BootstrapUrl, server)

// Send an API ping command
pr, err := api.SendPing(0, false)
if err != nil {
td.Logger.Printf("Ping to MQTT bootstrap server %s failed: %v", server, err)
continue
}

uptime := time.Now().Sub(pr.BootTime).Round(time.Second)
td.Logger.Printf("MQTT bootstrap server %s uptime: %v. It has processed %d MQTT messages", server, uptime, 17)

status, buf, err := api.RequestNG(http.MethodPost, "/bootstrap", tapir.BootstrapPost{
Command: "export-greylist",
ListName: src.Name,
}, true)
if err != nil {
fmt.Printf("Error from RequestNG: %v\n", err)
// return nil, fmt.Errorf("Error from RequestNG: %v", err)
continue
}

if status != http.StatusOK {
fmt.Printf("HTTP Error: %s\n", buf)
// return nil, fmt.Errorf("HTTP Error: %s", buf)
continue
}

var greylist tapir.WBGlist
decoder := gob.NewDecoder(bytes.NewReader(buf))
err = decoder.Decode(&greylist)
if err != nil {
// fmt.Printf("Error decoding greylist data: %v\n", err)
// If decoding the gob failed, perhaps we received a tapir.BootstrapResponse instead?
var br tapir.BootstrapResponse
err = json.Unmarshal(buf, &br)
if err != nil {
td.Logger.Printf("Error decoding bootstrap response from %s: %v. Giving up.\n", server, err)
continue
}
if br.Error {
td.Logger.Printf("Bootstrap server %s responded with error: %s (instead of GOB blob)", server, br.ErrorMsg)
}
if len(br.Msg) != 0 {
td.Logger.Printf("Bootstrap server %s responded: %s (instead of GOB blob)", server, br.Msg)
}
// return nil, fmt.Errorf("Command Error: %s", br.ErrorMsg)
continue
}

if td.Debug {
fmt.Printf("%v\n", greylist)
fmt.Printf("Names present in greylist %s:\n", src.Name)
out := []string{"Name|Time added|TTL|Tags"}
for _, n := range greylist.Names {
out = append(out, fmt.Sprintf("%s|%v|%v|%v", n.Name, n.TimeAdded.Format(tapir.TimeLayout), n.TTL, n.TagMask))
}
fmt.Printf("%s\n", columnize.SimpleFormat(out))
}

// Successfully received and decoded bootstrap data
return &greylist, nil
}

// If no bootstrap server succeeded
return nil, fmt.Errorf("All bootstrap servers failed")
}
27 changes: 12 additions & 15 deletions dnshandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package main

import (
"log"
"net"
"strings"

_ "github.com/mattn/go-sqlite3"
Expand Down Expand Up @@ -46,16 +47,6 @@ func DnsEngine(conf *Config) error {
return nil
}

func xxxGetKeepFunc(zone string) (string, func(uint16) bool) {
switch viper.GetString("service.filter") {
case "dnssec":
return "dnssec", tapir.DropDNSSECp
case "dnssec+zonemd":
return "dnssec+zonemd", tapir.DropDNSSECZONEMDp
}
return "none", func(t uint16) bool { return true }
}

func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) {

td := conf.TemData
Expand Down Expand Up @@ -148,6 +139,12 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16,
zd := td.Rpz.Axfr.ZoneData
// XXX: we need this, but later var glue tapir.RRset

downstream, _, err := net.SplitHostPort(w.RemoteAddr().String())
if err != nil {
lg.Printf("RpzResponder: Error from net.SplitHostPort(): %v", err)
return nil
}

switch qtype {
case dns.TypeAXFR:
lg.Printf("We have the zone %s, so let's try to serve it", td.Rpz.ZoneName)
Expand All @@ -157,23 +154,23 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16,
// td.Logger.Printf("RpzResponder: sending zone %s with %d body RRs to XfrOut",
// zd.ZoneName, len(zd.RRs))

serial, _, err := td.RpzAxfrOut(w, r)
_, _, err := td.RpzAxfrOut(w, r)
if err != nil {
lg.Printf("RpzResponder: error from RpzAxfrOut() serving zone %s: %v", zd.ZoneName, err)
}
td.mu.Lock()
td.Downstreams.Serial = serial
td.mu.Unlock()

return nil

case dns.TypeIXFR:
lg.Printf("RpzResponder: %s is our RPZ output", td.Rpz.ZoneName)

serial, _, err := td.RpzIxfrOut(w, r)
if err != nil {
lg.Printf("RpzResponder: error from RpzIxfrOut() serving zone %s: %v", zd.ZoneName, err)
}

td.mu.Lock()
td.Downstreams.Serial = serial
td.DownstreamSerials[downstream] = serial // track the highest known serial for each downstream
td.mu.Unlock()
return nil
case dns.TypeSOA:
Expand Down
10 changes: 6 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ func (td *TemData) SaveRpzSerial() error {
if serialFile == "" {
log.Fatalf("TEMExiter:No serial cache file specified")
}
serialData := []byte(fmt.Sprintf("%d", td.Rpz.CurrentSerial))
err := os.WriteFile(serialFile, serialData, 0644)
// serialData := []byte(fmt.Sprintf("%d", td.Rpz.CurrentSerial))
// err := os.WriteFile(serialFile, serialData, 0644)
serialYaml := fmt.Sprintf("current_serial: %d\n", td.Rpz.CurrentSerial)
err := os.WriteFile(serialFile, []byte(serialYaml), 0644)
if err != nil {
log.Printf("Error writing current serial to file: %v", err)
log.Printf("Error writing YAML serial to file: %v", err)
} else {
log.Printf("Saved current serial %d to file %s", td.Downstreams.Serial, serialFile)
log.Printf("Saved current serial %d to file %s", td.Rpz.CurrentSerial, serialFile)
}
return err
}
Expand Down
3 changes: 2 additions & 1 deletion mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ func (td *TemData) ProcessTapirUpdate(tpkg tapir.MqttPkg) (bool, error) {
tname.Name, wbgl.Name, tname.TimeAdded.Format(tapir.TimeLayout), tname.TTL)

// Time that the name will be removed from the list
reptime := tname.TimeAdded.Add(ttl).Truncate(td.ReaperInterval)
// must ensure that reapertime is at least ReaperInterval into the future
reptime := tname.TimeAdded.Add(ttl).Truncate(td.ReaperInterval).Add(td.ReaperInterval)

// Ensure that there are no prior removal events for this name
for reaperTime, namesMap := range wbgl.ReaperData {
Expand Down
185 changes: 185 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright (c) DNS TAPIR
*/

package main

import (
"log"
"net"
"os"
"strconv"
"strings"

"github.com/dnstapir/tapir"
"github.com/spf13/viper"
"gopkg.in/yaml.v3"
)

type TemOutput struct {
Active bool
Name string
Description string
Type string // listtype, usually "greylist"
Format string // i.e. rpz, etc
Downstream string
}

type TemOutputs struct {
Outputs map[string]TemOutput
}

func (td *TemData) ParseOutputs() error {
td.Logger.Printf("ParseOutputs: reading outputs from %s", tapir.TemOutputsCfgFile)
cfgdata, err := os.ReadFile(tapir.TemOutputsCfgFile)
if err != nil {
log.Fatalf("Error from ReadFile(%s): %v", tapir.TemOutputsCfgFile, err)
}

var oconf = TemOutputs{
Outputs: make(map[string]TemOutput),
}

// td.Logger.Printf("ParseOutputs: config read: %s", cfgdata)
err = yaml.Unmarshal(cfgdata, &oconf)
if err != nil {
log.Fatalf("Error from yaml.Unmarshal(OutputsConfig): %v", err)
}

td.Logger.Printf("ParseOutputs: found %d outputs", len(oconf.Outputs))
for name, v := range oconf.Outputs {
td.Logger.Printf("ParseOutputs: output %s: type %s, format %s, downstream %s",
name, v.Type, v.Format, v.Downstream)
}

for name, output := range oconf.Outputs {
if output.Active && strings.ToLower(output.Format) == "rpz" {
td.Logger.Printf("Output %s: Adding RPZ downstream %s to list of Notify receivers", name, output.Downstream)
addr, port, err := net.SplitHostPort(output.Downstream)
if err != nil {
td.Logger.Printf("Invalid downstream address %s: %v", output.Downstream, err)
continue
}
if net.ParseIP(addr) == nil {
td.Logger.Printf("Invalid IP address %s", addr)
continue
}
portInt, err := strconv.Atoi(port)
if err != nil {
td.Logger.Printf("Invalid port %s: %v", port, err)
continue
}
td.Downstreams[addr] = RpzDownstream{Address: addr, Port: portInt}
}
}
// Read the current value of td.Downstreams.Serial from a text file
serialFile := viper.GetString("output.rpz.serialcache")

if serialFile != "" {
serialData, err := os.ReadFile(serialFile)
if err != nil {
td.Logger.Printf("Error reading serial from file %s: %v", serialFile, err)
td.Rpz.CurrentSerial = 1
} else {
var serialYaml struct {
CurrentSerial uint32 `yaml:"current_serial"`
}
err = yaml.Unmarshal(serialData, &serialYaml)
if err != nil {
td.Logger.Printf("Error unmarshalling YAML serial data: %v", err)
td.Rpz.CurrentSerial = 1
} else {
td.Rpz.CurrentSerial = serialYaml.CurrentSerial
td.Logger.Printf("Loaded serial %d from file %s", td.Rpz.CurrentSerial, serialFile)
}
}
} else {
td.Logger.Printf("No serial cache file specified, starting serial at 1")
td.Rpz.CurrentSerial = 1
}
// td.Rpz.CurrentSerial = td.Downstreams.Serial
return nil
}

// Note: we onlygethere when we know that this name is only greylisted
// so no need tocheckfor white- or blacklisting
func (td *TemData) ComputeRpzGreylistAction(name string) tapir.Action {

var greyHits = map[string]*tapir.TapirName{}
for listname, list := range td.Lists["greylist"] {
switch list.Format {
case "map":
if v, exists := list.Names[name]; exists {
// td.Logger.Printf("ComputeRpzGreylistAction: found %s in greylist %s (%d names)",
// name, listname, len(list.Names))
greyHits[listname] = &v
}
// case "trie":
// if list.Trie.Search(name) != nil {
// greyHits = append(greyHits, v)
// }
default:
TEMExiter("Unknown greylist format %s", list.Format)
}
}
if len(greyHits) >= td.Policy.Greylist.NumSources {
td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d or more sources, action is %s",
name, td.Policy.Greylist.NumSources, tapir.ActionToString[td.Policy.Greylist.NumSourcesAction])
return td.Policy.Greylist.NumSourcesAction
}
td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d sources, not enough for action", name, len(greyHits))

if _, exists := greyHits["dns-tapir"]; exists {
numtapirtags := greyHits["dns-tapir"].TagMask.NumTags()
if numtapirtags >= td.Policy.Greylist.NumTapirTags {
td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has more than %d tapir tags, action is %s",
name, td.Policy.Greylist.NumTapirTags, tapir.ActionToString[td.Policy.Greylist.NumTapirTagsAction])
return td.Policy.Greylist.NumTapirTagsAction
}
td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has %d tapir tags, not enough for action", name, numtapirtags)
}
td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is present in %d greylists, but does not trigger any action",
name, len(greyHits))
return td.Policy.WhitelistAction
}

// Decision to block a greylisted name:
// 1. More than N tags present
// 2. Name is present in more than M sources
// 3. Name

func ApplyGreyPolicy(name string, v *tapir.TapirName) string {
var rpzaction string
if v.HasAction(tapir.NXDOMAIN) {
rpzaction = "."
} else if v.HasAction(tapir.NODATA) {
rpzaction = "*."
} else if v.HasAction(tapir.DROP) {
rpzaction = "rpz-drop."
} else if v.TagMask != 0 {
log.Printf("there are tags")
rpzaction = "rpz-drop."
}

return rpzaction
}

func (td *TemData) ComputeRpzAction(name string) tapir.Action {
if td.Whitelisted(name) {
if td.Debug {
td.Policy.Logger.Printf("ComputeRpzAction: name %s is whitelisted, action is %s", name, tapir.ActionToString[td.Policy.WhitelistAction])
}
return td.Policy.WhitelistAction
} else if td.Blacklisted(name) {
if td.Debug {
td.Policy.Logger.Printf("ComputeRpzAction: name %s is blacklisted, action is %s", name, tapir.ActionToString[td.Policy.BlacklistAction])
}
return td.Policy.BlacklistAction
} else if td.Greylisted(name) {
if td.Debug {
td.Policy.Logger.Printf("ComputeRpzAction: name %s is greylisted, needs further evaluation to determine action", name)
}
return td.ComputeRpzGreylistAction(name) // This is not complete, only a placeholder for now.
}
return tapir.WHITELIST
}
Loading

0 comments on commit 9a5736b

Please sign in to comment.