diff --git a/hopper.go b/hopper.go index dfb44bc..5e6218e 100644 --- a/hopper.go +++ b/hopper.go @@ -43,12 +43,12 @@ const ( ) type ( - // Listener defines a server which will be waiting to accept incoming connections + // Listener defines a server which will be waiting to accept incoming UDP connections Listener struct { logger *log.Logger // logger crypterIn BlockCrypt // crypter for incoming packets crypterOut BlockCrypt // crypter for outgoing packets - conn *net.UDPConn // the underlying packet connection + conn *net.UDPConn // the socket to listen on timeout time.Duration // session timeout sockbuf int // socket buffer size @@ -110,7 +110,7 @@ func ListenWithOptions(laddr string, target string, sockbuf int, timeout time.Du return l, nil } -// Start the listener +// Start the listener and wait until it's closed, it returns when the socket is closed. func (l *Listener) Start() { go l.switcher() @@ -138,10 +138,6 @@ func (l *Listener) packetIn(data []byte, raddr net.Addr) { } if packetOk { - l.incomingConnectionsLock.RLock() - conn, ok := l.incomingConnections[raddr.String()] - l.incomingConnectionsLock.RUnlock() - // encrypt or re-encrypt the packet if crypterOut is set(with new nonce) if l.crypterOut != nil { dataOut := make([]byte, len(data)+nonceSize) @@ -151,6 +147,11 @@ func (l *Listener) packetIn(data []byte, raddr net.Addr) { data = dataOut } + // load the connection from the incoming connections + l.incomingConnectionsLock.RLock() + conn, ok := l.incomingConnections[raddr.String()] + l.incomingConnectionsLock.RUnlock() + if ok { // existing connection l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout)) } else { // new connection @@ -161,24 +162,22 @@ func (l *Listener) packetIn(data []byte, raddr net.Addr) { return } + // add the connection to the incoming connections + l.addClient(raddr, conn) // log new connection log.Printf("new connection from %s to %s", raddr.String(), l.nextHop) + // watch the connection // the context is the address of incoming packet - // register the address ctx := raddr - l.incomingConnectionsLock.Lock() - l.incomingConnections[raddr.String()] = conn - l.incomingConnectionsLock.Unlock() - - // watch the connection l.watcher.ReadTimeout(ctx, conn, make([]byte, mtuLimit), time.Now().Add(l.timeout)) l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout)) // write needs not to specify the context(where the packet from) } } } -// packet switcher from clients to targets +// switcher handles the proxy connections to the next hop. +// It acts like a proxy multiplexer. func (l *Listener) switcher() { for { results, err := l.watcher.WaitIO() @@ -190,49 +189,58 @@ func (l *Listener) switcher() { for _, res := range results { switch res.Operation { case gaio.OpWrite: - // write to target complete + // done writting to proxy connection. if res.Error != nil { - l.logger.Printf("gaio write error: %+v", res) - l.cleanClient(res.Conn.RemoteAddr()) + l.logger.Printf("[switcher]write error: %#v", res) + l.removeClient(res.Conn.RemoteAddr()) continue } case gaio.OpRead: - if res.Error != nil { // any error discontinues the connection - l.logger.Printf("gaio read error: %+v", res) - l.cleanClient(res.Conn.RemoteAddr()) + // any read error from the proxy connection cleans the other side(client). + if res.Error != nil { + l.logger.Printf("[switcher]read error: %#v", res) + l.removeClient(res.Conn.RemoteAddr()) continue } - // received data from the next hop - dataFromTarget := res.Buffer[:res.Size] + // received data from the proxy connection. + dataFromProxy := res.Buffer[:res.Size] - // decrypt data from target if crypterOut is set + // decrypt data from the proxy connection if crypterOut is set. if l.crypterOut != nil { - l.crypterOut.Decrypt(dataFromTarget, dataFromTarget) - dataFromTarget = dataFromTarget[nonceSize:] + l.crypterOut.Decrypt(dataFromProxy, dataFromProxy) + dataFromProxy = dataFromProxy[nonceSize:] } - // re-encrypt data to client if crypterIn is set + // re-encrypt data if crypterIn is set. if l.crypterIn != nil { - data := make([]byte, len(dataFromTarget)+nonceSize) - copy(data[nonceSize:], dataFromTarget) + data := make([]byte, len(dataFromProxy)+nonceSize) + copy(data[nonceSize:], dataFromProxy) _, _ = io.ReadFull(rand.Reader, data[:nonceSize]) l.crypterIn.Encrypt(data, data) - dataFromTarget = data + dataFromProxy = data } - // forward data to client - l.conn.WriteTo(dataFromTarget, res.Context.(net.Addr)) + // forward the data to client via the listener. + l.conn.WriteTo(dataFromProxy, res.Context.(net.Addr)) - // fire another read-request to the connection + // fire next read-request to the proxy connection. l.watcher.ReadTimeout(res.Context, res.Conn, make([]byte, mtuLimit), time.Now().Add(l.timeout)) } } } } -func (l *Listener) cleanClient(raddr net.Addr) { +// addClient adds the client to the incoming connections map. +func (l *Listener) addClient(raddr net.Addr, conn net.Conn) { + l.incomingConnectionsLock.Lock() + l.incomingConnections[raddr.String()] = conn + l.incomingConnectionsLock.Unlock() +} + +// removeClient removes the client from the incoming connections map. +func (l *Listener) removeClient(raddr net.Addr) { l.incomingConnectionsLock.Lock() delete(l.incomingConnections, raddr.String()) l.incomingConnectionsLock.Unlock()