Skip to content
This repository has been archived by the owner on Mar 29, 2024. It is now read-only.

Commit

Permalink
Add support for binding Piers to specific addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
dhaavi committed Oct 11, 2023
1 parent 5a0145a commit d208af1
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 83 deletions.
2 changes: 1 addition & 1 deletion captain/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func loadPublicIdentity() (err error) {
publicIdentity.Hub.Info.IPv6 != nil,
)
if cfgOptionBindToAdvertised() {
conf.SetConnectAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
conf.SetBindAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
}

// Set Home Hub before updating the hub on the map, as this would trigger a
Expand Down
83 changes: 56 additions & 27 deletions conf/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,53 +29,82 @@ func HubHasIPv6() bool {
}

var (
connectIPv4 net.IP
connectIPv6 net.IP
connectIPLock sync.Mutex
bindIPv4 net.IP
bindIPv6 net.IP
bindIPLock sync.Mutex
)

// SetConnectAddr sets the preferred connect (bind) addresses.
func SetConnectAddr(ip4, ip6 net.IP) {
connectIPLock.Lock()
defer connectIPLock.Unlock()
// SetBindAddr sets the preferred connect (bind) addresses.
func SetBindAddr(ip4, ip6 net.IP) {
bindIPLock.Lock()
defer bindIPLock.Unlock()

connectIPv4 = ip4
connectIPv6 = ip6
bindIPv4 = ip4
bindIPv6 = ip6
}

// GetConnectAddr returns an address with the preferred connect (bind)
// addresses for the given dial network.
// The dial network must have a suffix specify the IP version.
func GetConnectAddr(dialNetwork string) net.Addr {
connectIPLock.Lock()
defer connectIPLock.Unlock()
// BindAddrIsSet returns whether any bind address is set.
func BindAddrIsSet() bool {
bindIPLock.Lock()
defer bindIPLock.Unlock()

return bindIPv4 != nil || bindIPv6 != nil
}

// GetBindAddr returns an address with the preferred binding address for the
// given dial network.
// The dial network must have a suffix specifying the IP version.
func GetBindAddr(dialNetwork string) net.Addr {
bindIPLock.Lock()
defer bindIPLock.Unlock()

switch dialNetwork {
case "ip4":
if connectIPv4 != nil {
return &net.IPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.IPAddr{IP: bindIPv4}
}
case "ip6":
if connectIPv6 != nil {
return &net.IPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.IPAddr{IP: bindIPv6}
}
case "tcp4":
if connectIPv4 != nil {
return &net.TCPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.TCPAddr{IP: bindIPv4}
}
case "tcp6":
if connectIPv6 != nil {
return &net.TCPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.TCPAddr{IP: bindIPv6}
}
case "udp4":
if connectIPv4 != nil {
return &net.UDPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.UDPAddr{IP: bindIPv4}
}
case "udp6":
if connectIPv6 != nil {
return &net.UDPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.UDPAddr{IP: bindIPv6}
}
}

return nil
}

// GetBindIPs returns the preferred binding IPs.
// Returns a slice with a single nil IP if no preferred binding IPs are set.
func GetBindIPs() []net.IP {
bindIPLock.Lock()
defer bindIPLock.Unlock()

switch {
case bindIPv4 == nil && bindIPv6 == nil:
// Match most common case first.
return []net.IP{nil}
case bindIPv4 != nil && bindIPv6 != nil:
return []net.IP{bindIPv4, bindIPv6}
case bindIPv4 != nil:
return []net.IP{bindIPv4}
case bindIPv6 != nil:
return []net.IP{bindIPv6}
}

return []net.IP{nil}
}
13 changes: 11 additions & 2 deletions crew/op_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (op *ConnectOp) setup(session *terminal.Session) {
}
dialer := &net.Dialer{
Timeout: 10 * time.Second,
LocalAddr: conf.GetConnectAddr(dialNet),
LocalAddr: conf.GetBindAddr(dialNet),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down Expand Up @@ -410,6 +410,8 @@ func (op *ConnectOp) connWriter(_ context.Context) error {
}()

defer func() {
// Signal that we are done with writing.
close(op.doneWriting)
// Close connection.
_ = op.conn.Close()
}()
Expand Down Expand Up @@ -522,7 +524,14 @@ func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Erro
// If the op was ended remotely, write all remaining received data.
// If the op was ended locally, don't bother writing remaining data.
if err.IsExternal() {
<-op.doneWriting
select {
case <-op.doneWriting:
default:
select {
case <-op.doneWriting:
case <-time.After(5 * time.Second):
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion patrol/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func CheckHTTPSConnection(ctx context.Context, network, domain string) (statusCo
}
dialer := &net.Dialer{
Timeout: 15 * time.Second,
LocalAddr: conf.GetConnectAddr(network),
LocalAddr: conf.GetBindAddr(network),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down
5 changes: 2 additions & 3 deletions ships/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func launchHTTPShip(ctx context.Context, transport *hub.Transport, ip net.IP) (S
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
LocalAddr: conf.GetConnectAddr(dialNet),
LocalAddr: conf.GetBindAddr(dialNet),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down Expand Up @@ -209,11 +209,10 @@ func establishHTTPPier(transport *hub.Transport, dockingRequests chan Ship) (Pie
pier.initBase()

// Register handler.
listener, err := addHTTPHandler(transport.Port, path, pier.ServeHTTP)
err := addHTTPHandler(transport.Port, path, pier.ServeHTTP)
if err != nil {
return nil, fmt.Errorf("failed to add HTTP handler: %w", err)
}
pier.listener = listener

return pier, nil
}
Expand Down
61 changes: 38 additions & 23 deletions ships/http_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"net/http"
"sync"
"time"

"github.com/safing/portbase/log"
"github.com/safing/spn/conf"
)

type sharedServer struct {
listener net.Listener
server *http.Server
server *http.Server

handlers map[string]http.HandlerFunc
handlersLock sync.RWMutex
Expand Down Expand Up @@ -45,10 +47,10 @@ var (
sharedHTTPServersLock sync.Mutex
)

func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.Listener, err error) {
func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error {
// Check params.
if port == 0 {
return nil, errors.New("cannot listen on port 0")
return errors.New("cannot listen on port 0")
}

// Default to root path.
Expand All @@ -69,12 +71,12 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.
// Check if path is already registered.
_, ok := shared.handlers[path]
if ok {
return nil, errors.New("path already registered")
return errors.New("path already registered")
}

// Else, register handler at path.
shared.handlers[path] = handler
return shared.listener, nil
return nil
}

// Shared server does not exist - create one.
Expand All @@ -99,28 +101,41 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.
}
shared.server = server

// Start listener.
shared.listener, err = net.Listen("tcp", server.Addr)
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
// Start listeners.
bindIPs := conf.GetBindIPs()
listeners := make([]net.Listener, 0, len(bindIPs))
for _, bindIP := range bindIPs {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: bindIP,
Port: int(port),
})
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

listeners = append(listeners, listener)
log.Infof("spn/ships: http transport pier established on %s", listener.Addr())
}

// Add shared http server to list.
sharedHTTPServers[port] = shared

// Start server in service worker.
module.StartServiceWorker(
fmt.Sprintf("shared http server listener on port %d", port), 0,
func(ctx context.Context) error {
err := shared.server.Serve(shared.listener)
if !errors.Is(http.ErrServerClosed, err) {
return err
}
return nil
},
)

return shared.listener, nil
// Start servers in service workers.
for _, listener := range listeners {
serviceListener := listener
module.StartServiceWorker(
fmt.Sprintf("shared http server listener on %s", listener.Addr()), 0,
func(ctx context.Context) error {
err := shared.server.Serve(serviceListener)
if !errors.Is(http.ErrServerClosed, err) {
return err
}
return nil
},
)
}

return nil
}

func removeHTTPHandler(port uint16, path string) error {
Expand Down
8 changes: 4 additions & 4 deletions ships/http_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ func TestSharedHTTP(t *testing.T) { //nolint:paralleltest // Test checks global
const testPort = 65100

// Register multiple handlers.
_, err := addHTTPHandler(testPort, "", ServeInfoPage)
err := addHTTPHandler(testPort, "", ServeInfoPage)
assert.NoError(t, err, "should be able to share http listener")
_, err = addHTTPHandler(testPort, "/test", ServeInfoPage)
err = addHTTPHandler(testPort, "/test", ServeInfoPage)
assert.NoError(t, err, "should be able to share http listener")
_, err = addHTTPHandler(testPort, "/test2", ServeInfoPage)
err = addHTTPHandler(testPort, "/test2", ServeInfoPage)
assert.NoError(t, err, "should be able to share http listener")
_, err = addHTTPHandler(testPort, "/", ServeInfoPage)
err = addHTTPHandler(testPort, "/", ServeInfoPage)
assert.Error(t, err, "should fail to register path twice")

// Unregister
Expand Down
16 changes: 5 additions & 11 deletions ships/pier.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ type Pier interface {
// Transport returns the transport used for this ship.
Transport() *hub.Transport

// Addr returns the underlying network address used by the listener.
Addr() net.Addr

// Abolish closes the underlying listener and cleans up any related resources.
Abolish()
}
Expand Down Expand Up @@ -50,8 +47,8 @@ func EstablishPier(transport *hub.Transport, dockingRequests chan Ship) (Pier, e
type PierBase struct {
// transport holds the transport definition of the pier.
transport *hub.Transport
// listener is the actual underlying listener.
listener net.Listener
// listeners holds the actual underlying listeners.
listeners []net.Listener

// dockingRequests is used to report new connections to the higher layer.
dockingRequests chan Ship
Expand All @@ -75,14 +72,11 @@ func (pier *PierBase) Transport() *hub.Transport {
return pier.transport
}

// Addr returns the underlying network address used by the listener.
func (pier *PierBase) Addr() net.Addr {
return pier.listener.Addr()
}

// Abolish closes the underlying listener and cleans up any related resources.
func (pier *PierBase) Abolish() {
if pier.abolishing.SetToIf(false, true) {
_ = pier.listener.Close()
for _, listener := range pier.listeners {
_ = listener.Close()
}
}
}
Loading

0 comments on commit d208af1

Please sign in to comment.