Skip to content

Commit

Permalink
add an automated test for UDP port forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
cre4ture committed Aug 2, 2024
1 parent 526780d commit 9718675
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 92 deletions.
18 changes: 17 additions & 1 deletion port-forwarder/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ port_forwarding:
assert.True(t, fwd_list.IsEmpty())
}

func TestConfigWithNoProtocols2(t *testing.T) {
func TestConfigWithNoProtocols_commentedProtos(t *testing.T) {
l := logrus.New()
c := config.NewC(l)
err := c.LoadString(`
Expand All @@ -70,6 +70,22 @@ port_forwarding:
assert.True(t, fwd_list.IsEmpty())
}

func TestConfigWithNoProtocols_missing_in_out(t *testing.T) {
l := logrus.New()
c := config.NewC(l)
err := c.LoadString(`
port_forwarding:
`)
assert.Nil(t, err)

fwd_list := NewPortForwardingList()
err = ParseConfig(l, c, fwd_list)
assert.Nil(t, err)

assert.Len(t, fwd_list.configPortForwardings, 0)
assert.True(t, fwd_list.IsEmpty())
}

func TestConfigWithTcpIn(t *testing.T) {
l := logrus.New()
c := config.NewC(l)
Expand Down
121 changes: 121 additions & 0 deletions port-forwarder/port_forwarder_test_udp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package port_forwarder

import (
"net"
"testing"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/service"
"github.com/stretchr/testify/assert"
)

func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) {
c := config.NewC(l)
err := c.LoadString(configStr)
if err != nil {
return nil, err
}

fwd_list := NewPortForwardingList()
err = ParseConfig(l, c, fwd_list)
if err != nil {
return nil, err
}

pf, err := ConstructFromInitialFwdList(srv, l, &fwd_list)
if err != nil {
return nil, err
}

err = pf.Activate()
if err != nil {
return nil, err
}

return pf, nil
}

func testUdpCommunication(
t *testing.T,
msg string,
senderConn *net.UDPConn,
toAddr net.Addr,
receiverConn *net.UDPConn,
) (senderAddr net.Addr) {
data_sent := []byte(msg)
var n int
var err error
if toAddr != nil {
n, err = senderConn.WriteTo(data_sent, toAddr)
} else {
n, err = senderConn.Write(data_sent)
}
assert.Nil(t, err)
assert.Equal(t, n, len(data_sent))

buf := make([]byte, 100)
n, senderAddr, err = receiverConn.ReadFrom(buf)
assert.Nil(t, err)
assert.Equal(t, n, len(data_sent))
assert.Equal(t, data_sent, buf[:n])
return
}

func TestUdpInOut2Clients(t *testing.T) {
l := logrus.New()
server, client := service.CreateTwoConnectedServices()
server_pf, err := createPortForwarderFromConfigString(l, server, `
port_forwarding:
inbound:
- listen_port: 4499
dial_address: 127.0.0.1:5599
protocols: [udp]
`)
assert.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

client_pf, err := createPortForwarderFromConfigString(l, client, `
port_forwarding:
outbound:
- listen_address: 127.0.0.1:3399
dial_address: 10.0.0.1:4499
protocols: [udp]
`)
assert.Nil(t, err)

assert.Len(t, client_pf.portForwardings, 1)

client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399")
assert.Nil(t, err)
server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599")
assert.Nil(t, err)

server_listen_conn, err := net.ListenUDP("udp", server_conn_addr)
assert.Nil(t, err)
client1_conn, err := net.DialUDP("udp", nil, client_conn_addr)
assert.Nil(t, err)
client2_conn, err := net.DialUDP("udp", nil, client_conn_addr)
assert.Nil(t, err)

client1_addr := testUdpCommunication(t, "Hello from client 1 side!",
client1_conn, nil, server_listen_conn)
assert.NotNil(t, client1_addr)
client2_addr := testUdpCommunication(t, "Hello from client two side!",
client2_conn, nil, server_listen_conn)
assert.NotNil(t, client2_addr)

testUdpCommunication(t, "Hello from server first side!",
server_listen_conn, client1_addr, client1_conn)
testUdpCommunication(t, "Hello from server second side!",
server_listen_conn, client2_addr, client2_conn)
testUdpCommunication(t, "Hello from server third side!",
server_listen_conn, client1_addr, client1_conn)

testUdpCommunication(t, "Hello from client two side AGAIN!",
client2_conn, nil, server_listen_conn)

client.Close()
server.Close()
}
92 changes: 1 addition & 91 deletions service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,103 +4,13 @@ import (
"bytes"
"context"
"errors"
"net/netip"
"testing"
"time"

"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
)

type m map[string]interface{}

func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
}

mc := m{
"pki": m{
"ca": string(caB),
"cert": string(myPEM),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
"handshakes": m{
"try_interval": "200ms",
},
}

if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = overrides
}

cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}

var c config.C
if err := c.LoadString(string(cb)); err != nil {
panic(err)
}

l := logrus.New()
s, err := New(&c, l)
if err != nil {
panic(err)
}
return s
}

func TestService(t *testing.T) {
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{},
"lighthouse": m{
"am_lighthouse": true,
},
"listen": m{
"host": "0.0.0.0",
"port": 4243,
},
})
b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
"static_host_map": m{
"10.0.0.1": []string{"localhost:4243"},
},
"lighthouse": m{
"hosts": []string{"10.0.0.1"},
"interval": 1,
},
})
a, b := CreateTwoConnectedServices()

ln, err := a.Listen("tcp", ":1234")
if err != nil {
Expand Down
100 changes: 100 additions & 0 deletions service/service_testhelpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package service

import (
"net/netip"
"time"

"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e"
"gopkg.in/yaml.v2"
)

type m map[string]interface{}

func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
}

mc := m{
"pki": m{
"ca": string(caB),
"cert": string(myPEM),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
"handshakes": m{
"try_interval": "200ms",
},
}

if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = overrides
}

cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}

var c config.C
if err := c.LoadString(string(cb)); err != nil {
panic(err)
}

l := logrus.New()
s, err := New(&c, l)
if err != nil {
panic(err)
}
return s
}

func CreateTwoConnectedServices() (*Service, *Service) {
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{},
"lighthouse": m{
"am_lighthouse": true,
},
"listen": m{
"host": "0.0.0.0",
"port": 4243,
},
})
b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
"static_host_map": m{
"10.0.0.1": []string{"localhost:4243"},
},
"lighthouse": m{
"hosts": []string{"10.0.0.1"},
"interval": 1,
},
})
return a, b
}

0 comments on commit 9718675

Please sign in to comment.