diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 6f06c0424..ae44e582d 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "sync" "sync/atomic" "syscall" "unsafe" @@ -25,6 +26,7 @@ import ( type StdConn struct { sysFd int closed atomic.Bool + wg *sync.WaitGroup isV4 bool l *logrus.Logger batch int @@ -81,7 +83,14 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{ + sysFd: fd, + closed: atomic.Bool{}, + wg: &sync.WaitGroup{}, + isV4: ip.Is4(), + l: l, + batch: batch, + }, err } func (u *StdConn) Rebind() error { @@ -123,6 +132,15 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + + u.wg.Add(1) + defer func() { + u.wg.Done() + }() + if u.closed.Load() { + return + } + plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -144,7 +162,7 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - if u.closed { + if u.closed.Load() { u.l.Debug("flag for closing connection is set, exiting read loop") return } @@ -321,11 +339,20 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop - if u.closed { + if !u.closed.CompareAndSwap(false, true) { + // already closed by e.g. other thread return nil } - u.closed = true + err := syscall.Shutdown(u.sysFd, syscall.SHUT_RDWR) + if err != nil { + errno, ok := err.(syscall.Errno) + // connection might have been terminated by remote before + wasDisconnected := ok && (errno == syscall.ENOTCONN) + if !wasDisconnected { + panic(fmt.Sprintf("error while shutdown of UDP socket: %v", err)) + } + } + u.wg.Wait() return syscall.Close(u.sysFd) }