diff --git a/Makefile b/Makefile index de1bcf7..9ba9a73 100644 --- a/Makefile +++ b/Makefile @@ -13,17 +13,23 @@ default: ${PROG} ${PROG}: build - version.go: /bin/sh make-version.sh $(VERSION)-$(COMMIT) $(APPDATE) $(PROG) -build: version.go +build: version.go # ../tapir/tapir.pb.go $(GO) build $(GOFLAGS) -o ${PROG} +# ../tapir/tapir.pb.go: ../tapir/tapir.proto +# make -C ../tapir tapir.pb.go + linux: /bin/sh make-version.sh $(VERSION)-$(COMMIT) $(APPDATE) $(PROG) GOOS=linux GOARCH=amd64 go build $(GOFLAGS) -o ${PROG}.linux +netbsd: + /bin/sh make-version.sh $(VERSION)-$(COMMIT) $(APPDATE) $(PROG) + GOOS=netbsd GOARCH=amd64 go build $(GOFLAGS) -o ${PROG}.netbsd + gen-mqtt-msg-new-qname.go: checkout/events-mqtt-message-new_qname.json go-jsonschema checkout/events-mqtt-message-new_qname.json --package main --tags json --only-models --output gen-mqtt-msg-new-qname.go diff --git a/apihandler.go b/apihandler.go index 3fb5872..41dd71d 100644 --- a/apihandler.go +++ b/apihandler.go @@ -1,5 +1,5 @@ /* - * Johan Stenstam, johani@johani.org + * Johan Stenstam, johan.stenstam@internetstiftelsen.se */ package main @@ -183,34 +183,135 @@ func APIcommand(conf *Config) func(w http.ResponseWriter, r *http.Request) { // Msg: "Daemon was happy, but now winding down", // } - case "export-greylist-dns-tapir": - // exportGreylistDnsTapir(w, r, conf.TemData) + // End of Selection + default: + resp.Error = true + resp.ErrorMsg = fmt.Sprintf("Unknown command: %s", cp.Command) + } + } +} + +func APIbootstrap(conf *Config) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + resp := tapir.BootstrapResponse{ + Status: "ok", // only status we know, so far + Msg: "We're happy, but send more cookies", + } + + defer func() { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(resp) + if err != nil { + log.Printf("Error from json encoder: %v", err) + log.Printf("resp: %v", resp) + } + }() + + decoder := json.NewDecoder(r.Body) + var bp tapir.BootstrapPost + err := decoder.Decode(&bp) + if err != nil { + log.Println("APIbootstrap: error decoding command post:", err) + resp.Error = true + resp.ErrorMsg = fmt.Sprintf("Error decoding command post: %v", err) + return + } + + log.Printf("API: received /bootstrap request (cmd: %s) from %s.\n", bp.Command, r.RemoteAddr) + + switch bp.Command { + case "export-greylist": td := conf.TemData td.mu.RLock() defer td.mu.RUnlock() - greylist, ok := td.Lists["greylist"]["dns-tapir"] + greylist, ok := td.Lists["greylist"][bp.ListName] if !ok { resp.Error = true - resp.ErrorMsg = "Greylist 'dns-tapir' not found" + resp.ErrorMsg = fmt.Sprintf("Greylist '%s' not found", bp.ListName) return } - log.Printf("Found dns-tapir greylist: %v", greylist) - - w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Disposition", "attachment; filename=greylist-dns-tapir.gob") + log.Printf("Found %s greylist containing %d names", bp.ListName, len(greylist.Names)) + + switch bp.Encoding { + case "gob": + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=greylist-%s.gob", bp.ListName)) + + encoder := gob.NewEncoder(w) + err := encoder.Encode(greylist) + if err != nil { + log.Printf("Error encoding greylist: %v", err) + resp.Error = true + resp.ErrorMsg = err.Error() + return + } - encoder := gob.NewEncoder(w) - err := encoder.Encode(greylist) - if err != nil { - log.Printf("Error encoding greylist: %v", err) + // case "protobuf": + // w.Header().Set("Content-Type", "application/octet-stream") + // w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=greylist-%s.protobuf", bp.ListName)) + // + // data, err := proto.Marshal(greylist) + // if err != nil { + // log.Printf("Error encoding greylist to protobuf: %v", err) + // resp.Error = true + // resp.ErrorMsg = err.Error() + // return + // } + + // _, err = w.Write(data) + // if err != nil { + // log.Printf("Error writing protobuf data to response: %v", err) + // resp.Error = true + // resp.ErrorMsg = err.Error() + // return + // } + + // case "flatbuffer": + // w.Header().Set("Content-Type", "application/octet-stream") + // w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=greylist-%s.flatbuffer", bp.ListName)) + + // builder := flatbuffers.NewBuilder(0) + // names := make([]flatbuffers.UOffsetT, len(greylist.Names)) + + // i := 0 + // for name := range greylist.Names { + // nameOffset := builder.CreateString(name) + // tapir.NameStart(builder) + // tapir.NameAddName(builder, nameOffset) + // names[i] = tapir.NameEnd(builder) + // i++ + // } + + // tapir.GreylistStartNamesVector(builder, len(names)) + // for j := len(names) - 1; j >= 0; j-- { + // builder.PrependUOffsetT(names[j]) + // } + // namesVector := builder.EndVector(len(names)) + + // tapir.GreylistStart(builder) + // tapir.GreylistAddNames(builder, namesVector) + // greylistOffset := tapir.GreylistEnd(builder) + + // builder.Finish(greylistOffset) + // buf := builder.FinishedBytes() + + // _, err := w.Write(buf) + // if err != nil { + // log.Printf("Error writing flatbuffer data to response: %v", err) + // resp.Error = true + // resp.ErrorMsg = err.Error() + // return + // } + + default: resp.Error = true - resp.ErrorMsg = err.Error() + resp.ErrorMsg = fmt.Sprintf("Unknown encoding: %s", bp.Encoding) return } default: - resp.ErrorMsg = fmt.Sprintf("Unknown command: %s", cp.Command) + resp.ErrorMsg = fmt.Sprintf("Unknown command: %s", bp.Command) resp.Error = true } } @@ -280,9 +381,9 @@ func APIdebug(conf *Config) func(w http.ResponseWriter, r *http.Request) { resp.ReaperStats = make(map[string]map[time.Time][]string) for SrcName, list := range td.Lists["greylist"] { resp.ReaperStats[SrcName] = make(map[time.Time][]string) - for ts, items := range list.ReaperData { - for _, item := range items { - resp.ReaperStats[SrcName][ts] = append(resp.ReaperStats[SrcName][ts], item.Name) + for ts, names := range list.ReaperData { + for name := range names { + resp.ReaperStats[SrcName][ts] = append(resp.ReaperStats[SrcName][ts], name) } } } @@ -333,12 +434,25 @@ func SetupRouter(conf *Config) *mux.Router { viper.GetString("apiserver.key")).Subrouter() sr.HandleFunc("/ping", tapir.APIping("tem", conf.BootTime)).Methods("POST") sr.HandleFunc("/command", APIcommand(conf)).Methods("POST") + sr.HandleFunc("/bootstrap", APIbootstrap(conf)).Methods("POST") sr.HandleFunc("/debug", APIdebug(conf)).Methods("POST") // sr.HandleFunc("/show/api", tapir.APIshowAPI(r)).Methods("GET") return r } +func SetupBootstrapRouter(conf *Config) *mux.Router { + r := mux.NewRouter().StrictSlash(true) + + sr := r.PathPrefix("/api/v1").Headers("X-API-Key", viper.GetString("apiserver.key")).Subrouter() + sr.HandleFunc("/ping", tapir.APIping("tem", conf.BootTime)).Methods("POST") + sr.HandleFunc("/bootstrap", APIbootstrap(conf)).Methods("POST") + // sr.HandleFunc("/debug", APIdebug(conf)).Methods("POST") + // sr.HandleFunc("/show/api", tapir.APIshowAPI(r)).Methods("GET") + + return r +} + func walkRoutes(router *mux.Router, address string) { log.Printf("Defined API endpoints for router on: %s\n", address) @@ -370,6 +484,10 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { certfile := viper.GetString("certs.tem.cert") keyfile := viper.GetString("certs.tem.key") + bootstrapaddress := viper.GetString("bootstrapserver.address") + bootstraptlsaddress := viper.GetString("bootstrapserver.tlsaddress") + bootstraprouter := SetupBootstrapRouter(conf) + tlspossible := true _, err := os.Stat(certfile) @@ -384,7 +502,7 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { tlsConfig, err := tapir.NewServerConfig(viper.GetString("certs.cacertfile"), tls.VerifyClientCertIfGiven) // Alternatives are: tls.RequireAndVerifyClientCert, tls.VerifyClientCertIfGiven, // tls.RequireAnyClientCert, tls.RequestClientCert, tls.NoClientCert - // We would like to request a client cert, but until all labgroup servers have certs we cannot do that. + if err != nil { TEMExiter("Error creating API server tls config: %v\n", err) } @@ -394,27 +512,65 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { Handler: router, TLSConfig: tlsConfig, } + bootstrapTlsServer := &http.Server{ + Addr: bootstraptlsaddress, + Handler: bootstraprouter, + TLSConfig: tlsConfig, + } var wg sync.WaitGroup - go func() { - log.Println("Starting API dispatcher #1. Listening on", address) - TEMExiter(http.ListenAndServe(address, router)) - }() + log.Println("*** API: Starting API dispatcher #1. Listening on", address) + + if address != "" { + wg.Add(1) + go func(wg *sync.WaitGroup) { + log.Println("*** API: Starting API dispatcher #1. Listening on", address) + wg.Done() + TEMExiter(http.ListenAndServe(address, router)) + }(&wg) + } if tlsaddress != "" { if tlspossible { wg.Add(1) go func(wg *sync.WaitGroup) { - log.Println("Starting TLS API dispatcher #1. Listening on", tlsaddress) + log.Println("*** API: Starting TLS API dispatcher #1. Listening on", tlsaddress) + wg.Done() TEMExiter(tlsServer.ListenAndServeTLS(certfile, keyfile)) + }(&wg) + } else { + log.Printf("*** API: APIdispatcher: Error: Cannot provide TLS service without cert and key files.\n") + } + } + + if bootstrapaddress != "" { + wg.Add(1) + go func(wg *sync.WaitGroup) { + log.Println("*** API: Starting Bootstrap API dispatcher #1. Listening on", bootstrapaddress) + wg.Done() + TEMExiter(http.ListenAndServe(bootstrapaddress, bootstraprouter)) + }(&wg) + } else { + log.Println("*** API: No bootstrap address specified") + } + + if bootstraptlsaddress != "" { + if tlspossible { + wg.Add(1) + go func(wg *sync.WaitGroup) { + log.Println("*** API: Starting Bootstrap TLS API dispatcher #1. Listening on", bootstraptlsaddress) wg.Done() + TEMExiter(bootstrapTlsServer.ListenAndServeTLS(certfile, keyfile)) }(&wg) } else { - log.Printf("APIdispatch Error: Cannot provide TLS service without cert and key files.\n") + log.Printf("*** API: APIdispatcher: Error: Cannot provide Bootstrap TLS service without cert and key files.\n") } + } else { + log.Println("*** API: No bootstrap TLS address specified") } + wg.Wait() log.Println("API dispatcher: unclear how to stop the http server nicely.") } diff --git a/bootstrap.go b/bootstrap.go new file mode 100644 index 0000000..4ba7262 --- /dev/null +++ b/bootstrap.go @@ -0,0 +1,112 @@ +/* + * Copyright (c) DNS TAPIR + */ +package main + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "github.com/dnstapir/tapir" + "github.com/ryanuber/columnize" + "github.com/spf13/viper" +) + +func (td *TemData) BootstrapMqttSource(s *tapir.WBGlist, src SourceConf) (*tapir.WBGlist, error) { + // Initialize the API client + api := &tapir.ApiClient{ + BaseUrl: fmt.Sprintf(src.BootstrapUrl, src.Bootstrap[0]), // Must specify a valid BaseUrl + ApiKey: src.BootstrapKey, + AuthMethod: "X-API-Key", + } + + cd := viper.GetString("certs.certdir") + if cd == "" { + log.Fatalf("Error: missing config key: certs.certdir") + } + // cert := cd + "/" + certname + cert := cd + "/" + "tem" + tlsConfig, err := tapir.NewClientConfig(viper.GetString("certs.cacertfile"), cert+".key", cert+".crt") + if err != nil { + TEMExiter("BootstrapMqttSource: Error: Could not set up TLS: %v", err) + } + // XXX: Need to verify that the server cert is valid for the bootstrap server + tlsConfig.InsecureSkipVerify = true + err = api.SetupTLS(tlsConfig) + if err != nil { + return nil, fmt.Errorf("Error setting up TLS for the API client: %v", err) + } + + // Iterate over the bootstrap servers + for _, server := range src.Bootstrap { + api.BaseUrl = fmt.Sprintf(src.BootstrapUrl, server) + + // Send an API ping command + pr, err := api.SendPing(0, false) + if err != nil { + td.Logger.Printf("Ping to MQTT bootstrap server %s failed: %v", server, err) + continue + } + + uptime := time.Now().Sub(pr.BootTime).Round(time.Second) + td.Logger.Printf("MQTT bootstrap server %s uptime: %v. It has processed %d MQTT messages", server, uptime, 17) + + status, buf, err := api.RequestNG(http.MethodPost, "/bootstrap", tapir.BootstrapPost{ + Command: "export-greylist", + ListName: src.Name, + Encoding: "gob", // XXX: This is our default, but we'll test other encodings later + }, true) + if err != nil { + fmt.Printf("Error from RequestNG: %v\n", err) + continue + } + + if status != http.StatusOK { + fmt.Printf("HTTP Error: %s\n", buf) + continue + } + + var greylist tapir.WBGlist + decoder := gob.NewDecoder(bytes.NewReader(buf)) + err = decoder.Decode(&greylist) + if err != nil { + // fmt.Printf("Error decoding greylist data: %v\n", err) + // If decoding the gob failed, perhaps we received a tapir.BootstrapResponse instead? + var br tapir.BootstrapResponse + err = json.Unmarshal(buf, &br) + if err != nil { + td.Logger.Printf("Error decoding bootstrap response from %s: %v. Giving up.\n", server, err) + continue + } + if br.Error { + td.Logger.Printf("Bootstrap server %s responded with error: %s (instead of GOB blob)", server, br.ErrorMsg) + } + if len(br.Msg) != 0 { + td.Logger.Printf("Bootstrap server %s responded: %s (instead of GOB blob)", server, br.Msg) + } + // return nil, fmt.Errorf("Command Error: %s", br.ErrorMsg) + continue + } + + if td.Debug { + fmt.Printf("%v\n", greylist) + fmt.Printf("Names present in greylist %s:\n", src.Name) + out := []string{"Name|Time added|TTL|Tags"} + for _, n := range greylist.Names { + out = append(out, fmt.Sprintf("%s|%v|%v|%v", n.Name, n.TimeAdded.Format(tapir.TimeLayout), n.TTL, n.TagMask)) + } + fmt.Printf("%s\n", columnize.SimpleFormat(out)) + } + + // Successfully received and decoded bootstrap data + return &greylist, nil + } + + // If no bootstrap server succeeded + return nil, fmt.Errorf("All bootstrap servers failed") +} diff --git a/config.go b/config.go index a2769fe..6a0ddb7 100644 --- a/config.go +++ b/config.go @@ -1,5 +1,5 @@ /* - * Johan Stenstam, johani@johani.org + * Johan Stenstam, johan.stenstam@internetstiftelsen.se */ package main @@ -10,6 +10,8 @@ import ( "github.com/go-playground/validator/v10" "github.com/spf13/viper" + + "github.com/dnstapir/tapir" ) type Config struct { @@ -24,10 +26,10 @@ type Config struct { Verbose *bool `validate:"required"` Debug *bool `validate:"required"` } - Loggers struct { - Mqtt *log.Logger - Dnsengine *log.Logger - Policy *log.Logger + Loggers struct { + Mqtt *log.Logger + Dnsengine *log.Logger + Policy *log.Logger } Internal InternalConf TemData *TemData @@ -48,20 +50,25 @@ type ServerConf struct { } type SourceConf struct { - Active *bool `validate:"required"` - Name string `validate:"required"` - Description string `validate:"required"` - Type string `validate:"required"` - Format string `validate:"required"` - Source string `validate:"required"` - Filename string - Upstream string - Zone string + Active *bool `validate:"required"` + Name string `validate:"required"` + Description string `validate:"required"` + Type string `validate:"required"` + Format string `validate:"required"` + Source string `validate:"required"` + Topic string + ValidatorKey string + Bootstrap []string + BootstrapUrl string + BootstrapKey string + Filename string + Upstream string + Zone string } type PolicyConf struct { - Logfile string -// Logger *log.Logger + Logfile string + // Logger *log.Logger Whitelist struct { Action string `validate:"required"` } @@ -96,7 +103,7 @@ type ApiserverConf struct { type DnsengineConf struct { Address string `validate:"required"` Logfile string `validate:"required"` -// Logger *log.Logger + // Logger *log.Logger } type InternalConf struct { @@ -151,3 +158,7 @@ func ValidateBySection(config *Config, configsections map[string]interface{}, cf } return nil } + +func (td *TemData) ProcessTapirGlobalConfig(tpkg tapir.TapirMsg) { + log.Printf("TapirProcessGlobalConfig: %+v", tpkg) +} diff --git a/dnshandler.go b/dnshandler.go index 8faa067..10bd2fe 100644 --- a/dnshandler.go +++ b/dnshandler.go @@ -6,6 +6,7 @@ package main import ( "log" + "net" "strings" _ "github.com/mattn/go-sqlite3" @@ -46,16 +47,6 @@ func DnsEngine(conf *Config) error { return nil } -func xxxGetKeepFunc(zone string) (string, func(uint16) bool) { - switch viper.GetString("service.filter") { - case "dnssec": - return "dnssec", tapir.DropDNSSECp - case "dnssec+zonemd": - return "dnssec+zonemd", tapir.DropDNSSECZONEMDp - } - return "none", func(t uint16) bool { return true } -} - func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) { td := conf.TemData @@ -148,6 +139,12 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16, zd := td.Rpz.Axfr.ZoneData // XXX: we need this, but later var glue tapir.RRset + downstream, _, err := net.SplitHostPort(w.RemoteAddr().String()) + if err != nil { + lg.Printf("RpzResponder: Error from net.SplitHostPort(): %v", err) + return nil + } + switch qtype { case dns.TypeAXFR: lg.Printf("We have the zone %s, so let's try to serve it", td.Rpz.ZoneName) @@ -157,14 +154,13 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16, // td.Logger.Printf("RpzResponder: sending zone %s with %d body RRs to XfrOut", // zd.ZoneName, len(zd.RRs)) - serial, _, err := td.RpzAxfrOut(w, r) + _, _, err := td.RpzAxfrOut(w, r) if err != nil { lg.Printf("RpzResponder: error from RpzAxfrOut() serving zone %s: %v", zd.ZoneName, err) } - td.mu.Lock() - td.Downstreams.Serial = serial - td.mu.Unlock() + return nil + case dns.TypeIXFR: lg.Printf("RpzResponder: %s is our RPZ output", td.Rpz.ZoneName) @@ -172,8 +168,9 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16, if err != nil { lg.Printf("RpzResponder: error from RpzIxfrOut() serving zone %s: %v", zd.ZoneName, err) } + td.mu.Lock() - td.Downstreams.Serial = serial + td.DownstreamSerials[downstream] = serial // track the highest known serial for each downstream td.mu.Unlock() return nil case dns.TypeSOA: diff --git a/go.mod b/go.mod index e764e76..4e5b1a9 100644 --- a/go.mod +++ b/go.mod @@ -10,13 +10,14 @@ require ( github.com/gorilla/mux v1.8.1 github.com/mattn/go-sqlite3 v1.14.22 github.com/miekg/dns v1.1.59 + github.com/ryanuber/columnize v2.1.2+incompatible github.com/smhanov/dawg v0.0.0-20220118194912-66057bdbf2e3 github.com/spf13/viper v1.18.2 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/Pashugan/trie v0.0.0-20230121015024-96f8fcbb2af1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/eclipse/paho.golang v0.20.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect @@ -35,7 +36,6 @@ require ( github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect - github.com/ryanuber/columnize v2.1.2+incompatible // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/segmentio/asm v1.2.0 // indirect @@ -56,5 +56,4 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.19.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 2d2cd89..d9fa4ff 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/Pashugan/trie v0.0.0-20230121015024-96f8fcbb2af1 h1:acggebuZWlHgNzQWwdp88oRtlw8xQBlylvd/yUxxG54= -github.com/Pashugan/trie v0.0.0-20230121015024-96f8fcbb2af1/go.mod h1:pTsSapvqi2cR2qzfSwkawuSAW8vspP2C/cJ5yEwOrdU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -23,10 +21,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.18.0 h1:BvolUXjp4zuvkZ5YN5t7ebzbhlUtPsPm2S9NAZ5nl9U= -github.com/go-playground/validator/v10 v10.18.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= -github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= -github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= @@ -47,14 +41,10 @@ github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= -github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= -github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk= github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= -github.com/lestrrat-go/jwx/v2 v2.0.19 h1:ekv1qEZE6BVct89QA+pRF6+4pCpfVrOnEJnTnT4RXoY= -github.com/lestrrat-go/jwx/v2 v2.0.19/go.mod h1:l3im3coce1lL2cDeAjqmaR+Awx+X8Ih+2k8BuHNJ4CU= github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55FHrR0= github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= @@ -63,8 +53,6 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= -github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -105,9 +93,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/twotwotwo/sorts v0.0.0-20160814051341-bf5c1f2b8553 h1:DRC1ubdb3ZmyyIeCSTxjZIQAnpLPfKVgYrLETQuOPjo= @@ -118,8 +106,6 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -131,14 +117,10 @@ golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+o golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.1-0.20200828183125-ce943fd02449/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -148,8 +130,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -157,8 +137,6 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index ec3a480..7dc7f79 100644 --- a/main.go +++ b/main.go @@ -38,12 +38,14 @@ func (td *TemData) SaveRpzSerial() error { if serialFile == "" { log.Fatalf("TEMExiter:No serial cache file specified") } - serialData := []byte(fmt.Sprintf("%d", td.Rpz.CurrentSerial)) - err := os.WriteFile(serialFile, serialData, 0644) + // serialData := []byte(fmt.Sprintf("%d", td.Rpz.CurrentSerial)) + // err := os.WriteFile(serialFile, serialData, 0644) + serialYaml := fmt.Sprintf("current_serial: %d\n", td.Rpz.CurrentSerial) + err := os.WriteFile(serialFile, []byte(serialYaml), 0644) if err != nil { - log.Printf("Error writing current serial to file: %v", err) + log.Printf("Error writing YAML serial to file: %v", err) } else { - log.Printf("Saved current serial %d to file %s", td.Downstreams.Serial, serialFile) + log.Printf("Saved current serial %d to file %s", td.Rpz.CurrentSerial, serialFile) } return err } @@ -177,10 +179,13 @@ func main() { TEMExiter("Error from NewTemData: %v", err) } go td.RefreshEngine(&conf, stopch) - err = td.ParseSources() + + log.Println("*** main: Calling ParseSourcesNG()") + err = td.ParseSourcesNG() if err != nil { - TEMExiter("Error from ParseSources: %v", err) + TEMExiter("Error from ParseSourcesNG: %v", err) } + log.Println("*** main: Returned from ParseSourcesNG()") err = td.ParseOutputs() if err != nil { diff --git a/mqtt.go b/mqtt.go index 7a77d38..e235815 100644 --- a/mqtt.go +++ b/mqtt.go @@ -85,27 +85,28 @@ func (td *TemData) ProcessTapirUpdate(tpkg tapir.MqttPkg) (bool, error) { return false, fmt.Errorf("MQTT Source %s is unknown, update rejected", tpkg.Data.SrcName) } - for _, name := range tpkg.Data.Added { - ttl := time.Duration(name.TTL) * time.Second + for _, tname := range tpkg.Data.Added { + ttl := time.Duration(tname.TTL) * time.Second tmp := tapir.TapirName{ - Name: name.Name, - TimeAdded: name.TimeAdded, + Name: tname.Name, + TimeAdded: tname.TimeAdded, TTL: ttl, - TagMask: name.TagMask, + TagMask: tname.TagMask, } - wbgl.Names[name.Name] = &tmp + wbgl.Names[tname.Name] = tmp td.Logger.Printf("ProcessTapirUpdate: adding name %s to %s (TimeAdded: %s ttl: %v)", - name.Name, wbgl.Name, name.TimeAdded.Format(tapir.TimeLayout), name.TTL) + tname.Name, wbgl.Name, tname.TimeAdded.Format(tapir.TimeLayout), tname.TTL) // Time that the name will be removed from the list - reptime := name.TimeAdded.Add(ttl).Truncate(td.ReaperInterval) + // must ensure that reapertime is at least ReaperInterval into the future + reptime := tname.TimeAdded.Add(ttl).Truncate(td.ReaperInterval).Add(td.ReaperInterval) // Ensure that there are no prior removal events for this name for reaperTime, namesMap := range wbgl.ReaperData { if reaperTime.Before(reptime) { - if _, exists := namesMap[name.Name]; exists { - delete(namesMap, name.Name) + if _, exists := namesMap[tname.Name]; exists { + delete(namesMap, tname.Name) if len(namesMap) == 0 { delete(wbgl.ReaperData, reaperTime) } @@ -115,18 +116,17 @@ func (td *TemData) ProcessTapirUpdate(tpkg tapir.MqttPkg) (bool, error) { // Add the name to the removal list for the time it will be removed if wbgl.ReaperData[reptime] == nil { - wbgl.ReaperData[reptime] = make(map[string]*tapir.TapirName) + wbgl.ReaperData[reptime] = make(map[string]bool) } - wbgl.ReaperData[reptime][name.Name] = &tmp + wbgl.ReaperData[reptime][tname.Name] = true } - td.Logger.Printf("ProcessTapirUpdate: current state of %s %s ReaperData:", - tpkg.Data.ListType, wbgl.Name) + td.Logger.Printf("ProcessTapirUpdate: current state of %s %s ReaperData:", tpkg.Data.ListType, wbgl.Name) for t, v := range wbgl.ReaperData { if len(v) > 0 { td.Logger.Printf("== At time %s the following names will be removed from the dns-tapir list:", t.Format(tapir.TimeLayout)) - for _, item := range v { - td.Logger.Printf(" %s", item.Name) + for name := range v { + td.Logger.Printf(" %s", name) } } else { td.Logger.Printf("ReaperData: timekey %s is empty, deleting", t.Format(tapir.TimeLayout)) @@ -134,8 +134,8 @@ func (td *TemData) ProcessTapirUpdate(tpkg tapir.MqttPkg) (bool, error) { } } - for _, name := range tpkg.Data.Removed { - delete(wbgl.Names, name.Name) + for _, tname := range tpkg.Data.Removed { + delete(wbgl.Names, tname.Name) } ixfr, err := td.GenerateRpzIxfr(&tpkg.Data) diff --git a/policy.go b/policy.go new file mode 100644 index 0000000..d1bb12e --- /dev/null +++ b/policy.go @@ -0,0 +1,185 @@ +/* + * Copyright (c) DNS TAPIR + */ + +package main + +import ( + "log" + "net" + "os" + "strconv" + "strings" + + "github.com/dnstapir/tapir" + "github.com/spf13/viper" + "gopkg.in/yaml.v3" +) + +type TemOutput struct { + Active bool + Name string + Description string + Type string // listtype, usually "greylist" + Format string // i.e. rpz, etc + Downstream string +} + +type TemOutputs struct { + Outputs map[string]TemOutput +} + +func (td *TemData) ParseOutputs() error { + td.Logger.Printf("ParseOutputs: reading outputs from %s", tapir.TemOutputsCfgFile) + cfgdata, err := os.ReadFile(tapir.TemOutputsCfgFile) + if err != nil { + log.Fatalf("Error from ReadFile(%s): %v", tapir.TemOutputsCfgFile, err) + } + + var oconf = TemOutputs{ + Outputs: make(map[string]TemOutput), + } + + // td.Logger.Printf("ParseOutputs: config read: %s", cfgdata) + err = yaml.Unmarshal(cfgdata, &oconf) + if err != nil { + log.Fatalf("Error from yaml.Unmarshal(OutputsConfig): %v", err) + } + + td.Logger.Printf("ParseOutputs: found %d outputs", len(oconf.Outputs)) + for name, v := range oconf.Outputs { + td.Logger.Printf("ParseOutputs: output %s: type %s, format %s, downstream %s", + name, v.Type, v.Format, v.Downstream) + } + + for name, output := range oconf.Outputs { + if output.Active && strings.ToLower(output.Format) == "rpz" { + td.Logger.Printf("Output %s: Adding RPZ downstream %s to list of Notify receivers", name, output.Downstream) + addr, port, err := net.SplitHostPort(output.Downstream) + if err != nil { + td.Logger.Printf("Invalid downstream address %s: %v", output.Downstream, err) + continue + } + if net.ParseIP(addr) == nil { + td.Logger.Printf("Invalid IP address %s", addr) + continue + } + portInt, err := strconv.Atoi(port) + if err != nil { + td.Logger.Printf("Invalid port %s: %v", port, err) + continue + } + td.Downstreams[addr] = RpzDownstream{Address: addr, Port: portInt} + } + } + // Read the current value of td.Downstreams.Serial from a text file + serialFile := viper.GetString("output.rpz.serialcache") + + if serialFile != "" { + serialData, err := os.ReadFile(serialFile) + if err != nil { + td.Logger.Printf("Error reading serial from file %s: %v", serialFile, err) + td.Rpz.CurrentSerial = 1 + } else { + var serialYaml struct { + CurrentSerial uint32 `yaml:"current_serial"` + } + err = yaml.Unmarshal(serialData, &serialYaml) + if err != nil { + td.Logger.Printf("Error unmarshalling YAML serial data: %v", err) + td.Rpz.CurrentSerial = 1 + } else { + td.Rpz.CurrentSerial = serialYaml.CurrentSerial + td.Logger.Printf("Loaded serial %d from file %s", td.Rpz.CurrentSerial, serialFile) + } + } + } else { + td.Logger.Printf("No serial cache file specified, starting serial at 1") + td.Rpz.CurrentSerial = 1 + } + // td.Rpz.CurrentSerial = td.Downstreams.Serial + return nil +} + +// Note: we onlygethere when we know that this name is only greylisted +// so no need tocheckfor white- or blacklisting +func (td *TemData) ComputeRpzGreylistAction(name string) tapir.Action { + + var greyHits = map[string]*tapir.TapirName{} + for listname, list := range td.Lists["greylist"] { + switch list.Format { + case "map": + if v, exists := list.Names[name]; exists { + // td.Logger.Printf("ComputeRpzGreylistAction: found %s in greylist %s (%d names)", + // name, listname, len(list.Names)) + greyHits[listname] = &v + } + // case "trie": + // if list.Trie.Search(name) != nil { + // greyHits = append(greyHits, v) + // } + default: + TEMExiter("Unknown greylist format %s", list.Format) + } + } + if len(greyHits) >= td.Policy.Greylist.NumSources { + td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d or more sources, action is %s", + name, td.Policy.Greylist.NumSources, tapir.ActionToString[td.Policy.Greylist.NumSourcesAction]) + return td.Policy.Greylist.NumSourcesAction + } + td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d sources, not enough for action", name, len(greyHits)) + + if _, exists := greyHits["dns-tapir"]; exists { + numtapirtags := greyHits["dns-tapir"].TagMask.NumTags() + if numtapirtags >= td.Policy.Greylist.NumTapirTags { + td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has more than %d tapir tags, action is %s", + name, td.Policy.Greylist.NumTapirTags, tapir.ActionToString[td.Policy.Greylist.NumTapirTagsAction]) + return td.Policy.Greylist.NumTapirTagsAction + } + td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has %d tapir tags, not enough for action", name, numtapirtags) + } + td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is present in %d greylists, but does not trigger any action", + name, len(greyHits)) + return td.Policy.WhitelistAction +} + +// Decision to block a greylisted name: +// 1. More than N tags present +// 2. Name is present in more than M sources +// 3. Name + +func ApplyGreyPolicy(name string, v *tapir.TapirName) string { + var rpzaction string + if v.HasAction(tapir.NXDOMAIN) { + rpzaction = "." + } else if v.HasAction(tapir.NODATA) { + rpzaction = "*." + } else if v.HasAction(tapir.DROP) { + rpzaction = "rpz-drop." + } else if v.TagMask != 0 { + log.Printf("there are tags") + rpzaction = "rpz-drop." + } + + return rpzaction +} + +func (td *TemData) ComputeRpzAction(name string) tapir.Action { + if td.Whitelisted(name) { + if td.Debug { + td.Policy.Logger.Printf("ComputeRpzAction: name %s is whitelisted, action is %s", name, tapir.ActionToString[td.Policy.WhitelistAction]) + } + return td.Policy.WhitelistAction + } else if td.Blacklisted(name) { + if td.Debug { + td.Policy.Logger.Printf("ComputeRpzAction: name %s is blacklisted, action is %s", name, tapir.ActionToString[td.Policy.BlacklistAction]) + } + return td.Policy.BlacklistAction + } else if td.Greylisted(name) { + if td.Debug { + td.Policy.Logger.Printf("ComputeRpzAction: name %s is greylisted, needs further evaluation to determine action", name) + } + return td.ComputeRpzGreylistAction(name) // This is not complete, only a placeholder for now. + } + return tapir.WHITELIST +} diff --git a/reaper.go b/reaper.go new file mode 100644 index 0000000..f2c8e12 --- /dev/null +++ b/reaper.go @@ -0,0 +1,78 @@ +/* + * Copyright (c) DNS TAPIR + */ +package main + +import ( + "time" + + "github.com/dnstapir/tapir" +) + +// type WBGC map[string]*tapir.WBGlist + +// 1. Iterate over all lists +// 2. Delete all items from the list that is in the ReaperData bucket for this time slot +// 3. Delete the bucket from the ReaperData map +// 4. Generate a new IXFR for the deleted items +// 5. Send the IXFR to the RPZ +func (td *TemData) Reaper(full bool) error { + timekey := time.Now().Truncate(td.ReaperInterval) + tpkg := tapir.MqttPkg{} + td.Logger.Printf("Reaper: working on time slot %s across all lists", timekey.Format(tapir.TimeLayout)) + for _, listtype := range []string{"whitelist", "greylist", "blacklist"} { + for listname, wbgl := range td.Lists[listtype] { + // This loop is here to ensure that we don't have any old data in the ReaperData bucket + // that has already passed its time slot. + for t, d := range wbgl.ReaperData { + if t.Before(timekey) { + if len(d) == 0 { + continue + } + + td.Logger.Printf("Reaper: Warning: found old reaperdata for time slot %s (that has already passed). Moving %d names to current time slot (%s)", t.Format(tapir.TimeLayout), len(d), timekey.Format(tapir.TimeLayout)) + td.mu.Lock() + if _, exist := wbgl.ReaperData[timekey]; !exist { + wbgl.ReaperData[timekey] = map[string]bool{} + } + for name, _ := range d { + wbgl.ReaperData[timekey][name] = true + } + // wbgl.ReaperData[timekey] = d + delete(wbgl.ReaperData, t) + td.mu.Unlock() + } + } + // td.Logger.Printf("Reaper: working on %s %s", listtype, listname) + if len(wbgl.ReaperData[timekey]) > 0 { + td.Logger.Printf("Reaper: list [%s][%s] has %d timekeys stored", listtype, listname, + len(wbgl.ReaperData[timekey])) + td.mu.Lock() + for name, _ := range wbgl.ReaperData[timekey] { + td.Logger.Printf("Reaper: removing %s from %s %s", name, listtype, listname) + delete(td.Lists[listtype][listname].Names, name) + delete(wbgl.ReaperData[timekey], name) + tpkg.Data.Removed = append(tpkg.Data.Removed, tapir.Domain{Name: name}) + } + // td.Logger.Printf("Reaper: %s %s now has %d items:", listtype, listname, len(td.Lists[listtype][listname].Names)) + // for name, item := range td.Lists[listtype][listname].Names { + // td.Logger.Printf("Reaper: remaining: key: %s name: %s", name, item.Name) + // } + delete(wbgl.ReaperData, timekey) + td.mu.Unlock() + } + } + } + + if len(tpkg.Data.Removed) > 0 { + ixfr, err := td.GenerateRpzIxfr(&tpkg.Data) + if err != nil { + td.Logger.Printf("Reaper: Error from GenerateRpzIxfr(): %v", err) + } + err = td.ProcessIxfrIntoAxfr(ixfr) + if err != nil { + td.Logger.Printf("Reaper: Error from ProcessIxfrIntoAxfr(): %v", err) + } + } + return nil +} diff --git a/refreshengine.go b/refreshengine.go index e9d422b..b30d33d 100644 --- a/refreshengine.go +++ b/refreshengine.go @@ -6,6 +6,8 @@ package main import ( "fmt" "log" + "net" + "strconv" "strings" "time" @@ -88,10 +90,20 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { for { select { case tpkg = <-TapirIntelCh: - log.Printf("RefreshEngine: Tapir IntelUpdate: (src: %s) %d additions and %d removals\n", - tpkg.Data.SrcName, len(tpkg.Data.Added), len(tpkg.Data.Removed)) - td.ProcessTapirUpdate(tpkg) - log.Printf("RefreshEngine: Tapir IntelUpdate evaluated.") + switch tpkg.Data.MsgType { + case "intel-update", "observation": + log.Printf("RefreshEngine: Tapir Observation update: (src: %s) %d additions and %d removals\n", + tpkg.Data.SrcName, len(tpkg.Data.Added), len(tpkg.Data.Removed)) + td.ProcessTapirUpdate(tpkg) + log.Printf("RefreshEngine: Tapir Observation update evaluated.") + + case "global-config": + td.ProcessTapirGlobalConfig(tpkg.Data) + log.Printf("RefreshEngine: Tapir Global Config evaluated.") + + default: + log.Printf("RefreshEngine: Tapir IntelUpdate: unknown msg type: %s", tpkg.Data.MsgType) + } case zr = <-zonerefch: zone = zr.Name @@ -376,25 +388,26 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { } func (td *TemData) NotifyDownstreams() error { - td.Logger.Printf("RefreshEngine: Notifying %d downstreams for RPZ zone %s", len(td.Downstreams.Downstreams), td.Rpz.ZoneName) - for _, d := range td.Downstreams.Downstreams { + td.Logger.Printf("RefreshEngine: Notifying %d downstreams for RPZ zone %s", len(td.Downstreams), td.Rpz.ZoneName) + for _, d := range td.Downstreams { m := new(dns.Msg) m.SetNotify(td.Rpz.ZoneName) td.Rpz.Axfr.SOA.Serial = td.Rpz.CurrentSerial m.Ns = append(m.Ns, dns.RR(&td.Rpz.Axfr.SOA)) - td.Logger.Printf("RefreshEngine: Notifying downstream %s about new SOA serial (%d) for RPZ zone %s", d, td.Rpz.Axfr.SOA.Serial, td.Rpz.ZoneName) - r, err := dns.Exchange(m, d) + dest := net.JoinHostPort(d.Address, strconv.Itoa(d.Port)) + td.Logger.Printf("RefreshEngine: Notifying downstream %s about new SOA serial (%d) for RPZ zone %s", dest, td.Rpz.Axfr.SOA.Serial, td.Rpz.ZoneName) + r, err := dns.Exchange(m, dest) if err != nil { // well, we tried - td.Logger.Printf("Error from downstream %s on Notify(%s): %v", d, td.Rpz.ZoneName, err) + td.Logger.Printf("Error from downstream %s on Notify(%s): %v", dest, td.Rpz.ZoneName, err) continue } if r.Opcode != dns.OpcodeNotify { // well, we tried td.Logger.Printf("Error: not a NOTIFY QR from downstream %s on Notify(%s): %s", - d, td.Rpz.ZoneName, dns.OpcodeToString[r.Opcode]) + dest, td.Rpz.ZoneName, dns.OpcodeToString[r.Opcode]) } else { - td.Logger.Printf("RefreshEngine: Downstream %s responded correctly to Notify(%s) about new SOA serial (%d)", d, td.Rpz.ZoneName, td.Rpz.Axfr.SOA.Serial) + td.Logger.Printf("RefreshEngine: Downstream %s responded correctly to Notify(%s) about new SOA serial (%d)", dest, td.Rpz.ZoneName, td.Rpz.Axfr.SOA.Serial) } } return nil diff --git a/output.go b/rpz.go similarity index 64% rename from output.go rename to rpz.go index e64cc90..521e062 100644 --- a/output.go +++ b/rpz.go @@ -5,85 +5,10 @@ package main import ( - "log" - "os" - "strconv" - "strings" - "github.com/dnstapir/tapir" "github.com/miekg/dns" - "github.com/spf13/viper" - "gopkg.in/yaml.v3" ) -type TemOutput struct { - Active bool - Name string - Description string - Type string // listtype, usually "greylist" - Format string // i.e. rpz, etc - Downstream string -} - -type TemOutputs struct { - Outputs map[string]TemOutput -} - -func (td *TemData) ParseOutputs() error { - td.Logger.Printf("ParseOutputs: reading outputs from %s", tapir.TemOutputsCfgFile) - cfgdata, err := os.ReadFile(tapir.TemOutputsCfgFile) - if err != nil { - log.Fatalf("Error from ReadFile(%s): %v", tapir.TemOutputsCfgFile, err) - } - - var oconf = TemOutputs{ - Outputs: make(map[string]TemOutput), - } - - td.Logger.Printf("ParseOutputs: config read: %s", cfgdata) - err = yaml.Unmarshal(cfgdata, &oconf) - if err != nil { - log.Fatalf("Error from yaml.Unmarshal(OutputsConfig): %v", err) - } - - td.Logger.Printf("ParseOutputs: found %d outputs", len(oconf.Outputs)) - for name, v := range oconf.Outputs { - td.Logger.Printf("ParseOutputs: output %s: type %s, format %s, downstream %s", - name, v.Type, v.Format, v.Downstream) - } - - for name, output := range oconf.Outputs { - if output.Active && strings.ToLower(output.Format) == "rpz" { - td.Logger.Printf("Output %s: Adding RPZ downstream %s to list of Notify receivers", name, output.Downstream) - td.Downstreams.Downstreams = append(td.Downstreams.Downstreams, output.Downstream) - } - } - // Read the current value of td.Downstreams.Serial from a text file - serialFile := viper.GetString("output.rpz.serialcache") - - if serialFile != "" { - serialData, err := os.ReadFile(serialFile) - if err != nil { - td.Logger.Printf("Error reading serial from file %s: %v", serialFile, err) - td.Downstreams.Serial = 1 - } else { - tmp := strings.Replace(string(serialData), "\n", "", -1) - serial, err := strconv.Atoi(tmp) - if err != nil { - td.Logger.Printf("Error converting serial data to integer: %v", err) - } else { - td.Downstreams.Serial = uint32(serial) - td.Logger.Printf("Loaded serial %d from file %s", td.Downstreams.Serial, serialFile) - } - } - } else { - td.Logger.Printf("No serial cache file specified, starting serial at 1") - td.Downstreams.Serial = 1 - } - td.Rpz.CurrentSerial = td.Downstreams.Serial - return nil -} - // XXX: Generating a complete new RPZ zone for output to downstream // Generate the RPZ output based on the currently loaded sources. @@ -155,7 +80,7 @@ func (td *TemData) GenerateRpzAxfr() error { tmp.Action = tmp.Action | v.Action grey[k] = tmp } else { - grey[k] = v + grey[k] = &v } } } @@ -168,7 +93,7 @@ func (td *TemData) GenerateRpzAxfr() error { td.Logger.Printf("GenRpzAxfr: There are a total of %d greylisted names in the sources", len(grey)) newaxfrdata := []*tapir.RpzName{} - td.Rpz.RpzMap = map[string]*tapir.RpzName{} + // td.Rpz.RpzMap = map[string]*tapir.RpzName{} for name, _ := range td.BlacklistedNames { cname := new(dns.CNAME) cname.Hdr = dns.RR_Header{ @@ -227,89 +152,6 @@ func (td *TemData) GenerateRpzAxfr() error { return err } -// Decision to block a greylisted name: -// 1. More than N tags present -// 2. Name is present in more than M sources -// 3. Name - -func ApplyGreyPolicy(name string, v *tapir.TapirName) string { - var rpzaction string - if v.HasAction(tapir.NXDOMAIN) { - rpzaction = "." - } else if v.HasAction(tapir.NODATA) { - rpzaction = "*." - } else if v.HasAction(tapir.DROP) { - rpzaction = "rpz-drop." - } else if v.TagMask != 0 { - log.Printf("there are tags") - rpzaction = "rpz-drop." - } - - return rpzaction -} - -// Note: we onlygethere when we know that this name is only greylisted -// so no need tocheckfor white- or blacklisting -func (td *TemData) ComputeRpzGreylistAction(name string) tapir.Action { - - var greyHits = map[string]*tapir.TapirName{} - for listname, list := range td.Lists["greylist"] { - switch list.Format { - case "map": - if v, exists := list.Names[name]; exists { - // td.Logger.Printf("ComputeRpzGreylistAction: found %s in greylist %s (%d names)", - // name, listname, len(list.Names)) - greyHits[listname] = v - } - // case "trie": - // if list.Trie.Search(name) != nil { - // greyHits = append(greyHits, v) - // } - default: - TEMExiter("Unknown greylist format %s", list.Format) - } - } - if len(greyHits) >= td.Policy.Greylist.NumSources { - td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d or more sources, action is %s", - name, td.Policy.Greylist.NumSources, tapir.ActionToString[td.Policy.Greylist.NumSourcesAction]) - return td.Policy.Greylist.NumSourcesAction - } - td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is in %d sources, not enough for action", name, len(greyHits)) - - if _, exists := greyHits["dns-tapir"]; exists { - numtapirtags := greyHits["dns-tapir"].TagMask.NumTags() - if numtapirtags >= td.Policy.Greylist.NumTapirTags { - td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has more than %d tapir tags, action is %s", - name, td.Policy.Greylist.NumTapirTags, tapir.ActionToString[td.Policy.Greylist.NumTapirTagsAction]) - return td.Policy.Greylist.NumTapirTagsAction - } - td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s has %d tapir tags, not enough for action", name, numtapirtags) - } - td.Policy.Logger.Printf("ComputeRpzGreylistAction: name %s is present in %d greylists, but does not trigger any action", - name, len(greyHits)) - return td.Policy.WhitelistAction -} - -func (td *TemData) ComputeRpzAction(name string) tapir.Action { - if td.Whitelisted(name) { - if td.Debug { - td.Policy.Logger.Printf("ComputeRpzAction: name %s is whitelisted, action is %s", name, tapir.ActionToString[td.Policy.WhitelistAction]) - } - return td.Policy.WhitelistAction - } else if td.Blacklisted(name) { - if td.Debug { - td.Policy.Logger.Printf("ComputeRpzAction: name %s is blacklisted, action is %s", name, tapir.ActionToString[td.Policy.BlacklistAction]) - } - return td.Policy.BlacklistAction - } else if td.Greylisted(name) { - if td.Debug { - td.Policy.Logger.Printf("ComputeRpzAction: name %s is greylisted, needs further evaluation to determine action", name) - } - return td.ComputeRpzGreylistAction(name) // This is not complete, only a placeholder for now. - } - return tapir.WHITELIST -} - // Generate the RPZ representation of the names in the TapirMsg combined with the currently loaded sources. // The output is a []dns.RR with the additions and removals, but without the IXFR SOA serial magic. // Algorithm: diff --git a/sources.go b/sources.go index d405646..2a8909b 100644 --- a/sources.go +++ b/sources.go @@ -4,90 +4,19 @@ package main import ( - "crypto/ecdsa" "fmt" "log" "os" "strings" - "sync" "time" "github.com/dnstapir/tapir" "github.com/miekg/dns" "github.com/smhanov/dawg" "github.com/spf13/viper" + "gopkg.in/yaml.v3" ) -type TemData struct { - mu sync.RWMutex - Lists map[string]map[string]*tapir.WBGlist - RpzRefreshCh chan RpzRefresh - RpzCommandCh chan RpzCmdData - TapirMqttEngineRunning bool - TapirMqttCmdCh chan tapir.MqttEngineCmd - TapirMqttSubCh chan tapir.MqttPkg - TapirMqttPubCh chan tapir.MqttPkg // not used ATM - Logger *log.Logger - MqttLogger *log.Logger - BlacklistedNames map[string]bool - GreylistedNames map[string]*tapir.TapirName - Policy TemPolicy - Rpz RpzData - RpzSources map[string]*tapir.ZoneData - Downstreams RpzDownstream - ReaperInterval time.Duration - MqttEngine *tapir.MqttEngine - Verbose bool - Debug bool -} - -type RpzDownstream struct { - Serial uint32 // Must track the current RPZ serial in the resolver - Downstreams []string -} - -type RpzData struct { - CurrentSerial uint32 - ZoneName string - Axfr RpzAxfr - IxfrChain []RpzIxfr // NOTE: the IxfrChain is in reverse order, newest first! - RpzZone *tapir.ZoneData - RpzMap map[string]*tapir.RpzName -} - -type RpzIxfr struct { - FromSerial uint32 - ToSerial uint32 - Removed []*tapir.RpzName - Added []*tapir.RpzName -} - -type RpzAxfr struct { - Serial uint32 - SOA dns.SOA - NSrrs []dns.RR - Data map[string]*tapir.RpzName - ZoneData *tapir.ZoneData -} - -type TemPolicy struct { - Logger *log.Logger - WhitelistAction tapir.Action - BlacklistAction tapir.Action - Greylist GreylistPolicy -} - -type GreylistPolicy struct { - NumSources int - NumSourcesAction tapir.Action - NumTapirTags int - NumTapirTagsAction tapir.Action - BlackTapirTags tapir.TagMask - BlackTapirAction tapir.Action -} - -type WBGC map[string]*tapir.WBGlist - func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { rpzdata := RpzData{ CurrentSerial: 1, @@ -96,7 +25,7 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { Axfr: RpzAxfr{ Data: map[string]*tapir.RpzName{}, }, - RpzMap: map[string]*tapir.RpzName{}, + // RpzMap: map[string]*tapir.RpzName{}, } repint := viper.GetInt("output.reaper.interval") @@ -107,7 +36,7 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { td := TemData{ Lists: map[string]map[string]*tapir.WBGlist{}, Logger: lg, - MqttLogger: conf.Loggers.Mqtt, + MqttLogger: conf.Loggers.Mqtt, RpzRefreshCh: make(chan RpzRefresh, 10), RpzCommandCh: make(chan RpzCmdData, 10), Rpz: rpzdata, @@ -119,6 +48,7 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { td.Lists["whitelist"] = make(map[string]*tapir.WBGlist, 3) td.Lists["greylist"] = make(map[string]*tapir.WBGlist, 3) td.Lists["blacklist"] = make(map[string]*tapir.WBGlist, 3) + td.Downstreams = map[string]RpzDownstream{} err := td.ParseOutputs() if err != nil { @@ -144,6 +74,7 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { } td.Policy.Greylist.NumSources = viper.GetInt("policy.greylist.numsources.limit") if td.Policy.Greylist.NumSources == 0 { + //nolint:typecheck TEMExiter("Error parsing policy: greylist.numsources.limit cannot be 0") } td.Policy.Greylist.NumSourcesAction, err = @@ -178,53 +109,22 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { return &td, nil } -// 1. Iterate over all lists -// 2. Delete all items from the list that is in the ReaperData bucket for this time slot -// 3. Delete the bucket from the ReaperData map -// 4. Generate a new IXFR for the deleted items -// 5. Send the IXFR to the RPZ -func (td *TemData) Reaper(full bool) error { - timekey := time.Now().Round(td.ReaperInterval) - tpkg := tapir.MqttPkg{} - td.Logger.Printf("Reaper: working on time slot %v across all lists", timekey) - for _, listtype := range []string{"whitelist", "greylist", "blacklist"} { - for listname, wbgl := range td.Lists[listtype] { - // td.Logger.Printf("Reaper: working on %s %s", listtype, listname) - if len(wbgl.ReaperData[timekey]) > 0 { - td.Logger.Printf("Reaper: list [%s][%s] has %d timekeys stored", listtype, listname, - len(wbgl.ReaperData[timekey])) - td.mu.Lock() - for _, item := range wbgl.ReaperData[timekey] { - td.Logger.Printf("Reaper: removing %s from %s %s", item.Name, listtype, listname) - delete(td.Lists[listtype][listname].Names, item.Name) - delete(wbgl.ReaperData[timekey], item.Name) - tpkg.Data.Removed = append(tpkg.Data.Removed, tapir.Domain{Name: item.Name}) - } - // td.Logger.Printf("Reaper: %s %s now has %d items:", listtype, listname, len(td.Lists[listtype][listname].Names)) - // for name, item := range td.Lists[listtype][listname].Names { - // td.Logger.Printf("Reaper: remaining: key: %s name: %s", name, item.Name) - // } - td.mu.Unlock() - } - } +func (td *TemData) ParseSourcesNG() error { + var srcfoo SrcFoo + configFile := tapir.TemSourcesCfgFile + data, err := os.ReadFile(configFile) + if err != nil { + return fmt.Errorf("error reading config file: %v", err) } - if len(tpkg.Data.Removed) > 0 { - ixfr, err := td.GenerateRpzIxfr(&tpkg.Data) - if err != nil { - td.Logger.Printf("Reaper: Error from GenerateRpzIxfr(): %v", err) - } - err = td.ProcessIxfrIntoAxfr(ixfr) - if err != nil { - td.Logger.Printf("Reaper: Error from ProcessIxfrIntoAxfr(): %v", err) - } + err = yaml.Unmarshal(data, &srcfoo) + if err != nil { + return fmt.Errorf("error unmarshalling YAML data: %v", err) } - return nil -} - -func (td *TemData) ParseSources() error { - sources := viper.GetStringSlice("sources.active") - log.Printf("Defined policy sources: %v", sources) + // log.Printf("ParseSourcesNG: Defined policy sources:\n") + // for name, src := range srcfoo.Sources { + // log.Printf(" %s: %s", name, src.Description) + // } td.mu.Lock() td.Lists["whitelist"]["white_catchall"] = @@ -235,8 +135,8 @@ func (td *TemData) ParseSources() error { SrcFormat: "none", Format: "map", Datasource: "Data misplaced in other sources", - Names: map[string]*tapir.TapirName{}, - ReaperData: map[time.Time]map[string]*tapir.TapirName{}, + Names: map[string]tapir.TapirName{}, + ReaperData: map[time.Time]map[string]bool{}, } td.Lists["greylist"]["grey_catchall"] = &tapir.WBGlist{ @@ -246,20 +146,17 @@ func (td *TemData) ParseSources() error { SrcFormat: "none", Format: "map", Datasource: "Data misplaced in other sources", - Names: map[string]*tapir.TapirName{}, - ReaperData: map[time.Time]map[string]*tapir.TapirName{}, + Names: map[string]tapir.TapirName{}, + ReaperData: map[time.Time]map[string]bool{}, } td.mu.Unlock() - srcs := viper.GetStringMap("sources") - td.Logger.Printf("*** ParseSources: there are %d items in spec.", len(srcs)) + srcs := srcfoo.Sources + td.Logger.Printf("*** ParseSourcesNG: there are %d items in spec.", len(srcs)) threads := 0 var rptchan = make(chan string, 5) - var mqttdetails = tapir.MqttDetails{ - ValidatorKeys: make(map[string]*ecdsa.PublicKey), - } if td.MqttEngine == nil { td.mu.Lock() @@ -271,105 +168,77 @@ func (td *TemData) ParseSources() error { } for name, src := range srcs { - switch src.(type) { - case map[string]any: - s := src.(map[string]any) - if _, exist := s["active"]; !exist { - td.Logger.Printf("*** Source \"%s\" is not active. Ignored.", name) - continue - } + if !*src.Active { + td.Logger.Printf("*** ParseSourcesNG: Source \"%s\" is not active. Ignored.", name) + continue + } + td.Logger.Printf("=== ParseSourcesNG: Source: %s (%s) will be used (list type %s)", name, src.Name, src.Type) - switch s["active"].(type) { - case bool: - if s["active"].(bool) == false { - td.Logger.Printf("*** Source \"%s\" is not active (%v). Ignored.", - name, s["active"]) - continue - } - default: - td.Logger.Printf("*** [should not happen] Source \"%s\" active key is of type %t. Ignored.", - name, s["active"]) - continue - } + var err error - td.Logger.Printf("=== Source: %s (%s) will be used (list type %s)", - name, s["name"], s["type"]) + threads++ - var params = map[string]string{} + go func(name string, src SourceConf, thread int) { + // defer func() { + //td.Logger.Printf("<--Thread %d: source \"%s\" (%s) is now complete. %d remaining", thread, name, src.Source, threads) + // }() + td.Logger.Printf("-->Thread %d: parsing source \"%s\" (source %s)", thread, name, src.Source) - for _, key := range []string{"name", "upstream", "filename", "zone", "topic", "validatorkey"} { - if tmp, ok := s[key].(string); ok { - params[key] = tmp - } else { - params[key] = "" - } - } - - threads++ newsource := tapir.WBGlist{ - Name: params["name"], - Description: s["description"].(string), - Type: s["type"].(string), - SrcFormat: s["format"].(string), - Datasource: s["source"].(string), - Names: map[string]*tapir.TapirName{}, - ReaperData: map[time.Time]map[string]*tapir.TapirName{}, - Filename: params["filename"], - RpzUpstream: params["upstream"], - RpzZoneName: dns.Fqdn(params["zone"]), + Name: src.Name, + Description: src.Description, + Type: src.Type, + SrcFormat: src.Format, + Datasource: src.Source, + Names: map[string]tapir.TapirName{}, + ReaperData: map[time.Time]map[string]bool{}, + Filename: src.Filename, + RpzUpstream: src.Upstream, + RpzZoneName: dns.Fqdn(src.Zone), } - if newsource.Datasource == "mqtt" { - key, err := tapir.FetchMqttValidatorKey(params["topic"], params["validatorkey"]) + switch src.Source { + case "mqtt": + td.Logger.Printf("ParseSourcesNG: Fetching MQTT validator key for topic %s", src.Topic) + valkey, err := tapir.FetchMqttValidatorKey(src.Topic, src.ValidatorKey) if err != nil { - td.Logger.Printf("ParseSources: Error fetching MQTT validator key for topic %s: %v", - params["topic"], err) - } else { - mqttdetails.ValidatorKeys[params["topic"]] = key + td.Logger.Printf("ParseSources: Error fetching MQTT validator key for topic %s: %v", src.Topic, err) } - } - var err error + err = td.MqttEngine.AddTopic(src.Topic, valkey) + if err != nil { + TEMExiter("Error adding topic %s to MQTT Engine: %v", src.Topic, err) + } - go func(name string, threads int) { - td.Logger.Printf("Thread: parsing source \"%s\"", name) - switch s["source"] { - case "mqtt": - err = td.MqttEngine.AddTopic(params["topic"], mqttdetails.ValidatorKeys[params["topic"]]) + newsource.Format = "map" // for now + if len(src.Bootstrap) > 0 { + td.Logger.Printf("ParseSourcesNG: The %s MQTT source has %d bootstrap servers: %v", src.Name, len(src.Bootstrap), src.Bootstrap) + tmp, err := td.BootstrapMqttSource(&newsource, src) if err != nil { - TEMExiter("Error adding topic %s to MQTT Engine: %v", params["topic"], err) + td.Logger.Printf("Error bootstrapping MQTT source %s: %v", src.Name, err) + } else { + newsource = *tmp } - - newsource.Format = "map" // for now - // td.Greylists[newsource.Name] = &newsource - td.mu.Lock() - td.Lists["greylist"][newsource.Name] = &newsource - td.Logger.Printf("Created list [greylist][%s]", newsource.Name) - td.mu.Unlock() - td.Logger.Printf("*** MQTT sources are only managed via RefreshEngine.") - rptchan <- name - case "file": - err = td.ParseLocalFile(name, &newsource, rptchan) - case "xfr": - err = td.ParseRpzFeed(name, &newsource, rptchan) - } - if err != nil { - log.Printf("Error parsing source %s (datasource %s): %v", - name, s["source"], err) } - }(name, threads) - - default: - td.Logger.Printf("*** ParseSources: Error: failed to parse source \"%s\": %v", - name, src) - } - } - - if td.MqttEngine != nil && !td.TapirMqttEngineRunning { - err := td.StartMqttEngine(td.MqttEngine) - if err != nil { - TEMExiter("Error starting MQTT Engine: %v", err) - } + td.mu.Lock() + td.Lists["greylist"][newsource.Name] = &newsource + td.Logger.Printf("Created list [greylist][%s]", newsource.Name) + td.mu.Unlock() + td.Logger.Printf("*** MQTT sources are only managed via RefreshEngine.") + rptchan <- name + case "file": + err = td.ParseLocalFile(name, &newsource, rptchan) + case "xfr": + err = td.ParseRpzFeed(name, &newsource, rptchan) + td.Logger.Printf("Thread %d: source \"%s\" (%s) now returned from ParseRpzFeed(). %d remaining", thread, name, threads) + default: + td.Logger.Printf("*** ParseSourcesNG: Error: unhandled source type %s", src.Source) + } + if err != nil { + log.Printf("Error parsing source %s (datasource %s): %v", + name, src.Source, err) + } + }(name, src, threads) } for { @@ -381,9 +250,16 @@ func (td *TemData) ParseSources() error { } } + if td.MqttEngine != nil && !td.TapirMqttEngineRunning { + err := td.StartMqttEngine(td.MqttEngine) + if err != nil { + TEMExiter("Error starting MQTT Engine: %v", err) + } + } + td.Logger.Printf("ParseSources: static sources done.") - err := td.GenerateRpzAxfr() + err = td.GenerateRpzAxfr() if err != nil { td.Logger.Printf("ParseSources: Error from GenerateRpzAxfr(): %v", err) } @@ -404,7 +280,7 @@ func (td *TemData) ParseLocalFile(sourceid string, s *tapir.WBGlist, rpt chan st switch s.SrcFormat { case "domains": - s.Names = map[string]*tapir.TapirName{} + s.Names = map[string]tapir.TapirName{} s.Format = "map" _, err := tapir.ParseText(s.Filename, s.Names, true) if err != nil { @@ -416,7 +292,7 @@ func (td *TemData) ParseLocalFile(sourceid string, s *tapir.WBGlist, rpt chan st } case "csv": - s.Names = map[string]*tapir.TapirName{} + s.Names = map[string]tapir.TapirName{} s.Format = "map" _, err := tapir.ParseCSV(s.Filename, s.Names, true) if err != nil { @@ -467,7 +343,7 @@ func (td *TemData) ParseRpzFeed(sourceid string, s *tapir.WBGlist, rpt chan stri sourceid) } - s.Names = map[string]*tapir.TapirName{} // must initialize + s.Names = map[string]tapir.TapirName{} // must initialize s.Format = "map" // s.RpzZoneName = dns.Fqdn(zone) // s.RpzUpstream = upstream @@ -483,12 +359,12 @@ func (td *TemData) ParseRpzFeed(sourceid string, s *tapir.WBGlist, rpt chan stri } <-reRpt - td.Logger.Printf("ParseRpzFeed: parsing RPZ %s complete", s.RpzZoneName) td.mu.Lock() td.Lists[s.Type][s.Name] = s td.mu.Unlock() rpt <- sourceid + td.Logger.Printf("ParseRpzFeed: parsing RPZ %s complete", s.RpzZoneName) return nil } @@ -532,13 +408,13 @@ func (td *TemData) RpzParseFuncFactory(s *tapir.WBGlist) func(*dns.RR, *tapir.Zo switch s.Type { case "whitelist": if action == tapir.WHITELIST { - s.Names[name] = &tapir.TapirName{Name: name} // drop all other actions + s.Names[name] = tapir.TapirName{Name: name} // drop all other actions } else { td.Logger.Printf("Warning: whitelist RPZ source %s has blacklisted name: %s", s.RpzZoneName, name) td.mu.Lock() td.Lists["greylist"]["grey_catchall"].Names[name] = - &tapir.TapirName{ + tapir.TapirName{ Name: name, Action: action, } // drop all other actions @@ -546,22 +422,22 @@ func (td *TemData) RpzParseFuncFactory(s *tapir.WBGlist) func(*dns.RR, *tapir.Zo } case "blacklist": if action != tapir.WHITELIST { - s.Names[name] = &tapir.TapirName{Name: name, Action: action} + s.Names[name] = tapir.TapirName{Name: name, Action: action} } else { td.Logger.Printf("Warning: blacklist RPZ source %s has whitelisted name: %s", s.RpzZoneName, name) td.mu.Lock() - td.Lists["whitelist"]["white_catchall"].Names[name] = &tapir.TapirName{Name: name} + td.Lists["whitelist"]["white_catchall"].Names[name] = tapir.TapirName{Name: name} td.mu.Unlock() } case "greylist": if action != tapir.WHITELIST { - s.Names[name] = &tapir.TapirName{Name: name, Action: action} + s.Names[name] = tapir.TapirName{Name: name, Action: action} } else { td.Logger.Printf("Warning: greylist RPZ source %s has whitelisted name: %s", s.RpzZoneName, name) td.mu.Lock() - td.Lists["whitelist"]["white_catchall"].Names[name] = &tapir.TapirName{Name: name} + td.Lists["whitelist"]["white_catchall"].Names[name] = tapir.TapirName{Name: name} td.mu.Unlock() } } @@ -569,14 +445,3 @@ func (td *TemData) RpzParseFuncFactory(s *tapir.WBGlist) func(*dns.RR, *tapir.Zo return true } } - -// Generate the RPZ output based on the currently loaded sources. -// The output is a tapir.ZoneData, but with only the RRs (i.e. a []dns.RR) populated. -// Output should consist of: -// 1. Walk all blacklists: -// a) remove any whitelisted names -// b) rest goes straight into output -// 2. Walk all greylists: -// a) collect complete grey data on each name -// b) remove any whitelisted name -// c) evalutate the grey data to make a decision on inclusion or not diff --git a/structs.go b/structs.go new file mode 100644 index 0000000..4f2ca97 --- /dev/null +++ b/structs.go @@ -0,0 +1,93 @@ +/* + * Copyright (c) DNS TAPIR + */ +package main + +import ( + "log" + "sync" + "time" + + "github.com/dnstapir/tapir" + "github.com/miekg/dns" +) + +type TemData struct { + mu sync.RWMutex + Lists map[string]map[string]*tapir.WBGlist + RpzRefreshCh chan RpzRefresh + RpzCommandCh chan RpzCmdData + TapirMqttEngineRunning bool + TapirMqttCmdCh chan tapir.MqttEngineCmd + TapirMqttSubCh chan tapir.MqttPkg + TapirMqttPubCh chan tapir.MqttPkg // not used ATM + Logger *log.Logger + MqttLogger *log.Logger + BlacklistedNames map[string]bool + GreylistedNames map[string]*tapir.TapirName + Policy TemPolicy + Rpz RpzData + RpzSources map[string]*tapir.ZoneData + Downstreams map[string]RpzDownstream // map[ipaddr]RpzDownstream + DownstreamSerials map[string]uint32 // New map to track SOA serials by address + ReaperInterval time.Duration + MqttEngine *tapir.MqttEngine + Verbose bool + Debug bool +} + +type RpzDownstream struct { + Address string + Port int + // Serial uint32 // The serial that the downstream says that it already has in the latest IXFR request + // Downstreams []string +} + +type RpzData struct { + CurrentSerial uint32 + ZoneName string + Axfr RpzAxfr + IxfrChain []RpzIxfr // NOTE: the IxfrChain is in reverse order, newest first! + // RpzZone *tapir.ZoneData + // RpzMap map[string]*tapir.RpzName +} + +type RpzIxfr struct { + FromSerial uint32 + ToSerial uint32 + Removed []*tapir.RpzName + Added []*tapir.RpzName +} + +type RpzAxfr struct { + Serial uint32 + SOA dns.SOA + NSrrs []dns.RR + Data map[string]*tapir.RpzName + ZoneData *tapir.ZoneData +} + +type TemPolicy struct { + Logger *log.Logger + WhitelistAction tapir.Action + BlacklistAction tapir.Action + Greylist GreylistPolicy +} + +type GreylistPolicy struct { + NumSources int + NumSourcesAction tapir.Action + NumTapirTags int + NumTapirTagsAction tapir.Action + BlackTapirTags tapir.TagMask + BlackTapirAction tapir.Action +} + +// type WBGC map[string]*tapir.WBGlist + +type SrcFoo struct { + Src struct { + Style string `yaml:"style"` + } `yaml:"src"` + Sources map[string]SourceConf `yaml:"sources"` +} diff --git a/tem-sources.sample.yaml b/tem-sources.sample.yaml index 08f1e74..749135a 100644 --- a/tem-sources.sample.yaml +++ b/tem-sources.sample.yaml @@ -7,6 +7,8 @@ sources: description: DNS TAPIR main intelligence feed type: greylist source: mqtt + bootstrap: [ 77.72.231.135:5454, 77.72.230.61 ] # www.axfr.net+nsb + bootstrapurl: https://%s/api/v1 format: tapir-mqtt-v1 rpztfc: name: rpz.threat-feed.com diff --git a/tem.sample.yaml b/tem.sample.yaml index f9d6dbf..4669383 100644 --- a/tem.sample.yaml +++ b/tem.sample.yaml @@ -13,17 +13,9 @@ apiserver: address: 127.0.0.1:9099 tlsaddress: 127.0.0.1:9098 -mqtt: - server: mqtt.dev.dnstapir.se:8883 - uid: johani - clientid: johani-tem - topic: events/up/johani/frobozz - cacert: /etc/dnstapir/certs/tapirCA.crt - clientcert: /etc/dnstapir/certs/mqttclient.crt - clientkey: /etc/dnstapir/certs/mqttclient-key.pem - signingkey: /etc/dnstapir/certs/mqttsigner-key.pem - validatorkey: /etc/dnstapir/certs/mqttsigner-pub.pem - qos: 2 +bootstrapserver: + address: + tlsaddress: # server: # listen: 127.0.0.1 @@ -38,6 +30,18 @@ mqtt: dnsengine: addresses: 127.0.0.1:5359 +mqtt: + server: mqtt.dev.dnstapir.se:8883 + uid: johani + clientid: johani-tem + topic: events/up/johani/frobozz + cacert: /etc/dnstapir/certs/tapirCA.crt + clientcert: /etc/dnstapir/certs/mqttclient.crt + clientkey: /etc/dnstapir/certs/mqttclient-key.pem + signingkey: /etc/dnstapir/certs/mqttsigner-key.pem + validatorkey: /etc/dnstapir/certs/mqttsigner-pub.pem + qos: 2 + output: reaper: interval: 60 # seconds, time between runs of deleting expired data diff --git a/xfr.go b/xfr.go index 3bba8c5..eb378e7 100644 --- a/xfr.go +++ b/xfr.go @@ -7,6 +7,8 @@ package main import ( "fmt" "log" + "math" + "net" "strings" "sync" @@ -160,7 +162,28 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er } } + downstream, _, err := net.SplitHostPort(w.RemoteAddr().String()) + if err != nil { + td.Logger.Printf("RpzIxfrOut: Error from net.SplitHostPort(): %v", err) + return 0, 0, err + } + + // tmp := td.Downstreams[downstream] + // tmp.Serial = curserial + + td.mu.Lock() + td.DownstreamSerials[downstream] = curserial zone := td.Rpz.ZoneName + td.mu.Unlock() + + if curserial < td.Rpz.IxfrChain[0].FromSerial { + td.Logger.Printf("RpzIxfrOut: Downstream %s claims to have RPZ %s with serial %d, but the IXFR chain starts at %d; AXFR needed", downstream, zone, curserial, td.Rpz.IxfrChain[0].FromSerial) + serial, _, err := td.RpzAxfrOut(w, r) + if err != nil { + return 0, 0, err + } + return serial, 0, nil + } if td.Verbose { td.Logger.Printf("RpzIxfrOut: Will try to serve RPZ %s to %v (%d IXFRs in chain)\n", zone, @@ -188,7 +211,7 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er var totcount, count int var finalSerial uint32 for _, ixfr := range td.Rpz.IxfrChain { - td.Logger.Printf("IxfrOut: checking client serial(%d) against IXFR[from:%d, to:%d]", + td.Logger.Printf("RpzIxfrOut: checking client serial(%d) against IXFR[from:%d, to:%d]", curserial, ixfr.FromSerial, ixfr.ToSerial) if ixfr.FromSerial >= curserial { finalSerial = ixfr.ToSerial @@ -201,7 +224,7 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er } env.RR = append(env.RR, fromsoa) count++ - td.Logger.Printf("IxfrOut: IXFR[%d,%d] has %d RRs in the removal list", + td.Logger.Printf("RpzIxfrOut: IXFR[%d,%d] has %d RRs in the removal list", ixfr.FromSerial, ixfr.ToSerial, len(ixfr.Removed)) for _, tn := range ixfr.Removed { if td.Debug { @@ -223,11 +246,11 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er tosoa := dns.Copy(dns.RR(&td.Rpz.Axfr.ZoneData.SOA)) tosoa.(*dns.SOA).Serial = ixfr.ToSerial if td.Debug { - td.Logger.Printf("IxfrOut: adding TOSOA to output: %s", tosoa.String()) + td.Logger.Printf("RpzIxfrOut: adding TOSOA to output: %s", tosoa.String()) } env.RR = append(env.RR, tosoa) count++ - td.Logger.Printf("IxfrOut: IXFR[%d,%d] has %d RRs in the added list", + td.Logger.Printf("RpzIxfrOut: IXFR[%d,%d] has %d RRs in the added list", ixfr.FromSerial, ixfr.ToSerial, len(ixfr.Added)) for _, tn := range ixfr.Added { if td.Debug { @@ -253,7 +276,7 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er env.RR = append(env.RR, dns.RR(&td.Rpz.Axfr.SOA)) // trailing SOA total_sent += len(env.RR) - td.Logger.Printf("ZoneTransferOut: Zone %s: Sending final %d RRs (including trailing SOA, total sent %d)\n", + td.Logger.Printf("RpzIxfrOut: Zone %s: Sending final %d RRs (including trailing SOA, total sent %d)\n", zone, len(env.RR), total_sent) // td.Logger.Printf("Sending %d RRs\n", len(env.RR)) @@ -266,7 +289,36 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er wg.Wait() // wait until everything is written out w.Close() // close connection - td.Logger.Printf("ZoneTransferOut: %s: Sent %d RRs (including SOA twice).", zone, total_sent) + td.Logger.Printf("RpzIxfrOut: %s: Sent %d RRs (including SOA twice).", zone, total_sent) + err = td.PruneRpzIxfrChain() + if err != nil { + td.Logger.Printf("RpzIxfrOut: Error from PruneRpzIxfrChain(): %v", err) + } return finalSerial, total_sent - 1, nil } + +func (td *TemData) PruneRpzIxfrChain() error { + lowSerial := uint32(math.MaxUint32) + for _, serial := range td.DownstreamSerials { + if serial < lowSerial { + lowSerial = serial + } + } + + indexToDeleteUpTo := -1 + for i := 0; i < len(td.Rpz.IxfrChain); i++ { + if td.Rpz.IxfrChain[i].FromSerial == lowSerial { + indexToDeleteUpTo = i - 2 + break + } + } + + if indexToDeleteUpTo >= 0 { + td.Rpz.IxfrChain = td.Rpz.IxfrChain[indexToDeleteUpTo+1:] + td.Logger.Printf("PruneRpzIxfrChain: Pruning IXFR chain up to two serials before serial %d", lowSerial) + } else { + td.Logger.Printf("PruneRpzIxfrChain: Nothing to prune from the IXFR chain") + } + return nil +}