diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..eb1a77f --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,104 @@ +name: Build + +'on': + 'push': + 'tags': + - 'v*' + 'branches': + - '*' + 'pull_request': + +jobs: + tests: + runs-on: ${{ matrix.os }} + env: + GO111MODULE: "on" + strategy: + matrix: + os: + - windows-latest + - macos-latest + - ubuntu-latest + + steps: + - uses: actions/checkout@master + + - uses: actions/setup-go@v3 + with: + go-version: 1.x + + - name: Run tests + run: |- + go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload coverage + uses: codecov/codecov-action@v3 + if: "success() && matrix.os == 'ubuntu-latest'" + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.txt + + build: + needs: + - tests + runs-on: ubuntu-latest + env: + GO111MODULE: "on" + steps: + - uses: actions/checkout@master + + - uses: actions/setup-go@v3 + with: + go-version: 1.x + + - name: Prepare environment + run: |- + RELEASE_VERSION="${GITHUB_REF##*/}" + if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi + echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV + + # Win + - run: GOOS=windows GOARCH=386 VERSION=${RELEASE_VERSION} make release + - run: GOOS=windows GOARCH=amd64 VERSION=${RELEASE_VERSION} make release + + # MacOS + - run: GOOS=darwin GOARCH=amd64 VERSION=${RELEASE_VERSION} make release + + # MacOS ARM + - run: GOOS=darwin GOARCH=arm64 VERSION=${RELEASE_VERSION} make release + + # Linux X86 + - run: GOOS=linux GOARCH=386 VERSION=${RELEASE_VERSION} make release + - run: GOOS=linux GOARCH=amd64 VERSION=${RELEASE_VERSION} make release + + # Linux ARM + - run: GOOS=linux GOARCH=arm GOARM=6 VERSION=${RELEASE_VERSION} make release + - run: GOOS=linux GOARCH=arm64 VERSION=${RELEASE_VERSION} make release + + # Linux MIPS/MIPSLE + - run: GOOS=linux GOARCH=mips GOMIPS=softfloat VERSION=${RELEASE_VERSION} make release + - run: GOOS=linux GOARCH=mipsle GOMIPS=softfloat VERSION=${RELEASE_VERSION} make release + + # FreeBSD X86 + - run: GOOS=freebsd GOARCH=386 VERSION=${RELEASE_VERSION} make release + - run: GOOS=freebsd GOARCH=amd64 VERSION=${RELEASE_VERSION} make release + + # FreeBSD ARM/ARM64 + - run: GOOS=freebsd GOARCH=arm GOARM=6 VERSION=${RELEASE_VERSION} make release + - run: GOOS=freebsd GOARCH=arm64 VERSION=${RELEASE_VERSION} make release + + - run: ls -l build/udptlspipe-* + + - name: Create release + if: startsWith(github.ref, 'refs/tags/v') + id: create_release + uses: softprops/action-gh-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + body: See [CHANGELOG.md](./CHANGELOG.md) for the list of changes. + draft: false + prerelease: false + files: | + build/udptlspipe-*.tar.gz + build/udptlspipe-*.zip diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..22d7348 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,26 @@ +name: golangci-lint +'on': + 'push': + 'tags': + - 'v*' + 'branches': + - '*' + 'pull_request': + +jobs: + golangci: + runs-on: + ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2.3.0 + with: + # This field is required. Dont set the patch version to always use + # the latest patch version. + version: v1.54.1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8875a15 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +.idea +.vscode +build +udptlspipe diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..6b20c8e --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,66 @@ +# options for analysis running +run: + # default concurrency is a available CPU number + concurrency: 4 + + # timeout for analysis, e.g. 30s, 5m, default is 1m + deadline: 2m + + # which files to skip: they will be analyzed, but issues from them + # won't be reported. Default value is empty list, but there is + # no need to include all autogenerated files, we confidently recognize + # autogenerated files. If it's not please let us know. + skip-files: + - ".*generated.*" + - ".*_test.go" + +# all available settings of specific linters +linters-settings: + gocyclo: + min-complexity: 20 + lll: + line-length: 200 + +linters: + enable: + - errcheck + - govet + - ineffassign + - staticcheck + - unused + - dupl + - gocyclo + - goimports + - revive + - gosec + - misspell + - stylecheck + - unconvert + disable-all: true + fast: true + +issues: + exclude-use-default: false + + # List of regexps of issue texts to exclude, empty list by default. + # But independently from this option we use default exclude patterns, + # it can be disabled by `exclude-use-default: false`. To list all + # excluded by default patterns execute `golangci-lint run --help` + exclude: + # errcheck defer Close + - error return value not checked \(defer .*\.Close()\) + # errcheck: Almost all programs ignore errors on these functions and in most cases it's ok + - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv). is not checked + # gosec: Duplicated errcheck checks + - G104 + # gosec: Expect file permissions to be 0600 or less + - G302 + # errcheck defer Close + - error return value not checked \(defer .*\.Close()\) + # gosec: False positive is triggered by 'src, err := os.ReadFile(filename)' + - Potential file inclusion via variable + # gosec: TLS InsecureSkipVerify may be true + # We have a configuration option that allows to do this + - G402 + # gosec: Use of weak random number generator + - G404 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..cce4ca7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,22 @@ +# udptlspipe changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog][keepachangelog], and this project +adheres to [Semantic Versioning][semver]. + +[keepachangelog]: https://keepachangelog.com/en/1.0.0/ + +[semver]: https://semver.org/spec/v2.0.0.html + +## [Unreleased] + +[unreleased]: https://github.com/ameshkov/udptlspipe/compare/v1.0.0...HEAD + +## [1.0.0] - 2024-02-02 + +### Added + +* The first version with base functionality. + +[1.0.0]: https://github.com/ameshkov/udptlspipe/releases/tag/v1.0.0 \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e1bcc4d --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +NAME=udptlspipe +BASE_BUILDDIR=build +BUILDNAME=$(GOOS)-$(GOARCH) +BUILDDIR=$(BASE_BUILDDIR)/$(BUILDNAME) +VERSION?=v0.0-dev +VERSIONPKG=github.com/ameshkov/udptlspipe/internal/version + +ifeq ($(GOOS),windows) + ext=.exe + archiveCmd=zip -9 -r $(NAME)-$(BUILDNAME)-$(VERSION).zip $(BUILDNAME) +else + ext= + archiveCmd=tar czpvf $(NAME)-$(BUILDNAME)-$(VERSION).tar.gz $(BUILDNAME) +endif + +.PHONY: default +default: build + +build: clean + go build -ldflags "-X $(VERSIONPKG).version=$(VERSION)" + +release: check-env-release + mkdir -p $(BUILDDIR) + cp LICENSE $(BUILDDIR)/ + cp README.md $(BUILDDIR)/ + CGO_ENABLED=0 GOOS=$(GOOS) GOARCH=$(GOARCH) go build -ldflags "-X $(VERSIONPKG).version=$(VERSION)" -o $(BUILDDIR)/$(NAME)$(ext) + cd $(BASE_BUILDDIR) ; $(archiveCmd) + +test: + go test -race -v -bench=. ./... + +clean: + go clean + rm -rf $(BASE_BUILDDIR) + +check-env-release: + @ if [ "$(GOOS)" = "" ]; then \ + echo "Environment variable GOOS not set"; \ + exit 1; \ + fi + @ if [ "$(GOARCH)" = "" ]; then \ + echo "Environment variable GOARCH not set"; \ + exit 1; \ + fi diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8457b73 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/ameshkov/udptlspipe + +go 1.21.6 + +require ( + github.com/AdguardTeam/golibs v0.20.0 + github.com/jessevdk/go-flags v1.5.0 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.15.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..dc9b177 --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/AdguardTeam/golibs v0.20.0 h1:A9FIdYq7wUKhFYy3z+YZ/Aw5oFUYgW+xgaVAJ0pnnPY= +github.com/AdguardTeam/golibs v0.20.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go new file mode 100644 index 0000000..993e3a2 --- /dev/null +++ b/internal/cmd/cmd.go @@ -0,0 +1,39 @@ +// Package cmd is the entry point of the tool. +package cmd + +import ( + "errors" + "fmt" + "os" + + "github.com/AdguardTeam/golibs/log" + "github.com/ameshkov/udptlspipe/internal/version" + goFlags "github.com/jessevdk/go-flags" +) + +// Main is the entry point for the command-line tool. +func Main() { + if len(os.Args) == 2 && (os.Args[1] == "--version" || os.Args[1] == "-v") { + fmt.Printf("udptlspipe version: %s\n", version.Version()) + + os.Exit(0) + } + + o, err := parseOptions() + var flagErr *goFlags.Error + if errors.As(err, &flagErr) && flagErr.Type == goFlags.ErrHelp { + // This is a special case when we exit process here as we received + // --help. + os.Exit(0) + } + + if err != nil { + log.Error("Failed to parse args: %v", err) + + os.Exit(1) + } + + if o.Verbose { + log.SetLevel(log.DEBUG) + } +} diff --git a/internal/cmd/options.go b/internal/cmd/options.go new file mode 100644 index 0000000..622bd1d --- /dev/null +++ b/internal/cmd/options.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "fmt" + "os" + + goFlags "github.com/jessevdk/go-flags" +) + +// Options represents command-line arguments. +type Options struct { + // ServerMode controls whether the tool works in the server mode. + // By default, the tool will work in the client mode. + ServerMode bool `short:"s" long:"pipe" description:"Enables pipe mode." optional:"yes" optional-value:"true"` + + // ListenAddr is the address the tool will be listening to. If it's in the + // pipe mode, it will listen to tcp://, if it's in the client mode, it + // will listen to udp://. + ListenAddr string `short:"l" long:"listen" description:"Address the tool will be listening to." value-name:":"` + + // DestinationAddr is the address the tool will connect to. Depending on the + // mode (pipe or client) this address has different semantics. In the + // client mode this is the address of the udptlspipe pipe. In the pipe + // mode this is the address where the received traffic will be passed. + DestinationAddr string `short:"d" long:"destination" description:"Address the tool will connect to." value-name:":"` + + // Verbose defines whether we should write the DEBUG-level log or not. + Verbose bool `short:"v" long:"verbose" description:"Verbose output (optional)." optional:"yes" optional-value:"true"` +} + +// parseOptions parses os.Args and creates the Options struct. +func parseOptions() (o *Options, err error) { + opts := &Options{} + parser := goFlags.NewParser(opts, goFlags.Default|goFlags.IgnoreUnknown) + remainingArgs, err := parser.ParseArgs(os.Args[1:]) + if err != nil { + return nil, err + } + + if len(remainingArgs) > 0 { + return nil, fmt.Errorf("unknown arguments: %v", remainingArgs) + } + + return opts, nil +} diff --git a/internal/pipe/server.go b/internal/pipe/server.go new file mode 100644 index 0000000..bc56f44 --- /dev/null +++ b/internal/pipe/server.go @@ -0,0 +1,340 @@ +// Package pipe implements the pipe logic, i.e. listening for TLS or UDP +// connections and proxying data to the target destination. +package pipe + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "time" + + "github.com/ameshkov/udptlspipe/internal/udp" + + "github.com/AdguardTeam/golibs/log" + "github.com/ameshkov/udptlspipe/internal/tunnel" +) + +// Server represents an udptlspipe pipe. Depending on whether it is created in +// server- or client- mode, it listens to TLS or UDP connections and pipes the +// data to the destination. +type Server struct { + listenAddr string + destinationAddr string + serverMode bool + + // listen is the TLS listener for incoming connections + listen net.Listener + + // srcConns is a set that is used to track active incoming TCP connections. + srcConns map[net.Conn]struct{} + srcConnsMu *sync.Mutex + + // dstConns is a set that is used to track active connections to the proxy + // destination. + dstConns map[net.Conn]struct{} + dstConnsMu *sync.Mutex + + // Shutdown handling + // -- + + // lock protects started, tcpListener and udpListener. + lock sync.RWMutex + started bool + // wg tracks active workers. Stop won't finish until there is at least + // won't finish until there's at least one active worker. + wg sync.WaitGroup +} + +// NewServer creates a new instance of a *Server. +func NewServer(listenAddr string, destinationAddr string, serverMode bool) (s *Server, err error) { + return &Server{ + listenAddr: listenAddr, + destinationAddr: destinationAddr, + serverMode: serverMode, + srcConns: map[net.Conn]struct{}{}, + srcConnsMu: &sync.Mutex{}, + dstConns: map[net.Conn]struct{}{}, + dstConnsMu: &sync.Mutex{}, + }, nil +} + +// Addr returns the address the pipe listens to if it is started or nil. +func (s *Server) Addr() (addr net.Addr) { + if s.listen == nil { + return nil + } + + return s.listen.Addr() +} + +// Start starts the pipe, exits immediately if it failed to start +// listening. Start returns once all servers are considered up. +func (s *Server) Start() (err error) { + log.Info("Starting the pipe %s", s) + + s.lock.Lock() + defer s.lock.Unlock() + + if s.started { + return errors.New("pipe is already started") + } + + s.listen, err = s.createListener() + if err != nil { + return fmt.Errorf("failed to start pipe: %w", err) + } + + s.wg.Add(1) + go s.serve() + + s.started = true + log.Info("Server has been started") + + return nil +} + +// createListener creates a TLS listener in server mode and UDP listener in +// client mode. +func (s *Server) createListener() (l net.Listener, err error) { + if s.serverMode { + // Using a self-signed TLS certificate issued for example.org as the client + // will not verify the certificate anyways. + // TODO(ameshkov): Allow configuring the certificate. + tlsConfig := createServerTLSConfig("example.org") + l, err = tls.Listen("tcp", s.listenAddr, tlsConfig) + if err != nil { + return nil, err + } + } else { + l, err = udp.Listen("udp", s.listenAddr) + if err != nil { + return nil, err + } + } + + return l, nil +} + +// Shutdown stops the pipe and waits for all active connections to close. +func (s *Server) Shutdown(ctx context.Context) (err error) { + log.Info("Stopping the pipe %s", s) + + s.stopServeLoop() + + // Closing the udpConn thread. + log.OnCloserError(s.listen, log.DEBUG) + + // Closing active TCP connections. + s.closeConnections(s.srcConnsMu, s.srcConns) + + // Closing active UDP connections. + s.closeConnections(s.dstConnsMu, s.dstConns) + + // Wait until all worker threads finish working + err = s.waitShutdown(ctx) + + log.Info("Server has been stopped") + + return err +} + +// closeConnections closes all active connections. +func (s *Server) closeConnections(mu *sync.Mutex, conns map[net.Conn]struct{}) { + mu.Lock() + defer mu.Unlock() + + for c := range conns { + _ = c.SetReadDeadline(time.Unix(1, 0)) + + log.OnCloserError(c, log.DEBUG) + } +} + +// stopServeLoop sets the started flag to false thus stopping the serving loop. +func (s *Server) stopServeLoop() { + s.lock.Lock() + defer s.lock.Unlock() + + s.started = false +} + +// type check +var _ fmt.Stringer = (*Server)(nil) + +// String implements the fmt.Stringer interface for *Server. +func (s *Server) String() (str string) { + switch s.serverMode { + case true: + return fmt.Sprintf("tls://%s <-> udp://%s", s.listenAddr, s.destinationAddr) + default: + return fmt.Sprintf("udp://%s <-> tls://%s", s.listenAddr, s.destinationAddr) + } +} + +// serve implements the pipe logic, i.e. accepts new connections and tunnels +// data to the destination. +func (s *Server) serve() { + defer s.wg.Done() + defer log.OnPanicAndExit("serve", 1) + + defer log.OnCloserError(s.listen, log.DEBUG) + + for s.isStarted() { + err := s.acceptConn() + if err != nil { + if !s.isStarted() { + return + } + + log.Error("exit serve loop due to: %v", err) + + return + } + } +} + +// acceptConn accepts new incoming and tracks active connections. +func (s *Server) acceptConn() (err error) { + conn, err := s.listen.Accept() + if err != nil { + // This type of errors should not lead to stopping the pipe. + if errors.Is(os.ErrDeadlineExceeded, err) { + return nil + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return nil + } + + return err + } + + func() { + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.srcConns[conn] = struct{}{} + }() + + s.wg.Add(1) + go s.serveConn(conn) + + return nil +} + +// closeSrcConn closes the source connection and cleans up after it. +func (s *Server) closeSrcConn(conn net.Conn) { + log.OnCloserError(conn, log.DEBUG) + + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + delete(s.srcConns, conn) +} + +// closeDstConn closes the destination connection and cleans up after it. +func (s *Server) closeDstConn(conn net.Conn) { + // No destination connection opened yet, do nothing. + if conn != nil { + return + } + + log.OnCloserError(conn, log.DEBUG) + + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + delete(s.dstConns, conn) +} + +// serveConn processes incoming connections, opens the connection to the +// destination and tunnels data between both connections. +func (s *Server) serveConn(conn net.Conn) { + var dstConn net.Conn + + defer func() { + s.wg.Done() + + s.closeSrcConn(conn) + s.closeDstConn(dstConn) + }() + + dstConn, err := s.dialDst() + if err != nil { + log.Error("failed to connect to %s: %v", s.destinationAddr, err) + } + + func() { + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.dstConns[dstConn] = struct{}{} + }() + + var srcRw, dstRw io.ReadWriter + srcRw = conn + dstRw = dstConn + + // When the client communicates with the server it uses encoded messages so + // connection between them needs to be wrapped. In server mode it is the + // source connection, in client mode it is the destination connection. + if s.serverMode { + srcRw = tunnel.NewMsgReadWriter(srcRw) + } else { + dstRw = tunnel.NewMsgReadWriter(dstRw) + } + + tunnel.Tunnel(srcRw, dstRw) +} + +// dialDst creates a connection to the destination. Depending on the mode the +// server operates in, it is either a TLS connection or a UDP connection. +func (s *Server) dialDst() (conn net.Conn, err error) { + if s.serverMode { + return net.Dial("udp", s.destinationAddr) + } + + return tls.Dial("tcp", s.destinationAddr, &tls.Config{ + // TODO(ameshkov): Make verification possible. + InsecureSkipVerify: true, + }) +} + +// isStarted safely checks whether the pipe is started or not. +func (s *Server) isStarted() (started bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.started +} + +// waitShutdown waits either until context deadline OR Server.wg. +func (s *Server) waitShutdown(ctx context.Context) (err error) { + // Using this channel to wait until all goroutines finish their work. + closed := make(chan struct{}) + go func() { + defer log.OnPanic("waitShutdown") + + // Wait until all active workers finished its work. + s.wg.Wait() + close(closed) + }() + + var ctxErr error + select { + case <-closed: + // Do nothing here. + case <-ctx.Done(): + ctxErr = ctx.Err() + } + + return ctxErr +} diff --git a/internal/pipe/server_test.go b/internal/pipe/server_test.go new file mode 100644 index 0000000..4728914 --- /dev/null +++ b/internal/pipe/server_test.go @@ -0,0 +1,70 @@ +package pipe_test + +import ( + "context" + "fmt" + "io" + "net" + "testing" + + "github.com/AdguardTeam/golibs/log" + "github.com/ameshkov/udptlspipe/internal/pipe" + "github.com/ameshkov/udptlspipe/internal/testutil" + "github.com/stretchr/testify/require" +) + +func TestServer_Start_integration(t *testing.T) { + // Start the echo server where the pipe will proxy data. + udpSrv := &testutil.UDPEchoServer{} + err := udpSrv.Start() + require.NoError(t, err) + defer log.OnCloserError(udpSrv, log.DEBUG) + + // Create the pipe server proxying data to that UDP echo server. + pipeServer, err := pipe.NewServer("127.0.0.1:0", udpSrv.Addr(), true) + require.NoError(t, err) + + // Start the pipe server, it's ready to accept new connections. + err = pipeServer.Start() + require.NoError(t, err) + defer func() { + require.NoError(t, pipeServer.Shutdown(context.Background())) + }() + + // Now create the pipe client connected to that server. + pipeClient, err := pipe.NewServer("127.0.0.1:0", pipeServer.Addr().String(), false) + require.NoError(t, err) + + // Start the pipe client, it's ready to accept new connections. + err = pipeClient.Start() + require.NoError(t, err) + defer func() { + require.NoError(t, pipeServer.Shutdown(context.Background())) + }() + + // Connect to the pipe. + pipeConn, err := net.Dial("udp", pipeClient.Addr().String()) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + strMsg := fmt.Sprintf("test message %d", i) + msg := []byte(strMsg) + + // Write a message to the pipe. + _, err = pipeConn.Write(msg) + require.NoError(t, err) + + // Now read the response from the echo server. + buf := make([]byte, len(msg)) + _, err = io.ReadFull(pipeConn, buf) + require.NoError(t, err) + + // Check that the echo pipe response was received correctly. + require.Equal(t, msg, buf) + + // Now check that the echo pipe received the message correctly. + udpMsg := udpSrv.ReceivedMsg(i) + require.NotNil(t, udpMsg) + require.Equal(t, msg, udpMsg) + } +} diff --git a/internal/pipe/tlsconfig.go b/internal/pipe/tlsconfig.go new file mode 100644 index 0000000..a26b196 --- /dev/null +++ b/internal/pipe/tlsconfig.go @@ -0,0 +1,83 @@ +package pipe + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "time" +) + +// createServerTLSConfig creates a TLS pipe configuration with the specified +// pipe name. It returns a *tls.Config that will be used by the pipe. +func createServerTLSConfig(tlsServerName string) (tlsConfig *tls.Config) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(fmt.Sprintf("cannot generate RSA key: %v", err)) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + panic(fmt.Sprintf("failed to generate serial number: %v", err)) + } + + notBefore := time.Now() + notAfter := notBefore.Add(5 * 365 * time.Hour * 24) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"AdGuard Tests"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + template.DNSNames = append(template.DNSNames, tlsServerName) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + if err != nil { + panic(fmt.Sprintf("failed to create certificate: %v", err)) + } + + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(fmt.Sprintf("failed to create certificate: %v", err)) + } + + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(certPem) + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: tlsServerName, + RootCAs: roots, + MinVersion: tls.VersionTLS12, + } + + return tlsConfig +} + +func publicKey(priv any) (pub any) { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..20cab4d --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,3 @@ +// Package testutil implements a number of test structs and functions used +// in unit-tests. +package testutil diff --git a/internal/testutil/udpechoserver.go b/internal/testutil/udpechoserver.go new file mode 100644 index 0000000..9e861b7 --- /dev/null +++ b/internal/testutil/udpechoserver.go @@ -0,0 +1,105 @@ +package testutil + +import ( + "errors" + "io" + "net" + "sync" + + "github.com/ameshkov/udptlspipe/internal/udp" +) + +// UDPEchoServer is a test UDP pipe that accepts incoming connections and saves +// the information that it has received. +type UDPEchoServer struct { + listen net.Listener + received [][]byte + + mu sync.Mutex +} + +// ReceivedMsg returns the message received with the specified number of nil +// if there are no. Messages numbers start with 0. +func (s *UDPEchoServer) ReceivedMsg(num int) (b []byte) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.received) < num { + return nil + } + + return s.received[num] +} + +// Addr returns the address the pipe listens to. +func (s *UDPEchoServer) Addr() (str string) { + if s.listen == nil { + return "" + } + + return s.listen.Addr().String() +} + +// Start starts the echo pipe. +func (s *UDPEchoServer) Start() (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.listen != nil { + return errors.New("pipe is already started") + } + + s.listen, err = udp.Listen("udp", "127.0.0.1:0") + + if err != nil { + return err + } + + go s.serve() + + return nil +} + +// type check +var _ io.Closer = (*UDPEchoServer)(nil) + +// Close implements the io.Closer interface for *UDPEchoServer. +func (s *UDPEchoServer) Close() (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.listen.Close() +} + +// serve implements the serving loop. +func (s *UDPEchoServer) serve() { + for { + conn, err := s.listen.Accept() + if err != nil { + // For simplicity, stop the pipe here. + return + } + + go s.serveConn(conn) + } +} + +// serveConn handles one connection. +func (s *UDPEchoServer) serveConn(conn net.Conn) { + buf := make([]byte, 65536) + + for { + n, err := conn.Read(buf) + if err != nil { + return + } + + msg := make([]byte, n) + copy(msg, buf[:n]) + _, _ = conn.Write(msg) + + s.mu.Lock() + s.received = append(s.received, msg) + s.mu.Unlock() + } +} diff --git a/internal/tunnel/msgreadwriter.go b/internal/tunnel/msgreadwriter.go new file mode 100644 index 0000000..b575662 --- /dev/null +++ b/internal/tunnel/msgreadwriter.go @@ -0,0 +1,89 @@ +package tunnel + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/AdguardTeam/golibs/log" +) + +// MaxMessageLength is the maximum length that is safe to use. +// TODO(ameshkov): Make it configurable. +const MaxMessageLength = 1280 + +// MsgReadWriter is a wrapper over io.ReadWriter that encodes messages written +// to and read from the base writer. +type MsgReadWriter struct { + base io.ReadWriter +} + +// NewMsgReadWriter creates a new instance of *MsgReadWriter. +func NewMsgReadWriter(base io.ReadWriter) (rw *MsgReadWriter) { + return &MsgReadWriter{base: base} +} + +// type check +var _ io.ReadWriter = (*MsgReadWriter)(nil) + +// Read implements the io.ReadWriter interface for *MsgReadWriter. +func (rw *MsgReadWriter) Read(b []byte) (n int, err error) { + var length uint16 + err = binary.Read(rw.base, binary.BigEndian, &length) + if err != nil { + return 0, fmt.Errorf("reading binary data: %w", err) + } + + if length > MaxMessageLength { + // Warn the user that this may not work correctly. + log.Error( + "Warning: received message of length %d larger than %d, considering reducing the MTU", + length, + MaxMessageLength, + ) + } + + if len(b) < int(length) { + return 0, fmt.Errorf("message length %d is greater than the buffer size %d", length, len(b)) + } + + n, err = io.ReadFull(rw.base, b[:length]) + if err != nil { + return 0, fmt.Errorf("reading message: %w", err) + } + + return n, nil +} + +// Write implements the io.ReadWriter interface for *MsgReadWriter. +func (rw *MsgReadWriter) Write(b []byte) (n int, err error) { + if len(b) > MaxMessageLength { + // Warn the user that this may not work correctly. + log.Error( + "Warning: trying to write a message of length %d larger than %d, considering reducing the MTU", + len(b), + MaxMessageLength, + ) + } + + msg := Pack(b) + + n, err = rw.base.Write(msg) + + if err != nil { + return 0, err + } + + // Subtract the prefix length. + return n - 2, nil +} + +// Pack packs the message to be sent over the tunnel. +func Pack(b []byte) (msg []byte) { + msg = make([]byte, len(b)+2) + + binary.BigEndian.PutUint16(msg[:2], uint16(len(b))) + copy(msg[2:], b) + + return msg +} diff --git a/internal/tunnel/tunnel.go b/internal/tunnel/tunnel.go new file mode 100644 index 0000000..36f521a --- /dev/null +++ b/internal/tunnel/tunnel.go @@ -0,0 +1,54 @@ +// Package tunnel implements the tunneling logic for copying data between two +// network connections both sides. +package tunnel + +import ( + "fmt" + "io" + "sync" + + "github.com/AdguardTeam/golibs/log" +) + +// Tunnel passes data between two connections. +func Tunnel(pipeName string, left io.ReadWriter, right io.ReadWriter) { + wg := &sync.WaitGroup{} + wg.Add(2) + + go copyConn(fmt.Sprintf("%s left->right", pipeName), left, right, wg) + go copyConn(fmt.Sprintf("%s left<-right", pipeName), right, left, wg) + + wg.Wait() +} + +// copyConn copies data from reader r to writer r. +func copyConn(pipeName string, r io.Reader, w io.Writer, wg *sync.WaitGroup) { + defer wg.Done() + + buf := make([]byte, 65536) + var n int + var err error + + for { + n, err = r.Read(buf) + + if err != nil { + log.Debug("failed to read: %v", err) + + return + } + + if n == 0 { + continue + } + + log.Debug("%s: copying %d bytes", pipeName, n) + + _, err = w.Write(buf[:n]) + if err != nil { + log.Debug("failed to write: %v", err) + + return + } + } +} diff --git a/internal/udp/listener.go b/internal/udp/listener.go new file mode 100644 index 0000000..cc4d865 --- /dev/null +++ b/internal/udp/listener.go @@ -0,0 +1,256 @@ +// Package udp implements helper structures for working with UDP. +package udp + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/AdguardTeam/golibs/log" +) + +// Listener is a struct that implements net.Listener interface for working +// with UDP. This is achieved by maintaining an internal "nat-like" table +// with destinations. +type Listener struct { + conn *net.UDPConn + + // natTable is a table which maps peer addresses to udpConn structs. + // Whenever a new packet is received, Listener looks up if there's + // already a udpConn for the peer address and either creates a new one + // or adds the packet to the existing one. + natTable map[string]*udpConn + natTableMu sync.Mutex + + chanAccept chan *udpConn + chanClosed chan struct{} +} + +// Listen creates a new *Listener and is supposed to be a function similar +// to net.Listen, but for UDP only. +func Listen(network, addr string) (l *Listener, err error) { + listenAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + + l = &Listener{ + natTable: map[string]*udpConn{}, + chanAccept: make(chan *udpConn, 256), + chanClosed: make(chan struct{}, 1), + } + + l.conn, err = net.ListenUDP(network, listenAddr) + if err != nil { + return nil, err + } + + go l.readLoop() + + return l, nil +} + +// type check. +var _ net.Listener = (*Listener)(nil) + +// Accept implements the net.Listener interface for *Listener. +func (l *Listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.chanAccept: + return conn, nil + case <-l.chanClosed: + return nil, net.ErrClosed + } +} + +// Close implements the net.Listener interface for *Listener. +func (l *Listener) Close() (err error) { + close(l.chanClosed) + + l.natTableMu.Lock() + for _, c := range l.natTable { + log.OnCloserError(c, log.DEBUG) + } + l.natTableMu.Unlock() + + return l.conn.Close() +} + +// Addr implements the net.Listener interface for *Listener. +func (l *Listener) Addr() (addr net.Addr) { + return l.conn.LocalAddr() +} + +// readLoop implements the listener logic, it reads incoming data and passes it +// to the corresponding udpConn. When a new udpConn is created, it is written +// to the chanAccept channel. +func (l *Listener) readLoop() { + buf := make([]byte, 65536) + + for { + n, addr, err := l.conn.ReadFromUDP(buf) + + if err != nil || n == 0 { + if errors.Is(err, net.ErrClosed) { + return + } + + // TODO(ameshkov): Handle errors better here. + + continue + } + + msg := make([]byte, n) + copy(msg, buf[:n]) + l.acceptMsg(addr, msg) + } +} + +// acceptMsg passes the message to the corresponding udpConn. +func (l *Listener) acceptMsg(addr *net.UDPAddr, msg []byte) { + l.natTableMu.Lock() + defer l.natTableMu.Unlock() + + key := addr.String() + conn, _ := l.natTable[key] + if conn == nil || conn.isClosed() { + conn = newUDPConn(addr, l.conn) + l.natTable[key] = conn + + l.chanAccept <- conn + } + + conn.addMsg(msg) +} + +// udpConn represents a connection with a single peer. +type udpConn struct { + peerAddr *net.UDPAddr + conn *net.UDPConn + + remaining []byte + + closed bool + closedMu sync.Mutex + + chanMsg chan []byte + chanClosed chan struct{} +} + +// newUDPConn creates a new *udpConn for the specified peer. +func newUDPConn(peerAddr *net.UDPAddr, baseConn *net.UDPConn) (conn *udpConn) { + return &udpConn{ + peerAddr: peerAddr, + conn: baseConn, + chanMsg: make(chan []byte, 256), + chanClosed: make(chan struct{}, 1), + } +} + +// addMsg adds a new byte array that can be then read from this connection. +func (c *udpConn) addMsg(b []byte) { + c.chanMsg <- b +} + +// isClosed returns true if the connection is closed. +func (c *udpConn) isClosed() (ok bool) { + c.closedMu.Lock() + defer c.closedMu.Unlock() + + return c.closed +} + +// type check +var _ net.Conn = (*udpConn)(nil) + +// Read implements the net.Conn interface for *udpConn. +func (c *udpConn) Read(b []byte) (n int, err error) { + n = c.readRemaining(b) + if n > 0 { + return n, nil + } + + select { + case buf := <-c.chanMsg: + c.remaining = buf + n = c.readRemaining(b) + + return n, nil + case <-c.chanClosed: + return 0, net.ErrClosed + } +} + +// readRemaining reads remaining bytes that were not yet read. +func (c *udpConn) readRemaining(b []byte) (n int) { + if c.remaining == nil || len(c.remaining) == 0 { + return 0 + } + + if len(c.remaining) >= len(b) { + n = len(b) + + copy(b, c.remaining[:n]) + c.remaining = c.remaining[n:] + + return n + } + + n = len(c.remaining) + + copy(b[:n], c.remaining) + c.remaining = nil + + return n +} + +// Write implements the net.Conn interface for *udpConn. +func (c *udpConn) Write(b []byte) (n int, err error) { + return c.conn.WriteToUDP(b, c.peerAddr) +} + +// Close implements the net.Conn interface for *udpConn. +func (c *udpConn) Close() (err error) { + c.closedMu.Lock() + defer c.closedMu.Unlock() + + c.closed = true + close(c.chanClosed) + + // Do not close the underlying UDP connection as it's shared with other + // udpConn objects. + + return nil +} + +// LocalAddr implements the net.Conn interface for *udpConn. +func (c *udpConn) LocalAddr() (addr net.Addr) { + return c.conn.LocalAddr() +} + +// RemoteAddr implements the net.Conn interface for *udpConn. +func (c *udpConn) RemoteAddr() (addr net.Addr) { + return c.conn.RemoteAddr() +} + +// SetDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetDeadline(_ time.Time) (err error) { + // TODO(ameshkov): Implement it. + + return nil +} + +// SetReadDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetReadDeadline(_ time.Time) (err error) { + // TODO(ameshkov): Implement it. + + return nil +} + +// SetWriteDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetWriteDeadline(_ time.Time) (err error) { + // TODO(ameshkov): Implement it. + + return nil +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000..fdd7030 --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,20 @@ +// Package version exports some getters for the project's version values. +package version + +// Versions + +// These are set by the linker. Unfortunately, we cannot set constants during +// linking, and Go doesn't have a concept of immutable variables, so to be +// thorough we have to only export them through getters. +var ( + version string +) + +// Version returns the compiled-in value of this product version as a string. +func Version() (v string) { + if version == "" { + return "v0.0" + } + + return version +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..c81ecba --- /dev/null +++ b/main.go @@ -0,0 +1,8 @@ +// Package main is the program entrypoint.Ï +package main + +import "github.com/ameshkov/udptlspipe/internal/cmd" + +func main() { + cmd.Main() +}