Skip to content

Commit

Permalink
Fix the issue with corrupting source connection
Browse files Browse the repository at this point in the history
  • Loading branch information
ameshkov committed Feb 2, 2024
1 parent 81f5ce2 commit 869234a
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions internal/pipe/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ package pipe

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -318,20 +318,27 @@ func (s *Server) serveConn(conn net.Conn) {
s.dstConns[dstConn] = struct{}{}
}()

if s.serverMode && !s.checkAuth(conn) {
// Client connection has not been authorized, closing the connection.
return
var srcRw, dstRw io.ReadWriter
srcRw = conn
dstRw = dstConn

if s.serverMode {
ok, newRw := s.checkAuth(conn)
if !ok {
// Client connection has not been authorized, closing the connection.
return
}

// Replace the source reader since some bytes in the original srcRw
// may have been read as a part of authentication process.
srcRw = newRw
}

if !s.serverMode {
// Authorize the client if necessary.
s.auth(dstConn)
}

var srcRw, dstRw io.ReadWriter
srcRw = conn
dstRw = dstConn

// When the client communicates with the server it uses encoded messages so
// connection between them needs to be wrapped. In server mode it is the
// source connection, in client mode it is the destination connection.
Expand All @@ -357,10 +364,12 @@ func (s *Server) auth(dstRw io.ReadWriter) {
// checkAuth checks the first bytes sent by the client and looks for the
// password there. It also implements the active probing protection by detecting
// HTTP requests and returning a default stub HTTP response if detected.
func (s *Server) checkAuth(srcConn net.Conn) (ok bool) {
// The function returns an io.ReadWriter that should be used further to work
// with this connection.
func (s *Server) checkAuth(srcConn net.Conn) (ok bool, rw io.ReadWriter) {
if s.password == "" {
// No authentication and probing checks.
return true
return true, srcConn
}

// Give up to 60 seconds on the authentication.
Expand All @@ -370,19 +379,34 @@ func (s *Server) checkAuth(srcConn net.Conn) (ok bool) {
_ = srcConn.SetReadDeadline(time.Time{})
}()

r := textproto.NewReader(bufio.NewReader(srcConn))
// bufio.Reader may read more than requested, so it's crucial to use
// TeeReader so that we could restore the bytes that has been read.
var buf bytes.Buffer
r := bufio.NewReader(io.TeeReader(srcConn, &buf))

line, err := r.ReadLine()
lineBytes, err := r.ReadBytes('\n')
if err != nil {
log.Debug("Could not read password from the first bytes: %v", err)

return false
return false, srcConn
}
line := strings.TrimSpace(string(lineBytes))

if s.password == line {
log.Debug("Authentication successful")

return true
// Skip the line that contains the password, we don't need it anymore.
_, _ = buf.ReadBytes('\n')

// Now that authentication has been successful, return a new
// io.ReadWriter that restores the first bytes save for the password
// bytes.
rw = &multiReadWriter{
Reader: io.MultiReader(bytes.NewReader(buf.Bytes()), srcConn),
Writer: srcConn,
}

return true, rw
}

log.Debug("Authentication unsuccessful, check if probing detection is required")
Expand All @@ -391,7 +415,7 @@ func (s *Server) checkAuth(srcConn net.Conn) (ok bool) {
requestURI, proto, ok2 := strings.Cut(rest, " ")
if !ok1 || !ok2 || !strings.HasPrefix(proto, "HTTP/1") {
// Not HTTP protocol for sure, existing right away.
return false
return false, srcConn
}

log.Debug("Detected HTTP: %s %s %s", method, requestURI, proto)
Expand All @@ -414,7 +438,7 @@ func (s *Server) checkAuth(srcConn net.Conn) (ok bool) {
// Writing the stub response.
_, _ = srcConn.Write([]byte(response))

return false
return false, srcConn
}

// multiReadWriter is a helper object that's used for replacing io.ReadWriter
Expand Down

0 comments on commit 869234a

Please sign in to comment.