Skip to content

Commit

Permalink
Merge branch 'vishvananda:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
daveset authored Oct 6, 2023
2 parents f911b1f + a4fcbb7 commit 5a5eea8
Show file tree
Hide file tree
Showing 39 changed files with 2,033 additions and 146 deletions.
22 changes: 22 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package netlink

import (
"fmt"
)

// Chain contains the attributes of a Chain
type Chain struct {
Parent uint32
Chain uint32
}

func (c Chain) String() string {
return fmt.Sprintf("{Parent: %d, Chain: %d}", c.Parent, c.Chain)
}

func NewChain(parent uint32, chain uint32) Chain {
return Chain{
Parent: parent,
Chain: chain,
}
}
112 changes: 112 additions & 0 deletions chain_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package netlink

import (
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)

// ChainDel will delete a chain from the system.
func ChainDel(link Link, chain Chain) error {
// Equivalent to: `tc chain del $chain`
return pkgHandle.ChainDel(link, chain)
}

// ChainDel will delete a chain from the system.
// Equivalent to: `tc chain del $chain`
func (h *Handle) ChainDel(link Link, chain Chain) error {
return h.chainModify(unix.RTM_DELCHAIN, 0, link, chain)
}

// ChainAdd will add a chain to the system.
// Equivalent to: `tc chain add`
func ChainAdd(link Link, chain Chain) error {
return pkgHandle.ChainAdd(link, chain)
}

// ChainAdd will add a chain to the system.
// Equivalent to: `tc chain add`
func (h *Handle) ChainAdd(link Link, chain Chain) error {
return h.chainModify(
unix.RTM_NEWCHAIN,
unix.NLM_F_CREATE|unix.NLM_F_EXCL,
link,
chain)
}

func (h *Handle) chainModify(cmd, flags int, link Link, chain Chain) error {
req := h.newNetlinkRequest(cmd, flags|unix.NLM_F_ACK)
index := int32(0)
if link != nil {
base := link.Attrs()
h.ensureIndex(base)
index = int32(base.Index)
}
msg := &nl.TcMsg{
Family: nl.FAMILY_ALL,
Ifindex: index,
Parent: chain.Parent,
}
req.AddData(msg)
req.AddData(nl.NewRtAttr(nl.TCA_CHAIN, nl.Uint32Attr(chain.Chain)))

_, err := req.Execute(unix.NETLINK_ROUTE, 0)
return err
}

// ChainList gets a list of chains in the system.
// Equivalent to: `tc chain list`.
// The list can be filtered by link.
func ChainList(link Link, parent uint32) ([]Chain, error) {
return pkgHandle.ChainList(link, parent)
}

// ChainList gets a list of chains in the system.
// Equivalent to: `tc chain list`.
// The list can be filtered by link.
func (h *Handle) ChainList(link Link, parent uint32) ([]Chain, error) {
req := h.newNetlinkRequest(unix.RTM_GETCHAIN, unix.NLM_F_DUMP)
index := int32(0)
if link != nil {
base := link.Attrs()
h.ensureIndex(base)
index = int32(base.Index)
}
msg := &nl.TcMsg{
Family: nl.FAMILY_ALL,
Ifindex: index,
Parent: parent,
}
req.AddData(msg)

msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWCHAIN)
if err != nil {
return nil, err
}

var res []Chain
for _, m := range msgs {
msg := nl.DeserializeTcMsg(m)

attrs, err := nl.ParseRouteAttr(m[msg.Len():])
if err != nil {
return nil, err
}

// skip chains from other interfaces
if link != nil && msg.Ifindex != index {
continue
}

var chain Chain
for _, attr := range attrs {
switch attr.Attr.Type {
case nl.TCA_CHAIN:
chain.Chain = native.Uint32(attr.Value)
chain.Parent = parent
}
}
res = append(res, chain)
}

return res, nil
}
77 changes: 77 additions & 0 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//go:build linux
// +build linux

package netlink

import (
"testing"
)

func TestChainAddDel(t *testing.T) {
tearDown := setUpNetlinkTest(t)
defer tearDown()
if err := LinkAdd(&Ifb{LinkAttrs{Name: "foo"}}); err != nil {
t.Fatal(err)
}
if err := LinkAdd(&Ifb{LinkAttrs{Name: "bar"}}); err != nil {
t.Fatal(err)
}
link, err := LinkByName("foo")
if err != nil {
t.Fatal(err)
}
if err := LinkSetUp(link); err != nil {
t.Fatal(err)
}
qdisc := &Ingress{
QdiscAttrs: QdiscAttrs{
LinkIndex: link.Attrs().Index,
Handle: MakeHandle(0xffff, 0),
Parent: HANDLE_INGRESS,
},
}
if err := QdiscAdd(qdisc); err != nil {
t.Fatal(err)
}
qdiscs, err := SafeQdiscList(link)
if err != nil {
t.Fatal(err)
}
if len(qdiscs) != 1 {
t.Fatal("Failed to add qdisc")
}
_, ok := qdiscs[0].(*Ingress)
if !ok {
t.Fatal("Qdisc is the wrong type")
}
chainVal := new(uint32)
*chainVal = 20
chain := NewChain(HANDLE_INGRESS, *chainVal)
err = ChainAdd(link, chain)
if err != nil {
t.Fatal(err)
}
chains, err := ChainList(link, HANDLE_INGRESS)
if err != nil {
t.Fatal(err)
}
if len(chains) != 1 {
t.Fatal("Failed to add chain")
}
if chains[0].Chain != *chainVal {
t.Fatal("Incorrect chain added")
}
if chains[0].Parent != HANDLE_INGRESS {
t.Fatal("Incorrect chain parent")
}
if err := ChainDel(link, chain); err != nil {
t.Fatal(err)
}
chains, err = ChainList(link, HANDLE_INGRESS)
if err != nil {
t.Fatal(err)
}
if len(chains) != 0 {
t.Fatal("Failed to remove chain")
}
}
17 changes: 17 additions & 0 deletions cmd/ipset-test/main.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build linux
// +build linux

package main
Expand Down Expand Up @@ -28,6 +29,7 @@ var (
"listall": {cmdListAll, "list all ipsets", 0},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2},
"test": {cmdTest, "test whether an entry is in a set or not", 2},
}

timeoutVal *uint32
Expand Down Expand Up @@ -140,6 +142,21 @@ func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) {
}
}

func cmdTest(args []string) {
setName := args[0]
element := args[1]
ip := net.ParseIP(element)
entry := &netlink.IPSetEntry{
Timeout: timeoutVal,
IP: ip,
Comment: *comment,
Replace: *replace,
}
exist, err := netlink.IpsetTest(setName, entry)
check(err)
log.Printf("existence: %t\n", exist)
}

// panic on error
func check(err error) {
if err != nil {
Expand Down
72 changes: 68 additions & 4 deletions conntrack_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,26 @@ type ConntrackFlow struct {
TimeStart uint64
TimeStop uint64
TimeOut uint32
Labels []byte
}

func (s *ConntrackFlow) String() string {
// conntrack cmd output:
// udp 17 src=127.0.0.1 dst=127.0.0.1 sport=4001 dport=1234 packets=5 bytes=532 [UNREPLIED] src=127.0.0.1 dst=127.0.0.1 sport=1234 dport=4001 packets=10 bytes=1078 mark=0
// udp 17 src=127.0.0.1 dst=127.0.0.1 sport=4001 dport=1234 packets=5 bytes=532 [UNREPLIED] src=127.0.0.1 dst=127.0.0.1 sport=1234 dport=4001 packets=10 bytes=1078 mark=0 labels=0x00000000050012ac4202010000000000
// start=2019-07-26 01:26:21.557800506 +0000 UTC stop=1970-01-01 00:00:00 +0000 UTC timeout=30(sec)
start := time.Unix(0, int64(s.TimeStart))
stop := time.Unix(0, int64(s.TimeStop))
timeout := int32(s.TimeOut)
return fmt.Sprintf("%s\t%d src=%s dst=%s sport=%d dport=%d packets=%d bytes=%d\tsrc=%s dst=%s sport=%d dport=%d packets=%d bytes=%d mark=0x%x start=%v stop=%v timeout=%d(sec)",
res := fmt.Sprintf("%s\t%d src=%s dst=%s sport=%d dport=%d packets=%d bytes=%d\tsrc=%s dst=%s sport=%d dport=%d packets=%d bytes=%d mark=0x%x ",
nl.L4ProtoMap[s.Forward.Protocol], s.Forward.Protocol,
s.Forward.SrcIP.String(), s.Forward.DstIP.String(), s.Forward.SrcPort, s.Forward.DstPort, s.Forward.Packets, s.Forward.Bytes,
s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.SrcPort, s.Reverse.DstPort, s.Reverse.Packets, s.Reverse.Bytes,
s.Mark, start, stop, timeout)
s.Mark)
if len(s.Labels) > 0 {
res += fmt.Sprintf("labels=0x%x ", s.Labels)
}
res += fmt.Sprintf("start=%v stop=%v timeout=%d(sec)", start, stop, timeout)
return res
}

// This method parse the ip tuple structure
Expand Down Expand Up @@ -306,6 +312,12 @@ func parseConnectionMark(r *bytes.Reader) (mark uint32) {
return
}

func parseConnectionLabels(r *bytes.Reader) (label []byte) {
label = make([]byte, 16) // netfilter defines 128 bit labels value
binary.Read(r, nl.NativeEndian(), &label)
return
}

func parseRawData(data []byte) *ConntrackFlow {
s := &ConntrackFlow{}
// First there is the Nfgenmsg header
Expand Down Expand Up @@ -351,6 +363,8 @@ func parseRawData(data []byte) *ConntrackFlow {
switch t {
case nl.CTA_MARK:
s.Mark = parseConnectionMark(reader)
case nl.CTA_LABELS:
s.Labels = parseConnectionLabels(reader)
case nl.CTA_TIMEOUT:
s.TimeOut = parseTimeOut(reader)
case nl.CTA_STATUS, nl.CTA_USE, nl.CTA_ID:
Expand Down Expand Up @@ -406,6 +420,8 @@ const (
ConntrackReplyAnyIP // Match source or destination reply IP
ConntrackOrigSrcPort // --orig-port-src port Source port in original direction
ConntrackOrigDstPort // --orig-port-dst port Destination port in original direction
ConntrackMatchLabels // --label label1,label2 Labels used in entry
ConntrackUnmatchLabels // --label label1,label2 Labels not used in entry
ConntrackNatSrcIP = ConntrackReplySrcIP // deprecated use instead ConntrackReplySrcIP
ConntrackNatDstIP = ConntrackReplyDstIP // deprecated use instead ConntrackReplyDstIP
ConntrackNatAnyIP = ConntrackReplyAnyIP // deprecated use instead ConntrackReplyAnyIP
Expand All @@ -421,6 +437,7 @@ type ConntrackFilter struct {
ipNetFilter map[ConntrackFilterType]*net.IPNet
portFilter map[ConntrackFilterType]uint16
protoFilter uint8
labelFilter map[ConntrackFilterType][][]byte
}

// AddIPNet adds a IP subnet to the conntrack filter
Expand Down Expand Up @@ -474,10 +491,34 @@ func (f *ConntrackFilter) AddProtocol(proto uint8) error {
return nil
}

// AddLabels adds the provided list (zero or more) of labels to the conntrack filter
// ConntrackFilterType here can be either:
// 1) ConntrackMatchLabels: This matches every flow that has a label value (len(flow.Labels) > 0)
// against the list of provided labels. If `flow.Labels` contains ALL the provided labels
// it is considered a match. This can be used when you want to match flows that contain
// one or more labels.
// 2) ConntrackUnmatchLabels: This matches every flow that has a label value (len(flow.Labels) > 0)
// against the list of provided labels. If `flow.Labels` does NOT contain ALL the provided labels
// it is considered a match. This can be used when you want to match flows that don't contain
// one or more labels.
func (f *ConntrackFilter) AddLabels(tp ConntrackFilterType, labels [][]byte) error {
if len(labels) == 0 {
return errors.New("Invalid length for provided labels")
}
if f.labelFilter == nil {
f.labelFilter = make(map[ConntrackFilterType][][]byte)
}
if _, ok := f.labelFilter[tp]; ok {
return errors.New("Filter attribute already present")
}
f.labelFilter[tp] = labels
return nil
}

// MatchConntrackFlow applies the filter to the flow and returns true if the flow matches the filter
// false otherwise
func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool {
if len(f.ipNetFilter) == 0 && len(f.portFilter) == 0 && f.protoFilter == 0 {
if len(f.ipNetFilter) == 0 && len(f.portFilter) == 0 && f.protoFilter == 0 && len(f.labelFilter) == 0 {
// empty filter always not match
return false
}
Expand Down Expand Up @@ -531,6 +572,29 @@ func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool {
}
}

// Label filter
if len(f.labelFilter) > 0 {
if len(flow.Labels) > 0 {
// --label label1,label2 in conn entry;
// every label passed should be contained in flow.Labels for a match to be true
if elem, found := f.labelFilter[ConntrackMatchLabels]; match && found {
for _, label := range elem {
match = match && (bytes.Contains(flow.Labels, label))
}
}
// --label label1,label2 in conn entry;
// every label passed should be not contained in flow.Labels for a match to be true
if elem, found := f.labelFilter[ConntrackUnmatchLabels]; match && found {
for _, label := range elem {
match = match && !(bytes.Contains(flow.Labels, label))
}
}
} else {
// flow doesn't contain labels, so it doesn't contain or notContain any provided matches
match = false
}
}

return match
}

Expand Down
Loading

0 comments on commit 5a5eea8

Please sign in to comment.