diff --git a/ipset_linux.go b/ipset_linux.go index 94177e37..30ae878e 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -67,12 +67,13 @@ type IpsetCreateOptions struct { Comments bool Skbinfo bool - Family uint8 - Revision uint8 - IPFrom net.IP - IPTo net.IP - PortFrom uint16 - PortTo uint16 + Family uint8 + Revision uint8 + IPFrom net.IP + IPTo net.IP + PortFrom uint16 + PortTo uint16 + MaxElements uint32 } // IpsetProtocol returns the ipset protocol version from the kernel @@ -167,6 +168,10 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(family))) + if options.MaxElements != 0 { + data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_MAXELEM | nl.NLA_F_NET_BYTEORDER, Value: options.MaxElements}) + } + if timeout := options.Timeout; timeout != nil { data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout}) } diff --git a/ipset_linux_test.go b/ipset_linux_test.go index fa9877bf..27c2e90e 100644 --- a/ipset_linux_test.go +++ b/ipset_linux_test.go @@ -673,3 +673,52 @@ func TestIpsetSwap(t *testing.T) { assertIsEmpty(ipset1) assertHasOneEntry(ipset2) } + +func nextIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +// TestIpsetMaxElements tests that we can create an ipset containing +// 128k elements, which is double the default size (64k elements). +func TestIpsetMaxElements(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + ipsetName := "my-test-ipset-max" + maxElements := uint32(128 << 10) + + err := IpsetCreate(ipsetName, "hash:ip", IpsetCreateOptions{ + Replace: true, + MaxElements: maxElements, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = IpsetDestroy(ipsetName) + }() + + ip := net.ParseIP("10.0.0.0") + for i := uint32(0); i < maxElements; i++ { + err = IpsetAdd(ipsetName, &IPSetEntry{ + IP: ip, + }) + if err != nil { + t.Fatal(err) + } + nextIP(ip) + } + + result, err := IpsetList(ipsetName) + if err != nil { + t.Fatal(err) + } + if len(result.Entries) != int(maxElements) { + t.Fatalf("expected '%d' entry be created, got '%d'", maxElements, len(result.Entries)) + } +}