diff --git a/Makefile b/Makefile index c617396..0e692f1 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ build-ci: cd ./build && tar -zcvf ./proxy-fix-linux-arm64.tar.gz ./proxy-fix-linux-arm64 run: - env PORT=8080 NOHUP=1 go run . bun ./test/app.ts + env PORT=8080 NOHUP=1 TARGET=127.0.0.1:30000 go run . bun ./test/app.ts diff --git a/README.md b/README.md index e70f5db..26dea5c 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Download from releases or build it and place it to `~/.local/bin/proxfix`. ```bash PROXYFIX=proxy-fix-linux-$( [ "$(uname -m)" = "aarch64" ] && echo "arm64" || echo "amd64" ) -wget https://github.com/domcloud/proxy-fix/releases/download/v0.2.3/$PROXYFIX.tar.gz +wget https://github.com/domcloud/proxy-fix/releases/download/v0.2.4/$PROXYFIX.tar.gz tar -xf $PROXYFIX.tar.gz && mv $PROXYFIX /usr/local/bin/proxfix && rm -rf $PROXYFIX* ``` diff --git a/main.go b/main.go index 188d950..f7c8f1a 100644 --- a/main.go +++ b/main.go @@ -7,19 +7,24 @@ import ( "os" "os/exec" "strconv" + "strings" "syscall" ) -var outPort int +var outDial string func init() { var err error var pid int - outPort, pid, err = checkExistingProcess() + args := os.Args + bg := os.Getenv("NOHUP") == "1" + preferredDial := os.Getenv("TARGET") + + outDial, pid, err = checkExistingProcess() if err == nil { - if isPortListening(outPort) { - fmt.Printf("Process is already running on port %d\n", outPort) + if isPortListening(outDial) && outDial == preferredDial { + fmt.Printf("Process is already running on %d\n", outDial) return } else { fmt.Printf("Killing stale process %d \n", pid) @@ -28,14 +33,15 @@ func init() { } // No existing process found or not listening, start a new one - outPort, err = getFreePort() - if err != nil { - panic("Can't get free port") + if preferredDial == "" { + outDial, err = getFreeDial() + if err != nil { + panic("Can't get free port") + } + } else { + outDial = preferredDial } - args := os.Args - bg := os.Getenv("NOHUP") == "1" - // Check if there are additional arguments if len(args) > 1 { if bg { @@ -50,7 +56,9 @@ func init() { Pgid: 0, } cmd.Env = os.Environ() - cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%d", outPort)) + if strings.HasPrefix(outDial, LOCAL_PREFIX) { + cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%s", outDial[len(LOCAL_PREFIX):])) + } err = cmd.Run() if err != nil { fmt.Printf("Error starting command: %v\n", err) @@ -65,14 +73,16 @@ func init() { fmt.Printf("invalid PID of 0") os.Exit(1) } - writePidPortFile(pid, outPort) + writePidPortFile(pid, outDial) fmt.Printf("Started process %s with PID %d in background\n", args[1], pid) } else { cmd := exec.Command(args[1], args[2:]...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Env = os.Environ() - cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%d", outPort)) + if strings.HasPrefix(outDial, LOCAL_PREFIX) { + cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%s", outDial[len(LOCAL_PREFIX):])) + } // Start the specified command err := cmd.Start() @@ -100,7 +110,7 @@ func main() { func startProxy(address string) error { proxy := Proxy{ - DialTarget: fmt.Sprintf("localhost:%d", outPort), + DialTarget: outDial, } listener, err := net.Listen("tcp", address) if err != nil { diff --git a/util.go b/util.go index 0f63ab9..6b7d4df 100644 --- a/util.go +++ b/util.go @@ -15,16 +15,17 @@ import ( const MAX_RETRY = 20 const WAIT_RETRY = time.Second * 1 +const LOCAL_PREFIX = "127.0.0.1:" var invalidHeaderRegex = regexp.MustCompile("[^a-zA-Z0-9-]+") -func getFreePort() (port int, err error) { +func getFreeDial() (dial string, err error) { var a *net.TCPAddr if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { var l *net.TCPListener if l, err = net.ListenTCP("tcp", a); err == nil { defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil + return fmt.Sprintf("%s%d", LOCAL_PREFIX, l.Addr().(*net.TCPAddr).Port), nil } } return @@ -42,31 +43,26 @@ func getPidFile() string { return filepath.Join(os.Getenv("HOME"), "tmp", "app.pid") } -func checkExistingProcess() (port int, pid int, err error) { +func checkExistingProcess() (dial string, pid int, err error) { data, err := os.ReadFile(getPidFile()) if err != nil { - return 0, 0, err + return "", 0, err } - fields := strings.Split(string(data), ":") - if len(fields) != 2 { - return 0, 0, fmt.Errorf("invalid pid/port file format") + fields := strings.Split(string(data), ";") + if len(fields) != 2 || len(strings.Split(fields[1], ":")) < 2 { + return "", 0, fmt.Errorf("invalid pid/port file format") } pid, err = strconv.Atoi(fields[0]) if err != nil { - return 0, 0, fmt.Errorf("invalid PID format: %v", err) - } - - port, err = strconv.Atoi(fields[1]) - if err != nil { - return 0, 0, fmt.Errorf("invalid port format: %v", err) + return "", 0, fmt.Errorf("invalid PID format: %v", err) } if processExists(pid) { - return port, pid, nil + return fields[1], pid, nil } - return 0, 0, fmt.Errorf("no running process found") + return "", 0, fmt.Errorf("no running process found") } func processExists(pid int) bool { @@ -87,10 +83,10 @@ func processKill(pid int) bool { return err == nil } -func isPortListening(port int) bool { +func isPortListening(dial string) bool { retries := 0 retry: - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + conn, err := net.Dial("tcp", dial) if err != nil { if retries < MAX_RETRY { retries += 1 @@ -105,8 +101,8 @@ retry: return true } -func writePidPortFile(pid, port int) { - err := os.WriteFile(getPidFile(), []byte(fmt.Sprintf("%d:%d", pid, port)), 0644) +func writePidPortFile(pid int, dial string) { + err := os.WriteFile(getPidFile(), []byte(fmt.Sprintf("%d;%s", pid, dial)), 0644) if err != nil { fmt.Printf("Failed to write PID/port file: %v\n", err) }