Skip to content

Commit

Permalink
fix: activator redirect loops
Browse files Browse the repository at this point in the history
To avoid redirect loops, the activator now uses a known port for the
connection to the backend so we can disable redirects for these
packets.

Additionally this splits the bpf program into two separate programs for
ingress and egress, mainly to make things easier to understand but it
also makes the egress path shorter.
  • Loading branch information
ctrox committed Jan 9, 2024
1 parent f04dab3 commit 91c67d7
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 144 deletions.
101 changes: 72 additions & 29 deletions activator/activator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ type Server struct {

type OnAccept func() error

func NewServer(ctx context.Context, ports []uint16, nn ns.NetNS) (*Server, error) {
func NewServer(ctx context.Context, ports []uint16, nn ns.NetNS, ifaces ...string) (*Server, error) {
if len(ifaces) == 0 {
return nil, fmt.Errorf("no interfaces have been supplied, at least one is required")
}

s := &Server{
quit: make(chan interface{}),
ports: ports,
Expand All @@ -41,9 +45,7 @@ func NewServer(ctx context.Context, ports []uint16, nn ns.NetNS) (*Server, error
}

if err := nn.Do(func(_ ns.NetNS) error {
// TODO: is this really always eth0?
// we need loopback for port-forwarding to work
objs, close, err := initBPF("lo", "eth0")
objs, close, err := initBPF(ifaces...)
if err != nil {
return err
}
Expand All @@ -64,7 +66,7 @@ func (s *Server) Start(ctx context.Context, onAccept OnAccept) error {
return err
}

log.G(ctx).Infof("redirecting port %d -> %d", port, proxyPort)
log.G(ctx).Debugf("redirecting port %d -> %d", port, proxyPort)
if err := s.RedirectPort(port, uint16(proxyPort)); err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
Expand Down Expand Up @@ -97,8 +99,6 @@ func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (in
addr := "0.0.0.0:0"
cfg := net.ListenConfig{}

// for some reason, sometimes the address will still be in use after
// checkpointing, so we wrap the listen in a retry.
var listener net.Listener
if err := s.ns.Do(func(_ ns.NetNS) error {
l, err := cfg.Listen(ctx, "tcp4", addr)
Expand All @@ -113,7 +113,7 @@ func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (in
return 0, err
}

log.G(ctx).Infof("listening on %s in ns %s", listener.Addr(), s.ns.Path())
log.G(ctx).Debugf("listening on %s in ns %s", listener.Addr(), s.ns.Path())

s.firstAccept = sync.Once{}
s.onAccept = onAccept
Expand All @@ -130,7 +130,7 @@ func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (in
}

func (s *Server) Stop(ctx context.Context) {
log.G(ctx).Info("stopping activator")
log.G(ctx).Debugf("stopping activator")

if s.proxyCancel != nil {
s.proxyCancel()
Expand All @@ -143,7 +143,7 @@ func (s *Server) Stop(ctx context.Context) {
s.bpfCloseFunc()

s.wg.Wait()
log.G(ctx).Info("activator stopped")
log.G(ctx).Debugf("activator stopped")
}

func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16) {
Expand All @@ -169,7 +169,7 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16)
} else {
wg.Add(1)
go func() {
log.G(ctx).Info("accepting connection")
log.G(ctx).Debug("accepting connection")
s.handleConection(ctx, conn, port)
wg.Done()
}()
Expand All @@ -180,48 +180,51 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16)
func (s *Server) handleConection(ctx context.Context, conn net.Conn, port uint16) {
defer conn.Close()

s.firstAccept.Do(func() {
if err := s.onAccept(); err != nil {
log.G(ctx).Errorf("accept function: %s", err)
return
}
})

tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
log.G(ctx).Errorf("unable to get TCP Addr from remote addr: %T", conn.RemoteAddr())
return
}

log.G(ctx).Infof("registering connection on port %d", tcpAddr.Port)
log.G(ctx).Debugf("registering connection on remote port %d", tcpAddr.Port)
if err := s.registerConnection(uint16(tcpAddr.Port)); err != nil {
log.G(ctx).Errorf("error registering fade out port: %s", err)
log.G(ctx).Errorf("error registering connection: %s", err)
return
}

log.G(ctx).Printf("proxying connection to program at localhost:%d", port)
s.firstAccept.Do(func() {
if err := s.onAccept(); err != nil {
log.G(ctx).Errorf("accept function: %s", err)
return
}
})

initialConn, err := s.connect(ctx, port)
backendConn, err := s.connect(ctx, port)
if err != nil {
log.G(ctx).Errorf("error establishing connection: %s", err)
return
}
defer initialConn.Close()
defer backendConn.Close()

log.G(ctx).Println("dial succeeded", initialConn.RemoteAddr().String())
log.G(ctx).Println("dial succeeded", backendConn.RemoteAddr().String())

requestContext, cancel := context.WithTimeout(ctx, s.proxyTimeout)
s.proxyCancel = cancel
defer cancel()
if err := proxy(requestContext, conn, initialConn); err != nil {
if err := proxy(requestContext, conn, backendConn); err != nil {
log.G(ctx).Errorf("error proxying request: %s", err)
}

if err := s.removeConnection(uint16(tcpAddr.Port)); err != nil {
log.G(ctx).Errorf("error removing connection: %s", err)
return
}

log.G(ctx).Println("connection closed", conn.RemoteAddr().String())
}

func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
var initialConn net.Conn
var err error
var backendConn net.Conn

ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
Expand All @@ -237,7 +240,29 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
}

if err := s.ns.Do(func(_ ns.NetNS) error {
initialConn, err = net.DialTimeout("tcp4", fmt.Sprintf("localhost:%d", port), s.connectTimeout)
// to ensure we don't create a redirect loop we need to know
// the local port of our connection to the activated process.
// We reserve a free port, store it in the disable bpf map and
// then use it to make the connection.
backendConnPort, err := freePort()
if err != nil {
return fmt.Errorf("unable to get free port: %w", err)
}

log.G(ctx).Debugf("registering backend connection port %d in bpf map", backendConnPort)
if err := s.disableRedirect(uint16(backendConnPort)); err != nil {
return err
}

addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("127.0.0.1:%d", backendConnPort))
if err != nil {
return err
}
d := net.Dialer{
LocalAddr: addr,
Timeout: s.connectTimeout,
}
backendConn, err = d.Dial("tcp4", fmt.Sprintf("localhost:%d", port))
return err
}); err != nil {
var serr syscall.Errno
Expand All @@ -248,7 +273,7 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
return nil, fmt.Errorf("unable to connect to process: %s", err)
}

return initialConn, nil
return backendConn, nil
}
}
}
Expand Down Expand Up @@ -281,3 +306,21 @@ func copy(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) {
errors <- err
}
}

func freePort() (int, error) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return 0, err
}

addr, ok := listener.Addr().(*net.TCPAddr)
if !ok {
return 0, fmt.Errorf("addr is not a net.TCPAddr: %T", listener.Addr())
}

if err := listener.Close(); err != nil {
return 0, err
}

return addr.Port, nil
}
93 changes: 30 additions & 63 deletions activator/activator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
Expand All @@ -16,67 +15,58 @@ import (
"github.com/ctrox/zeropod/socket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netns"
)

func TestActivator(t *testing.T) {
t.Skip("broken for the time being")
require.NoError(t, socket.MountBPFFS(socket.BPFFSPath))

newns, err := netns.NewNamed("test")
require.NoError(t, err)
defer newns.Close()

nn, err := ns.GetCurrentNS()
require.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())

port, _, err := getFreePorts()
port, err := freePort()
require.NoError(t, err)

s, err := NewServer(ctx, []uint16{uint16(port)}, nn)
s, err := NewServer(ctx, []uint16{uint16(port)}, nn, "lo")
require.NoError(t, err)

response := "ok"
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, response)
}))

err = s.Start(ctx,
func() error {
l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatal(err)
}

// NewUnstartedServer creates a listener. Close that listener and replace
// with the one we created.
ts.Listener.Close()
ts.Listener = l
ts.Start()
log.Printf("listening on :%d", port)

t.Cleanup(func() {
ts.Close()
})

if err := s.DisableRedirects(); err != nil {
return fmt.Errorf("could not disable redirects: %w", err)
}

return nil
},
)
require.NoError(t, err)
defer s.Stop(ctx)
defer cancel()
err = s.Start(ctx, func() error {
// simulate a delay until our server is started
time.Sleep(time.Millisecond * 200)
l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
require.NoError(t, err)

if err := s.DisableRedirects(); err != nil {
return fmt.Errorf("could not disable redirects: %w", err)
}

time.Sleep(time.Hour)
// replace listener of server
ts.Listener.Close()
ts.Listener = l
ts.Start()
t.Logf("listening on :%d", port)

t.Cleanup(func() {
ts.Close()
})

return nil
})
require.NoError(t, err)
t.Cleanup(func() {
s.Stop(ctx)
cancel()
})

c := &http.Client{Timeout: time.Second}

parallelReqs := 6
parallelReqs := 10
wg := sync.WaitGroup{}
for _, port := range []int{port} {
port := port
Expand All @@ -89,34 +79,11 @@ func TestActivator(t *testing.T) {
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, response, string(b))
t.Log(string(b))
}()
}
}
wg.Wait()
}

func getFreePorts() (int, int, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, 0, err
}
listener2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, 0, err
}

port := listener.Addr().(*net.TCPAddr).Port
port2 := listener2.Addr().(*net.TCPAddr).Port

if err := listener.Close(); err != nil {
return 0, 0, err
}

if err := listener2.Close(); err != nil {
return 0, 0, err
}

return port, port2, nil
}
17 changes: 13 additions & 4 deletions activator/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ func initBPF(ifaces ...string) (*bpfObjects, func(), error) {
Handle: 1,
Protocol: unix.ETH_P_ALL,
},
Fd: objs.TcRedirector.FD(),
Name: objs.TcRedirector.String(),
Fd: objs.TcRedirectIngress.FD(),
Name: objs.TcRedirectIngress.String(),
DirectAction: true,
}
egress := ingress
egress.Parent = netlink.HANDLE_MIN_EGRESS
egress.Fd = objs.TcRedirectEgress.FD()
egress.Name = objs.TcRedirectEgress.String()

if err := netlink.FilterReplace(&ingress); err != nil {
return nil, nil, fmt.Errorf("failed to replace tc filter: %w", err)
Expand Down Expand Up @@ -111,10 +113,10 @@ func pinPath() string {

// RedirectPort redirects the port from to on ingress and to from on egress.
func (a *Server) RedirectPort(from, to uint16) error {
if err := a.bpfObjs.Redirects.Put(&from, &to); err != nil {
if err := a.bpfObjs.IngressRedirects.Put(&from, &to); err != nil {
return fmt.Errorf("unable to put ports %d -> %d into bpf map: %w", from, to, err)
}
if err := a.bpfObjs.Redirects.Put(&to, &from); err != nil {
if err := a.bpfObjs.EgressRedirects.Put(&to, &from); err != nil {
return fmt.Errorf("unable to put ports %d -> %d into bpf map: %w", to, from, err)
}
return nil
Expand All @@ -127,6 +129,13 @@ func (a *Server) registerConnection(port uint16) error {
return nil
}

func (a *Server) removeConnection(port uint16) error {
if err := a.bpfObjs.ActiveConnections.Delete(&port); err != nil {
return fmt.Errorf("unable to delete port %d in bpf map: %w", port, err)
}
return nil
}

func (a *Server) disableRedirect(port uint16) error {
if err := a.bpfObjs.DisableRedirect.Put(&port, uint8(1)); err != nil {
return fmt.Errorf("unable to put %d into bpf map: %w", port, err)
Expand Down
Loading

0 comments on commit 91c67d7

Please sign in to comment.