From 7e1fa9e1159277ea42e673a4cd0d8a71f5060b53 Mon Sep 17 00:00:00 2001 From: Lev Vysotsky Date: Mon, 11 Apr 2022 17:10:15 +0300 Subject: [PATCH 1/5] Add rate limiting capabilities --- go.mod | 1 + go.sum | 2 + integration_test/integration_test.go | 116 ++++++++++++++++-- server.go | 45 ++++++- service/limiter.go | 171 +++++++++++++++++++++++++++ service/limiter_test.go | 105 ++++++++++++++++ service/limiter_testing.go | 36 ++++++ service/tcp.go | 27 +++-- service/tcp_test.go | 26 ++-- service/udp.go | 16 ++- service/udp_test.go | 6 +- 11 files changed, 516 insertions(+), 35 deletions(-) create mode 100644 service/limiter.go create mode 100644 service/limiter_test.go create mode 100644 service/limiter_testing.go diff --git a/go.mod b/go.mod index fcfa48d5..016c778e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/prometheus/procfs v0.1.3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect google.golang.org/protobuf v1.23.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index 07d7952a..c1db5c2f 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index de674e7e..812f529d 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -97,6 +97,11 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { return conn, &running } +func makeLimiter(cipherList service.CipherList) service.RateLimiter { + c := service.MakeTestRateLimiterConfig(cipherList) + return service.NewRateLimiter(&c) +} + func TestTCPEcho(t *testing.T) { echoListener, echoRunning := startTCPEchoServer(t) @@ -111,7 +116,7 @@ func TestTCPEcho(t *testing.T) { } replayCache := service.NewReplayCache(5) const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -164,6 +169,103 @@ func TestTCPEcho(t *testing.T) { echoRunning.Wait() } +func TestRateLimiter(t *testing.T) { + echoListener, echoRunning := startTCPEchoServer(t) + + proxyListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + secrets := ss.MakeTestSecrets(1) + cipherList, err := service.MakeTestCiphers(secrets) + if err != nil { + t.Fatal(err) + } + replayCache := service.NewReplayCache(5) + const testTimeout = 5 * time.Second + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID + rateLimiter := service.NewRateLimiter(&service.RateLimiterConfig{ + KeyToLimits: map[string]service.KeyLimits{ + key: service.KeyLimits{ + LargeScaleLimit: 1000, + LargeScalePeriod: 5 * time.Second, + SmallScaleLimit: 100, + SmallScalePeriod: 100 * time.Millisecond, + }, + }, + }) + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, rateLimiter) + proxy.SetTargetIPValidator(allowAll) + go proxy.Serve(proxyListener) + + proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String()) + if err != nil { + t.Fatal(err) + } + portNum, err := strconv.Atoi(proxyPort) + if err != nil { + t.Fatal(err) + } + client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + + const N = 500 + up := make([]byte, N) + for i := 0; i < N; i++ { + up[i] = byte(i) + } + { + conn, err := client.DialTCP(nil, echoListener.Addr().String()) + if err != nil { + t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) + } + start := time.Now() + n, err := conn.Write(up) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n) + } + + down := make([]byte, N) + n, err = conn.Read(down) + if err != nil && err != io.EOF { + t.Fatal(err) + } + if n != N { + t.Fatalf("Expected to download %d bytes, but only received %d", N, n) + } + if time.Now().Sub(start) < 600 * time.Millisecond { + t.Fatalf("Download too fast") + } + + if !bytes.Equal(up, down) { + t.Fatal("Echo mismatch") + } + + conn.Close() + } + + { + conn, err := client.DialTCP(nil, echoListener.Addr().String()) + if err != nil { + t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) + } + _, err = conn.Write(up) + if err == nil { + t.Fatalf("Expected limit error when uploading") + } + conn.Close() + } + + proxy.Stop() + echoListener.Close() + echoRunning.Wait() +} + type statusMetrics struct { metrics.NoOpMetrics sync.Mutex @@ -184,7 +286,7 @@ func TestRestrictedAddresses(t *testing.T) { require.NoError(t, err) const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} - proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout) + proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout, makeLimiter(cipherList)) go proxy.Serve(proxyListener) proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String()) @@ -266,7 +368,7 @@ func TestUDPEcho(t *testing.T) { t.Fatal(err) } testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"} - proxy := service.NewUDPService(time.Hour, cipherList, testMetrics) + proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) @@ -363,7 +465,7 @@ func BenchmarkTCPThroughput(b *testing.B) { b.Fatal(err) } const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -430,7 +532,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { } replayCache := service.NewReplayCache(service.MaxCapacity) const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -505,7 +607,7 @@ func BenchmarkUDPEcho(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}) + proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) @@ -554,7 +656,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}) + proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) diff --git a/server.go b/server.go index c3c04121..93f9eb53 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ type SSServer struct { ports map[int]*ssPort } -func (s *SSServer) startPort(portNum int) error { +func (s *SSServer) startPort(portNum int, rateLimiterConfig *service.RateLimiterConfig) error { listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) if err != nil { return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err) @@ -85,9 +85,11 @@ func (s *SSServer) startPort(portNum int) error { } logger.Infof("Listening TCP and UDP on port %v", portNum) port := &ssPort{cipherList: service.NewCipherList()} + + limiter := service.NewRateLimiter(rateLimiterConfig) // TODO: Register initial data metrics at zero. - port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout) - port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout, limiter) + port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m, limiter) s.ports[portNum] = port go port.tcpService.Serve(listener) go port.udpService.Serve(packetConn) @@ -120,6 +122,7 @@ func (s *SSServer) loadConfig(filename string) error { portChanges := make(map[int]int) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. + portKeyLimits := make(map[int]map[string]service.KeyLimits) for _, keyConfig := range config.Keys { portChanges[keyConfig.Port] = 1 cipherList, ok := portCiphers[keyConfig.Port] @@ -133,6 +136,19 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(keyConfig.ID, cipher, keyConfig.Secret) cipherList.PushBack(&entry) + var keyLimits map[string]service.KeyLimits + keyLimits, ok = portKeyLimits[keyConfig.Port] + if !ok { + keyLimits = make(map[string]service.KeyLimits) + portKeyLimits[keyConfig.Port] = keyLimits + } + if keyConfig.Limits != nil { + keyLimits[keyConfig.ID] = *keyConfig.Limits + } else if config.DefaultKeyLimits != nil { + keyLimits[keyConfig.ID] = *config.DefaultKeyLimits + } else { + keyLimits[keyConfig.ID] = noLimits + } } for port := range s.ports { portChanges[port] = portChanges[port] - 1 @@ -143,7 +159,8 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("Failed to remove port %v: %v", portNum, err) } } else if count == +1 { - if err := s.startPort(portNum); err != nil { + rateLimiterConfig := &service.RateLimiterConfig{KeyToLimits: portKeyLimits[portNum]} + if err := s.startPort(portNum, rateLimiterConfig); err != nil { return fmt.Errorf("Failed to start port %v: %v", portNum, err) } } @@ -197,7 +214,16 @@ type Config struct { Port int Cipher string Secret string + Limits *service.KeyLimits } + DefaultKeyLimits *service.KeyLimits +} + +var noLimits service.KeyLimits = service.KeyLimits{ + LargeScalePeriod: time.Millisecond, + LargeScaleLimit: 1 << 30, + SmallScalePeriod: time.Millisecond, + SmallScaleLimit: 1 << 30, } func readConfig(filename string) (*Config, error) { @@ -207,6 +233,17 @@ func readConfig(filename string) (*Config, error) { return nil, err } err = yaml.Unmarshal(configData, &config) + if err != nil { + return nil, err + } + if config.DefaultKeyLimits == nil { + config.DefaultKeyLimits = &noLimits + } + for i := range config.Keys { + if config.Keys[i].Limits == nil { + config.Keys[i].Limits = config.DefaultKeyLimits + } + } return &config, err } diff --git a/service/limiter.go b/service/limiter.go new file mode 100644 index 00000000..6167baa9 --- /dev/null +++ b/service/limiter.go @@ -0,0 +1,171 @@ +// Copyright 2018 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "fmt" + "io" + "log" + "math" + "time" + + "golang.org/x/time/rate" +) + +type KeyLimits struct { + LargeScalePeriod time.Duration + LargeScaleLimit int64 + SmallScalePeriod time.Duration + SmallScaleLimit int64 +} + +type RateLimiterConfig struct { + KeyToLimits map[string]KeyLimits +} + +type RateLimiter interface { + WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer, error) + Allow(accessKey string, n int) error +} + +func NewRateLimiter(config *RateLimiterConfig) RateLimiter { + keyToLimiter := make(map[string]*perKeyLimiter, 0) + for accessKey, limits := range config.KeyToLimits { + keyToLimiter[accessKey] = &perKeyLimiter{ + largeScale: createLimiter(limits.LargeScalePeriod, limits.LargeScaleLimit), + smallScale: createLimiter(limits.SmallScalePeriod, limits.SmallScaleLimit), + } + } + return &rateLimiter{keyToLimiter: keyToLimiter} +} + +type rateLimiter struct { + keyToLimiter map[string]*perKeyLimiter +} + +type perKeyLimiter struct { + smallScale *rate.Limiter + largeScale *rate.Limiter +} + +// We need larger granularity, because rate.RateLimiter +// works with ints. +const tokenSizeBytes = 1024 +const maxSizeBytes = math.MaxInt32 * tokenSizeBytes + +func bytesToTokens64(n int64) int { + // Round up to avoid attack involving small reads. + if n >= maxSizeBytes { + log.Panicf("%v bytes cannot be converted to tokens", n) + } + return (int) ((n + tokenSizeBytes - 1) / tokenSizeBytes) +} + +func bytesToTokens(n int) int { + // Round up to avoid attack involving small reads. + return (n + tokenSizeBytes - 1) / tokenSizeBytes +} + +func min(a int, b int) int { + if a < b { + return a + } else { + return b + } +} + +func (l *perKeyLimiter) Wait(n int) error { + tokens := bytesToTokens(n) + if !l.largeScale.AllowN(time.Now(), tokens) { + return fmt.Errorf("exceeds large scale limit") + } + for tokens > 0 { + batch := min(tokens, int(l.smallScale.Burst())) + err := l.smallScale.WaitN(context.TODO(), batch) + if err != nil { + return err + } + tokens -= batch + } + return nil +} + +func (l *perKeyLimiter) Allow(n int) error { + tokens := bytesToTokens(n) + if !l.largeScale.AllowN(time.Now(), tokens) { + return fmt.Errorf("exceeds large-scale limit") + } + if !l.smallScale.AllowN(time.Now(), tokens) { + return fmt.Errorf("exceeds small-scale limit") + } + return nil +} + +type limitedReader struct { + reader io.Reader + limiter *perKeyLimiter +} + +func (r *limitedReader) Read(b []byte) (int, error) { + n, err := r.reader.Read(b) + if n <= 0 { + return n, err + } + waitErr := r.limiter.Wait(n) + if waitErr != nil { + return n, waitErr + } + return n, err +} + +type limitedWriter struct { + writer io.Writer + limiter *perKeyLimiter +} + +func (w *limitedWriter) Write(b []byte) (int, error) { + n, err := w.writer.Write(b) + if n <= 0 { + return n, err + } + waitErr := w.limiter.Wait(n) + if waitErr != nil { + return n, waitErr + } + return n, err +} + +func createLimiter(period time.Duration, limit int64) *rate.Limiter { + b := bytesToTokens64(limit) + r := rate.Every(period) * rate.Limit(b) + return rate.NewLimiter(r, b) +} + +func (l *rateLimiter) WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer, error) { + limiter, ok := l.keyToLimiter[accessKey] + if !ok { + return nil, nil, fmt.Errorf("Access key %v not found", accessKey) + } + return &limitedReader{reader: reader, limiter: limiter}, &limitedWriter{writer: writer, limiter: limiter}, nil +} + +func (l *rateLimiter) Allow(accessKey string, n int) error { + limiter, ok := l.keyToLimiter[accessKey] + if !ok { + return fmt.Errorf("Access key %v not found", accessKey) + } + return limiter.Allow(n) +} diff --git a/service/limiter_test.go b/service/limiter_test.go new file mode 100644 index 00000000..aff3dcd2 --- /dev/null +++ b/service/limiter_test.go @@ -0,0 +1,105 @@ +package service + +import ( + "bytes" + "io" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func makeRandBuffer(n int64) *bytes.Buffer { + arr := make([]byte, n) + rand.Read(arr) + return bytes.NewBuffer(arr) +} + +func TestRateLimiter(t *testing.T) { + key1 := "key1" + key2 := "key2" + + var tok int64 = 1024 + config := RateLimiterConfig{ + KeyToLimits: map[string]KeyLimits{ + key1: KeyLimits { + LargeScalePeriod: time.Minute, + LargeScaleLimit: 10 * tok, + SmallScalePeriod: time.Second, + SmallScaleLimit: 2 * tok, + }, + key2: KeyLimits{ + LargeScalePeriod: time.Minute, + LargeScaleLimit: 10 * tok, + SmallScalePeriod: time.Second, + SmallScaleLimit: 3 * tok, + }, + }, + } + + limiter := NewRateLimiter(&config) + + src1 := makeRandBuffer(20 * tok) + src1Orig := src1.Bytes() + dst1 := &bytes.Buffer{} + + src2 := makeRandBuffer(20 * tok) + src2Orig := src2.Bytes() + dst2 := &bytes.Buffer{} + + r1, w1, err1 := limiter.WrapReaderWriter(key1, src1, dst1) + require.NoError(t, err1) + r2, w2, err2 := limiter.WrapReaderWriter(key2, src2, dst2) + require.NoError(t, err2) + + b := make([]byte, 50) + + start := time.Now() + _, err := io.ReadFull(r1, b) + require.NoError(t, err) + require.Equal(t, b, src1Orig[:len(b)]) + if time.Now().Sub(start) > 10 * time.Millisecond { + t.Errorf("read took too long") + } + + start = time.Now() + _, err = io.ReadFull(r2, b) + require.NoError(t, err) + require.Equal(t, b, src2Orig[:len(b)]) + if time.Now().Sub(start) > 10 * time.Millisecond { + t.Errorf("read took too long") + } + + start = time.Now() + size := 2 * tok + _, err = w1.Write(src1Orig[:size]) + require.NoError(t, err) + require.Equal(t, src1Orig[:size], dst1.Bytes()[:size]) + if time.Now().Sub(start) < 500 * time.Millisecond { + t.Fatalf("write took too short") + } + + allowErr := limiter.Allow(key2, int(3 * tok)) + require.NoError(t, allowErr) + + allowErr = limiter.Allow(key2, int(1 * tok)) + require.Error(t, allowErr) + + start = time.Now() + size = 3 * tok + _, err = w2.Write(src2Orig[:size]) + require.NoError(t, err) + require.Equal(t, src2Orig[:size], dst2.Bytes()[:size]) + if time.Now().Sub(start) < 500 * time.Millisecond { + t.Fatalf("write took too short") + } + + start = time.Now() + size = 7 * tok + _, err = w2.Write(src2Orig[:size]) + require.Error(t, err) + if time.Now().Sub(start) > 10 * time.Millisecond { + t.Fatalf("write took too long") + } +} \ No newline at end of file diff --git a/service/limiter_testing.go b/service/limiter_testing.go new file mode 100644 index 00000000..355fcda8 --- /dev/null +++ b/service/limiter_testing.go @@ -0,0 +1,36 @@ +// Copyright 2018 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "net" + "time" +) + +func MakeTestRateLimiterConfig(ciphers CipherList) RateLimiterConfig { + elts := ciphers.SnapshotForClientIP(net.IP{}) + keyLimits := KeyLimits{ + LargeScalePeriod: 1000 * time.Hour, + LargeScaleLimit: 1 << 30, + SmallScalePeriod: 1000 * time.Hour, + SmallScaleLimit: 1 << 30, + } + keyToLimits := make(map[string]KeyLimits) + for _, elt := range elts { + entry := elt.Value.(*CipherEntry) + keyToLimits[entry.ID] = keyLimits + } + return RateLimiterConfig{KeyToLimits: keyToLimits} +} diff --git a/service/tcp.go b/service/tcp.go index 69d71a05..5fcf5507 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -102,7 +102,6 @@ func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list. continue } debugTCP(id, "Found cipher at index %d", ci) - // Move the active cipher to the front, so that the search is quicker next time. return entry, elt } return nil, nil @@ -114,6 +113,7 @@ type tcpService struct { stopped bool ciphers CipherList m metrics.ShadowsocksMetrics + limiter RateLimiter running sync.WaitGroup readTimeout time.Duration // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. @@ -123,11 +123,13 @@ type tcpService struct { // NewTCPService creates a TCPService // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. -func NewTCPService(ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService { +func NewTCPService(ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, + timeout time.Duration, limiter RateLimiter) TCPService { return &tcpService{ ciphers: ciphers, m: m, - readTimeout: timeout, + readTimeout: timeout, + limiter: limiter, replayCache: replayCache, targetIPValidator: onet.RequirePublicIP, } @@ -227,6 +229,17 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo var proxyMetrics metrics.ProxyMetrics clientConn := metrics.MeasureConn(clientTCPConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientTCPConn), s.ciphers) + var clientWriter io.Writer = clientConn + + var accessKey string + if cipherEntry != nil { + accessKey = cipherEntry.ID + var limiterErr error + clientReader, clientWriter, limiterErr = s.limiter.WrapReaderWriter(accessKey, clientReader, clientWriter) + if limiterErr != nil { + logger.Panicf("got unexpected error wrapping streams: %v", limiterErr) + } + } connError := func() *onet.ConnectionError { if keyErr != nil { @@ -268,7 +281,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo defer tgtConn.Close() logger.Debugf("proxy %s <-> %s", clientTCPConn.RemoteAddr().String(), tgtConn.RemoteAddr().String()) - ssw := ss.NewShadowsocksWriter(clientConn, cipherEntry.Cipher) + ssw := ss.NewShadowsocksWriter(clientWriter, cipherEntry.Cipher) ssw.SetSaltGenerator(cipherEntry.SaltGenerator) fromClientErrCh := make(chan error) @@ -306,11 +319,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo logger.Debugf("TCP Error: %v: %v", connError.Message, connError.Cause) status = connError.Status } - var id string - if cipherEntry != nil { - id = cipherEntry.ID - } - s.m.AddClosedTCPConnection(clientLocation, id, status, proxyMetrics, timeToCipher, connDuration) + s.m.AddClosedTCPConnection(clientLocation, accessKey, status, proxyMetrics, timeToCipher, connDuration) clientConn.Close() // Closing after the metrics are added aids integration testing. logger.Debugf("Done with status %v, duration %v", status, connDuration) } diff --git a/service/tcp_test.go b/service/tcp_test.go index a6080be6..b07ae1dd 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -277,12 +277,18 @@ func probe(serverAddr *net.TCPAddr, bytesToSend []byte) error { return nil } +func makeLimiter(cipherList CipherList) RateLimiter { + c := MakeTestRateLimiterConfig(cipherList) + return NewRateLimiter(&c) +} + func TestProbeRandom(t *testing.T) { listener := makeLocalhostListener(t) cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(1)) + require.Nil(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) go s.Serve(listener) // 221 is the largest random probe reported by https://gfw.report/blog/gfw_shadowsocks/ @@ -349,7 +355,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -379,7 +385,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -410,7 +416,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -448,7 +454,7 @@ func TestProbeServerBytesModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) go s.Serve(listener) initialBytes := makeServerBytes(t, cipher) @@ -473,7 +479,7 @@ func TestReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.Cipher @@ -546,7 +552,7 @@ func TestReverseReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.Cipher @@ -611,7 +617,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(5)) require.Nil(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, testTimeout) + s := NewTCPService(cipherList, nil, testMetrics, testTimeout, makeLimiter(cipherList)) testPayload := ss.MakeTestPayload(payloadSize) done := make(chan bool) @@ -674,7 +680,7 @@ func TestTCPDoubleServe(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) c := make(chan error) for i := 0; i < 2; i++ { @@ -704,7 +710,7 @@ func TestTCPEarlyStop(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) if err := s.Stop(); err != nil { t.Error(err) diff --git a/service/udp.go b/service/udp.go index 42160497..8d75047b 100644 --- a/service/udp.go +++ b/service/udp.go @@ -78,11 +78,18 @@ type udpService struct { m metrics.ShadowsocksMetrics running sync.WaitGroup targetIPValidator onet.TargetIPValidator + limiter RateLimiter } // NewUDPService creates a UDPService -func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics) UDPService { - return &udpService{natTimeout: natTimeout, ciphers: cipherList, m: m, targetIPValidator: onet.RequirePublicIP} +func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics, limiter RateLimiter) UDPService { + return &udpService{ + natTimeout: natTimeout, + ciphers: cipherList, + m: m, + targetIPValidator: onet.RequirePublicIP, + limiter: limiter, + } } // UDPService is a running UDP shadowsocks proxy that can be stopped. @@ -221,6 +228,11 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) + limitErr := s.limiter.Allow(keyID, len(payload)) + if limitErr != nil { + debugUDPAddr(clientAddr, "Rate limite exceeded: %v", limitErr) + return onet.NewConnectionError("ERR_LIMIT", "Rate limit exceeded", limitErr) + } proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) diff --git a/service/udp_test.go b/service/udp_test.go index 12139fdc..808e8436 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -130,7 +130,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher clientConn := makePacketConn() metrics := &natTestMetrics{} - service := NewUDPService(timeout, ciphers, metrics) + service := NewUDPService(timeout, ciphers, metrics, makeLimiter(ciphers)) service.SetTargetIPValidator(validator) go service.Serve(clientConn) @@ -474,7 +474,7 @@ func TestUDPDoubleServe(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewUDPService(testTimeout, cipherList, testMetrics) + s := NewUDPService(testTimeout, cipherList, testMetrics, makeLimiter(cipherList)) c := make(chan error) for i := 0; i < 2; i++ { @@ -508,7 +508,7 @@ func TestUDPEarlyStop(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewUDPService(testTimeout, cipherList, testMetrics) + s := NewUDPService(testTimeout, cipherList, testMetrics, makeLimiter(cipherList)) if err := s.Stop(); err != nil { t.Error(err) From cfcee9d8a05de7e7a10eafea07432ac2f21c6a27 Mon Sep 17 00:00:00 2001 From: Lev Vysotsky Date: Wed, 13 Apr 2022 18:12:08 +0300 Subject: [PATCH 2/5] Fixes --- integration_test/integration_test.go | 144 +++++++++++++++++++++++---- server.go | 12 +-- service/limiter.go | 39 ++++---- service/limiter_test.go | 38 ++++--- service/limiter_testing.go | 8 +- service/tcp.go | 20 ++-- service/tcp_test.go | 6 +- service/udp.go | 30 ++++-- service/udp_test.go | 11 +- 9 files changed, 214 insertions(+), 94 deletions(-) diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index 812f529d..74ffb7bd 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -97,9 +97,9 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { return conn, &running } -func makeLimiter(cipherList service.CipherList) service.RateLimiter { - c := service.MakeTestRateLimiterConfig(cipherList) - return service.NewRateLimiter(&c) +func makeLimiter(cipherList service.CipherList) service.TrafficLimiter { + c := service.MakeTestTrafficLimiterConfig(cipherList) + return service.NewTrafficLimiter(&c) } func TestTCPEcho(t *testing.T) { @@ -169,7 +169,7 @@ func TestTCPEcho(t *testing.T) { echoRunning.Wait() } -func TestRateLimiter(t *testing.T) { +func TestTrafficLimiterTCP(t *testing.T) { echoListener, echoRunning := startTCPEchoServer(t) proxyListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) @@ -183,18 +183,21 @@ func TestRateLimiter(t *testing.T) { } replayCache := service.NewReplayCache(5) const testTimeout = 5 * time.Second + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID - rateLimiter := service.NewRateLimiter(&service.RateLimiterConfig{ + const tok = 1024 + trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ KeyToLimits: map[string]service.KeyLimits{ key: service.KeyLimits{ - LargeScaleLimit: 1000, - LargeScalePeriod: 5 * time.Second, - SmallScaleLimit: 100, + LargeScaleLimit: 80 * tok, + LargeScalePeriod: 60 * time.Second, + SmallScaleLimit: 10 * tok, SmallScalePeriod: 100 * time.Millisecond, }, }, }) - proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, rateLimiter) + + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, trafficLimiter) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -211,12 +214,9 @@ func TestRateLimiter(t *testing.T) { t.Fatalf("Failed to create ShadowsocksClient: %v", err) } - const N = 500 - up := make([]byte, N) - for i := 0; i < N; i++ { - up[i] = byte(i) - } { + const N = 20 * tok + up := ss.MakeTestPayload(N) conn, err := client.DialTCP(nil, echoListener.Addr().String()) if err != nil { t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) @@ -231,14 +231,14 @@ func TestRateLimiter(t *testing.T) { } down := make([]byte, N) - n, err = conn.Read(down) + n, err = io.ReadFull(conn, down) if err != nil && err != io.EOF { t.Fatal(err) } if n != N { t.Fatalf("Expected to download %d bytes, but only received %d", N, n) } - if time.Now().Sub(start) < 600 * time.Millisecond { + if time.Now().Sub(start) < 100*time.Millisecond { t.Fatalf("Download too fast") } @@ -250,13 +250,24 @@ func TestRateLimiter(t *testing.T) { } { + const N = 50 * tok + up := ss.MakeTestPayload(N) conn, err := client.DialTCP(nil, echoListener.Addr().String()) if err != nil { t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) } + _, err = conn.Write(up) + if err != nil { + // No write error is expected + // as proxy just discards all the input + t.Fatalf("Unexpected error: %v", err) + } + + down := make([]byte, N) + _, err = io.ReadFull(conn, down) if err == nil { - t.Fatalf("Expected limit error when uploading") + t.Fatalf("Expected read error") } conn.Close() } @@ -266,6 +277,105 @@ func TestRateLimiter(t *testing.T) { echoRunning.Wait() } +func TestTrafficLimiterUDP(t *testing.T) { + echoConn, echoRunning := startUDPEchoServer(t) + + proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + secrets := ss.MakeTestSecrets(1) + cipherList, err := service.MakeTestCiphers(secrets) + if err != nil { + t.Fatal(err) + } + testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"} + + const tok = 1024 + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID + smallScalePeriod := 100 * time.Millisecond + trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ + KeyToLimits: map[string]service.KeyLimits{ + key: service.KeyLimits{ + LargeScaleLimit: 100 * tok, + LargeScalePeriod: 60 * time.Second, + SmallScaleLimit: 10 * tok, + SmallScalePeriod: smallScalePeriod, + }, + }, + }) + + proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, trafficLimiter) + proxy.SetTargetIPValidator(allowAll) + go proxy.Serve(proxyConn) + + proxyHost, proxyPort, err := net.SplitHostPort(proxyConn.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + portNum, err := strconv.Atoi(proxyPort) + if err != nil { + t.Fatal(err) + } + client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + conn, err := client.ListenUDP(nil) + if err != nil { + t.Fatalf("ShadowsocksClient.ListenUDP failed: %v", err) + } + + run := func(N int, expectReadError bool) { + up := ss.MakeTestPayload(N) + n, err := conn.WriteTo(up, echoConn.LocalAddr()) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n) + } + + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + + down := make([]byte, N) + n, addr, err := conn.ReadFrom(down) + if err != nil { + if !expectReadError { + t.Fatalf("Unexpected read error: %v", err) + } + return + } else { + if expectReadError { + t.Fatalf("Expected read error") + } + } + if n != N { + t.Fatalf("Tried to download %d bytes, but only sent %d", N, n) + } + if addr.String() != echoConn.LocalAddr().String() { + t.Errorf("Reported address mismatch: %s != %s", addr.String(), echoConn.LocalAddr().String()) + } + + if !bytes.Equal(up, down) { + t.Fatal("Echo mismatch") + } + } + + for i := 0; i < 7; i++ { + run(5*tok, false) + run(5*tok, true) + time.Sleep(smallScalePeriod) + } + + run(10*tok, true) + + conn.Close() + echoConn.Close() + echoRunning.Wait() + proxy.GracefulStop() +} + type statusMetrics struct { metrics.NoOpMetrics sync.Mutex diff --git a/server.go b/server.go index 93f9eb53..61efc72a 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ type SSServer struct { ports map[int]*ssPort } -func (s *SSServer) startPort(portNum int, rateLimiterConfig *service.RateLimiterConfig) error { +func (s *SSServer) startPort(portNum int, trafficLimiterConfig *service.TrafficLimiterConfig) error { listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) if err != nil { return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err) @@ -86,7 +86,7 @@ func (s *SSServer) startPort(portNum int, rateLimiterConfig *service.RateLimiter logger.Infof("Listening TCP and UDP on port %v", portNum) port := &ssPort{cipherList: service.NewCipherList()} - limiter := service.NewRateLimiter(rateLimiterConfig) + limiter := service.NewTrafficLimiter(trafficLimiterConfig) // TODO: Register initial data metrics at zero. port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout, limiter) port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m, limiter) @@ -159,8 +159,8 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("Failed to remove port %v: %v", portNum, err) } } else if count == +1 { - rateLimiterConfig := &service.RateLimiterConfig{KeyToLimits: portKeyLimits[portNum]} - if err := s.startPort(portNum, rateLimiterConfig); err != nil { + trafficLimiterConfig := &service.TrafficLimiterConfig{KeyToLimits: portKeyLimits[portNum]} + if err := s.startPort(portNum, trafficLimiterConfig); err != nil { return fmt.Errorf("Failed to start port %v: %v", portNum, err) } } @@ -221,9 +221,9 @@ type Config struct { var noLimits service.KeyLimits = service.KeyLimits{ LargeScalePeriod: time.Millisecond, - LargeScaleLimit: 1 << 30, + LargeScaleLimit: 1 << 30, SmallScalePeriod: time.Millisecond, - SmallScaleLimit: 1 << 30, + SmallScaleLimit: 1 << 30, } func readConfig(filename string) (*Config, error) { diff --git a/service/limiter.go b/service/limiter.go index 6167baa9..695d7dc6 100644 --- a/service/limiter.go +++ b/service/limiter.go @@ -27,21 +27,21 @@ import ( type KeyLimits struct { LargeScalePeriod time.Duration - LargeScaleLimit int64 + LargeScaleLimit int64 SmallScalePeriod time.Duration - SmallScaleLimit int64 + SmallScaleLimit int64 } -type RateLimiterConfig struct { +type TrafficLimiterConfig struct { KeyToLimits map[string]KeyLimits } -type RateLimiter interface { - WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer, error) +type TrafficLimiter interface { + WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer) Allow(accessKey string, n int) error } -func NewRateLimiter(config *RateLimiterConfig) RateLimiter { +func NewTrafficLimiter(config *TrafficLimiterConfig) TrafficLimiter { keyToLimiter := make(map[string]*perKeyLimiter, 0) for accessKey, limits := range config.KeyToLimits { keyToLimiter[accessKey] = &perKeyLimiter{ @@ -49,10 +49,10 @@ func NewRateLimiter(config *RateLimiterConfig) RateLimiter { smallScale: createLimiter(limits.SmallScalePeriod, limits.SmallScaleLimit), } } - return &rateLimiter{keyToLimiter: keyToLimiter} + return &trafficLimiter{keyToLimiter: keyToLimiter} } -type rateLimiter struct { +type trafficLimiter struct { keyToLimiter map[string]*perKeyLimiter } @@ -61,7 +61,7 @@ type perKeyLimiter struct { largeScale *rate.Limiter } -// We need larger granularity, because rate.RateLimiter +// We need larger granularity, because rate.TrafficLimiter // works with ints. const tokenSizeBytes = 1024 const maxSizeBytes = math.MaxInt32 * tokenSizeBytes @@ -71,7 +71,7 @@ func bytesToTokens64(n int64) int { if n >= maxSizeBytes { log.Panicf("%v bytes cannot be converted to tokens", n) } - return (int) ((n + tokenSizeBytes - 1) / tokenSizeBytes) + return (int)((n + tokenSizeBytes - 1) / tokenSizeBytes) } func bytesToTokens(n int) int { @@ -92,13 +92,14 @@ func (l *perKeyLimiter) Wait(n int) error { if !l.largeScale.AllowN(time.Now(), tokens) { return fmt.Errorf("exceeds large scale limit") } - for tokens > 0 { + waited := 0 + for waited < tokens { batch := min(tokens, int(l.smallScale.Burst())) err := l.smallScale.WaitN(context.TODO(), batch) if err != nil { return err } - tokens -= batch + waited += batch } return nil } @@ -126,7 +127,7 @@ func (r *limitedReader) Read(b []byte) (int, error) { } waitErr := r.limiter.Wait(n) if waitErr != nil { - return n, waitErr + return 0, waitErr } return n, err } @@ -143,7 +144,7 @@ func (w *limitedWriter) Write(b []byte) (int, error) { } waitErr := w.limiter.Wait(n) if waitErr != nil { - return n, waitErr + return 0, waitErr } return n, err } @@ -154,18 +155,18 @@ func createLimiter(period time.Duration, limit int64) *rate.Limiter { return rate.NewLimiter(r, b) } -func (l *rateLimiter) WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer, error) { +func (l *trafficLimiter) WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer) { limiter, ok := l.keyToLimiter[accessKey] if !ok { - return nil, nil, fmt.Errorf("Access key %v not found", accessKey) + logger.Panicf("Access key %v not found", accessKey) } - return &limitedReader{reader: reader, limiter: limiter}, &limitedWriter{writer: writer, limiter: limiter}, nil + return &limitedReader{reader: reader, limiter: limiter}, &limitedWriter{writer: writer, limiter: limiter} } -func (l *rateLimiter) Allow(accessKey string, n int) error { +func (l *trafficLimiter) Allow(accessKey string, n int) error { limiter, ok := l.keyToLimiter[accessKey] if !ok { - return fmt.Errorf("Access key %v not found", accessKey) + logger.Panicf("Access key %v not found", accessKey) } return limiter.Allow(n) } diff --git a/service/limiter_test.go b/service/limiter_test.go index aff3dcd2..a1bddf91 100644 --- a/service/limiter_test.go +++ b/service/limiter_test.go @@ -16,29 +16,29 @@ func makeRandBuffer(n int64) *bytes.Buffer { return bytes.NewBuffer(arr) } -func TestRateLimiter(t *testing.T) { +func TestTrafficLimiter(t *testing.T) { key1 := "key1" key2 := "key2" var tok int64 = 1024 - config := RateLimiterConfig{ + config := TrafficLimiterConfig{ KeyToLimits: map[string]KeyLimits{ - key1: KeyLimits { + key1: KeyLimits{ LargeScalePeriod: time.Minute, - LargeScaleLimit: 10 * tok, + LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, - SmallScaleLimit: 2 * tok, + SmallScaleLimit: 2 * tok, }, key2: KeyLimits{ LargeScalePeriod: time.Minute, - LargeScaleLimit: 10 * tok, + LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, - SmallScaleLimit: 3 * tok, + SmallScaleLimit: 3 * tok, }, }, } - limiter := NewRateLimiter(&config) + limiter := NewTrafficLimiter(&config) src1 := makeRandBuffer(20 * tok) src1Orig := src1.Bytes() @@ -48,10 +48,8 @@ func TestRateLimiter(t *testing.T) { src2Orig := src2.Bytes() dst2 := &bytes.Buffer{} - r1, w1, err1 := limiter.WrapReaderWriter(key1, src1, dst1) - require.NoError(t, err1) - r2, w2, err2 := limiter.WrapReaderWriter(key2, src2, dst2) - require.NoError(t, err2) + r1, w1 := limiter.WrapReaderWriter(key1, src1, dst1) + r2, w2 := limiter.WrapReaderWriter(key2, src2, dst2) b := make([]byte, 50) @@ -59,7 +57,7 @@ func TestRateLimiter(t *testing.T) { _, err := io.ReadFull(r1, b) require.NoError(t, err) require.Equal(t, b, src1Orig[:len(b)]) - if time.Now().Sub(start) > 10 * time.Millisecond { + if time.Now().Sub(start) > 10*time.Millisecond { t.Errorf("read took too long") } @@ -67,7 +65,7 @@ func TestRateLimiter(t *testing.T) { _, err = io.ReadFull(r2, b) require.NoError(t, err) require.Equal(t, b, src2Orig[:len(b)]) - if time.Now().Sub(start) > 10 * time.Millisecond { + if time.Now().Sub(start) > 10*time.Millisecond { t.Errorf("read took too long") } @@ -76,14 +74,14 @@ func TestRateLimiter(t *testing.T) { _, err = w1.Write(src1Orig[:size]) require.NoError(t, err) require.Equal(t, src1Orig[:size], dst1.Bytes()[:size]) - if time.Now().Sub(start) < 500 * time.Millisecond { + if time.Now().Sub(start) < 500*time.Millisecond { t.Fatalf("write took too short") } - allowErr := limiter.Allow(key2, int(3 * tok)) + allowErr := limiter.Allow(key2, int(3*tok)) require.NoError(t, allowErr) - allowErr = limiter.Allow(key2, int(1 * tok)) + allowErr = limiter.Allow(key2, int(1*tok)) require.Error(t, allowErr) start = time.Now() @@ -91,7 +89,7 @@ func TestRateLimiter(t *testing.T) { _, err = w2.Write(src2Orig[:size]) require.NoError(t, err) require.Equal(t, src2Orig[:size], dst2.Bytes()[:size]) - if time.Now().Sub(start) < 500 * time.Millisecond { + if time.Now().Sub(start) < 500*time.Millisecond { t.Fatalf("write took too short") } @@ -99,7 +97,7 @@ func TestRateLimiter(t *testing.T) { size = 7 * tok _, err = w2.Write(src2Orig[:size]) require.Error(t, err) - if time.Now().Sub(start) > 10 * time.Millisecond { + if time.Now().Sub(start) > 10*time.Millisecond { t.Fatalf("write took too long") } -} \ No newline at end of file +} diff --git a/service/limiter_testing.go b/service/limiter_testing.go index 355fcda8..12de3449 100644 --- a/service/limiter_testing.go +++ b/service/limiter_testing.go @@ -19,18 +19,18 @@ import ( "time" ) -func MakeTestRateLimiterConfig(ciphers CipherList) RateLimiterConfig { +func MakeTestTrafficLimiterConfig(ciphers CipherList) TrafficLimiterConfig { elts := ciphers.SnapshotForClientIP(net.IP{}) keyLimits := KeyLimits{ LargeScalePeriod: 1000 * time.Hour, - LargeScaleLimit: 1 << 30, + LargeScaleLimit: 1 << 30, SmallScalePeriod: 1000 * time.Hour, - SmallScaleLimit: 1 << 30, + SmallScaleLimit: 1 << 30, } keyToLimits := make(map[string]KeyLimits) for _, elt := range elts { entry := elt.Value.(*CipherEntry) keyToLimits[entry.ID] = keyLimits } - return RateLimiterConfig{KeyToLimits: keyToLimits} + return TrafficLimiterConfig{KeyToLimits: keyToLimits} } diff --git a/service/tcp.go b/service/tcp.go index 5fcf5507..00e3a8e4 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -113,7 +113,7 @@ type tcpService struct { stopped bool ciphers CipherList m metrics.ShadowsocksMetrics - limiter RateLimiter + limiter TrafficLimiter running sync.WaitGroup readTimeout time.Duration // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. @@ -124,11 +124,11 @@ type tcpService struct { // NewTCPService creates a TCPService // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. func NewTCPService(ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, - timeout time.Duration, limiter RateLimiter) TCPService { + timeout time.Duration, limiter TrafficLimiter) TCPService { return &tcpService{ ciphers: ciphers, m: m, - readTimeout: timeout, + readTimeout: timeout, limiter: limiter, replayCache: replayCache, targetIPValidator: onet.RequirePublicIP, @@ -234,11 +234,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo var accessKey string if cipherEntry != nil { accessKey = cipherEntry.ID - var limiterErr error - clientReader, clientWriter, limiterErr = s.limiter.WrapReaderWriter(accessKey, clientReader, clientWriter) - if limiterErr != nil { - logger.Panicf("got unexpected error wrapping streams: %v", limiterErr) - } + clientReader, clientWriter = s.limiter.WrapReaderWriter(accessKey, clientReader, clientWriter) } connError := func() *onet.ConnectionError { @@ -269,9 +265,10 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo clientTCPConn.SetReadDeadline(time.Time{}) if err != nil { // Drain to prevent a close on cipher error. - io.Copy(ioutil.Discard, clientConn) + io.Copy(ioutil.Discard, clientReader) return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err) } + logger.Debugf("address %s", clientTCPConn.RemoteAddr().String()) tgtConn, dialErr := dialTarget(tgtAddr, &proxyMetrics, s.targetIPValidator) if dialErr != nil { @@ -287,9 +284,10 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo fromClientErrCh := make(chan error) go func() { _, fromClientErr := ssr.WriteTo(tgtConn) + logger.Debugf("fromClientErr: %v", fromClientErr) if fromClientErr != nil { // Drain to prevent a close in the case of a cipher error. - io.Copy(ioutil.Discard, clientConn) + io.Copy(ioutil.Discard, clientReader) } clientConn.CloseRead() // Send FIN to target. @@ -321,7 +319,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo } s.m.AddClosedTCPConnection(clientLocation, accessKey, status, proxyMetrics, timeToCipher, connDuration) clientConn.Close() // Closing after the metrics are added aids integration testing. - logger.Debugf("Done with status %v, duration %v", status, connDuration) + logger.Debugf("Done with status %v, duration %v, metrics: %v", status, connDuration, proxyMetrics) } // Keep the connection open until we hit the authentication deadline to protect against probing attacks diff --git a/service/tcp_test.go b/service/tcp_test.go index b07ae1dd..38004105 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -277,9 +277,9 @@ func probe(serverAddr *net.TCPAddr, bytesToSend []byte) error { return nil } -func makeLimiter(cipherList CipherList) RateLimiter { - c := MakeTestRateLimiterConfig(cipherList) - return NewRateLimiter(&c) +func makeLimiter(cipherList CipherList) TrafficLimiter { + c := MakeTestTrafficLimiterConfig(cipherList) + return NewTrafficLimiter(&c) } func TestProbeRandom(t *testing.T) { diff --git a/service/udp.go b/service/udp.go index 8d75047b..deaae593 100644 --- a/service/udp.go +++ b/service/udp.go @@ -78,17 +78,17 @@ type udpService struct { m metrics.ShadowsocksMetrics running sync.WaitGroup targetIPValidator onet.TargetIPValidator - limiter RateLimiter + limiter TrafficLimiter } // NewUDPService creates a UDPService -func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics, limiter RateLimiter) UDPService { +func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics, limiter TrafficLimiter) UDPService { return &udpService{ - natTimeout: natTimeout, - ciphers: cipherList, - m: m, + natTimeout: natTimeout, + ciphers: cipherList, + m: m, targetIPValidator: onet.RequirePublicIP, - limiter: limiter, + limiter: limiter, } } @@ -126,7 +126,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { s.mu.Unlock() defer s.running.Done() - nm := newNATmap(s.natTimeout, s.m, &s.running) + nm := newNATmap(s.natTimeout, s.m, &s.running, s.limiter) defer nm.Close() cipherBuf := make([]byte, serverUDPBufferSize) textBuf := make([]byte, serverUDPBufferSize) @@ -354,10 +354,11 @@ type natmap struct { timeout time.Duration metrics metrics.ShadowsocksMetrics running *sync.WaitGroup + limiter TrafficLimiter } -func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap { - m := &natmap{metrics: sm, running: running} +func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup, limiter TrafficLimiter) *natmap { + m := &natmap{metrics: sm, running: running, limiter: limiter} m.keyConn = make(map[string]*natconn) m.timeout = timeout return m @@ -403,7 +404,7 @@ func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss. m.metrics.AddUDPNatEntry() m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + timedCopy(clientAddr, clientConn, entry, keyID, m.metrics, m.limiter) m.metrics.RemoveUDPNatEntry() if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -433,7 +434,7 @@ var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, - keyID string, sm metrics.ShadowsocksMetrics) { + keyID string, sm metrics.ShadowsocksMetrics, limiter TrafficLimiter) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. @@ -467,6 +468,13 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco } debugUDPAddr(clientAddr, "Got response from %v", raddr) + + limitErr := limiter.Allow(keyID, bodyLen) + if limitErr != nil { + debugUDPAddr(clientAddr, "Rate limite exceeded: %v", limitErr) + return onet.NewConnectionError("ERR_LIMIT", "Rate limit exceeded", limitErr) + } + srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: diff --git a/service/udp_test.go b/service/udp_test.go index 808e8436..6a7b6569 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -202,17 +202,22 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, makeLimiter(NewCipherList())) if nat.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } } func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(1)) + if err != nil { + logger.Fatal(err) + } + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, makeLimiter(cipherList)) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id") + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*CipherEntry).ID + nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", key) entry := nat.Get(clientAddr.String()) return clientConn, targetConn, entry } From 065f76d8cdb730b079d0288b9dc9fb841cf929f0 Mon Sep 17 00:00:00 2001 From: Lev Vysotsky Date: Wed, 13 Apr 2022 18:35:01 +0300 Subject: [PATCH 3/5] Optional limiters --- integration_test/integration_test.go | 8 ++++---- server.go | 30 ++++++---------------------- service/limiter.go | 23 ++++++++++++++++----- service/limiter_test.go | 6 +++--- service/limiter_testing.go | 11 ++-------- 5 files changed, 33 insertions(+), 45 deletions(-) diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index 74ffb7bd..f4f00052 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -187,8 +187,8 @@ func TestTrafficLimiterTCP(t *testing.T) { key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID const tok = 1024 trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ - KeyToLimits: map[string]service.KeyLimits{ - key: service.KeyLimits{ + KeyToLimits: map[string]*service.KeyLimits{ + key: &service.KeyLimits{ LargeScaleLimit: 80 * tok, LargeScalePeriod: 60 * time.Second, SmallScaleLimit: 10 * tok, @@ -295,8 +295,8 @@ func TestTrafficLimiterUDP(t *testing.T) { key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID smallScalePeriod := 100 * time.Millisecond trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ - KeyToLimits: map[string]service.KeyLimits{ - key: service.KeyLimits{ + KeyToLimits: map[string]*service.KeyLimits{ + key: &service.KeyLimits{ LargeScaleLimit: 100 * tok, LargeScalePeriod: 60 * time.Second, SmallScaleLimit: 10 * tok, diff --git a/server.go b/server.go index 61efc72a..2b369a8d 100644 --- a/server.go +++ b/server.go @@ -122,7 +122,7 @@ func (s *SSServer) loadConfig(filename string) error { portChanges := make(map[int]int) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. - portKeyLimits := make(map[int]map[string]service.KeyLimits) + portKeyLimits := make(map[int]map[string]*service.KeyLimits) for _, keyConfig := range config.Keys { portChanges[keyConfig.Port] = 1 cipherList, ok := portCiphers[keyConfig.Port] @@ -136,18 +136,15 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(keyConfig.ID, cipher, keyConfig.Secret) cipherList.PushBack(&entry) - var keyLimits map[string]service.KeyLimits + var keyLimits map[string]*service.KeyLimits keyLimits, ok = portKeyLimits[keyConfig.Port] if !ok { - keyLimits = make(map[string]service.KeyLimits) + keyLimits = make(map[string]*service.KeyLimits) portKeyLimits[keyConfig.Port] = keyLimits } - if keyConfig.Limits != nil { - keyLimits[keyConfig.ID] = *keyConfig.Limits - } else if config.DefaultKeyLimits != nil { - keyLimits[keyConfig.ID] = *config.DefaultKeyLimits - } else { - keyLimits[keyConfig.ID] = noLimits + keyLimits[keyConfig.ID] = keyConfig.Limits + if config.DefaultKeyLimits != nil { + keyLimits[keyConfig.ID] = config.DefaultKeyLimits } } for port := range s.ports { @@ -219,13 +216,6 @@ type Config struct { DefaultKeyLimits *service.KeyLimits } -var noLimits service.KeyLimits = service.KeyLimits{ - LargeScalePeriod: time.Millisecond, - LargeScaleLimit: 1 << 30, - SmallScalePeriod: time.Millisecond, - SmallScaleLimit: 1 << 30, -} - func readConfig(filename string) (*Config, error) { config := Config{} configData, err := ioutil.ReadFile(filename) @@ -236,14 +226,6 @@ func readConfig(filename string) (*Config, error) { if err != nil { return nil, err } - if config.DefaultKeyLimits == nil { - config.DefaultKeyLimits = &noLimits - } - for i := range config.Keys { - if config.Keys[i].Limits == nil { - config.Keys[i].Limits = config.DefaultKeyLimits - } - } return &config, err } diff --git a/service/limiter.go b/service/limiter.go index 695d7dc6..1ca6995a 100644 --- a/service/limiter.go +++ b/service/limiter.go @@ -33,7 +33,8 @@ type KeyLimits struct { } type TrafficLimiterConfig struct { - KeyToLimits map[string]KeyLimits + // If the corresponding KeyLimits is nil, it means no limits + KeyToLimits map[string]*KeyLimits } type TrafficLimiter interface { @@ -42,12 +43,18 @@ type TrafficLimiter interface { } func NewTrafficLimiter(config *TrafficLimiterConfig) TrafficLimiter { - keyToLimiter := make(map[string]*perKeyLimiter, 0) + keyToLimiter := make(map[string]*perKeyLimiter, len(config.KeyToLimits)) for accessKey, limits := range config.KeyToLimits { - keyToLimiter[accessKey] = &perKeyLimiter{ - largeScale: createLimiter(limits.LargeScalePeriod, limits.LargeScaleLimit), - smallScale: createLimiter(limits.SmallScalePeriod, limits.SmallScaleLimit), + var limiter *perKeyLimiter + if limits == nil { + limiter = nil + } else { + limiter = &perKeyLimiter{ + largeScale: createLimiter(limits.LargeScalePeriod, limits.LargeScaleLimit), + smallScale: createLimiter(limits.SmallScalePeriod, limits.SmallScaleLimit), + } } + keyToLimiter[accessKey] = limiter } return &trafficLimiter{keyToLimiter: keyToLimiter} } @@ -160,6 +167,9 @@ func (l *trafficLimiter) WrapReaderWriter(accessKey string, reader io.Reader, wr if !ok { logger.Panicf("Access key %v not found", accessKey) } + if limiter == nil { + return reader, writer + } return &limitedReader{reader: reader, limiter: limiter}, &limitedWriter{writer: writer, limiter: limiter} } @@ -168,5 +178,8 @@ func (l *trafficLimiter) Allow(accessKey string, n int) error { if !ok { logger.Panicf("Access key %v not found", accessKey) } + if limiter == nil { + return nil + } return limiter.Allow(n) } diff --git a/service/limiter_test.go b/service/limiter_test.go index a1bddf91..1ecd548d 100644 --- a/service/limiter_test.go +++ b/service/limiter_test.go @@ -22,14 +22,14 @@ func TestTrafficLimiter(t *testing.T) { var tok int64 = 1024 config := TrafficLimiterConfig{ - KeyToLimits: map[string]KeyLimits{ - key1: KeyLimits{ + KeyToLimits: map[string]*KeyLimits{ + key1: &KeyLimits{ LargeScalePeriod: time.Minute, LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, SmallScaleLimit: 2 * tok, }, - key2: KeyLimits{ + key2: &KeyLimits{ LargeScalePeriod: time.Minute, LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, diff --git a/service/limiter_testing.go b/service/limiter_testing.go index 12de3449..927cd27a 100644 --- a/service/limiter_testing.go +++ b/service/limiter_testing.go @@ -16,21 +16,14 @@ package service import ( "net" - "time" ) func MakeTestTrafficLimiterConfig(ciphers CipherList) TrafficLimiterConfig { elts := ciphers.SnapshotForClientIP(net.IP{}) - keyLimits := KeyLimits{ - LargeScalePeriod: 1000 * time.Hour, - LargeScaleLimit: 1 << 30, - SmallScalePeriod: 1000 * time.Hour, - SmallScaleLimit: 1 << 30, - } - keyToLimits := make(map[string]KeyLimits) + keyToLimits := make(map[string]*KeyLimits) for _, elt := range elts { entry := elt.Value.(*CipherEntry) - keyToLimits[entry.ID] = keyLimits + keyToLimits[entry.ID] = nil } return TrafficLimiterConfig{KeyToLimits: keyToLimits} } From 3b6be510dd3fbc7eb448fa473f2806cec6cf47f9 Mon Sep 17 00:00:00 2001 From: Lev Vysotsky Date: Wed, 13 Apr 2022 18:36:39 +0300 Subject: [PATCH 4/5] Fix --- service/tcp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/tcp.go b/service/tcp.go index 00e3a8e4..f7abfaac 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -319,7 +319,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo } s.m.AddClosedTCPConnection(clientLocation, accessKey, status, proxyMetrics, timeToCipher, connDuration) clientConn.Close() // Closing after the metrics are added aids integration testing. - logger.Debugf("Done with status %v, duration %v, metrics: %v", status, connDuration, proxyMetrics) + logger.Debugf("Done with status %v, duration %v", status, connDuration) } // Keep the connection open until we hit the authentication deadline to protect against probing attacks From fc55bb6aa244b3616ea399dff90134dfe9fc67d9 Mon Sep 17 00:00:00 2001 From: Lev Vysotsky Date: Wed, 13 Apr 2022 19:05:03 +0300 Subject: [PATCH 5/5] Default yaml --- config_example.yml | 11 +++++++++++ integration_test/integration_test.go | 4 ++-- server.go | 24 ++++++++++++------------ service/limiter.go | 4 ++-- service/limiter_test.go | 6 +++--- service/limiter_testing.go | 2 +- 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/config_example.yml b/config_example.yml index 8895b86d..725d8827 100644 --- a/config_example.yml +++ b/config_example.yml @@ -13,3 +13,14 @@ keys: port: 9001 cipher: chacha20-ietf-poly1305 secret: Secret2 + traffic_limits: + large_scale_limit: 100000000000 + large_scale_period: "30d" + small_scale_limit: 12800000000 + small_scale_period: "1s" + +default_traffic_limits: + large_scale_limit: 2000000000 + large_scale_period: "30d" + small_scale_limit: 128000000 + small_scale_period: "1s" diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index f4f00052..14fd7ae7 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -187,7 +187,7 @@ func TestTrafficLimiterTCP(t *testing.T) { key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID const tok = 1024 trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ - KeyToLimits: map[string]*service.KeyLimits{ + KeyToLimits: map[string]*service.TrafficLimits{ key: &service.KeyLimits{ LargeScaleLimit: 80 * tok, LargeScalePeriod: 60 * time.Second, @@ -295,7 +295,7 @@ func TestTrafficLimiterUDP(t *testing.T) { key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID smallScalePeriod := 100 * time.Millisecond trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ - KeyToLimits: map[string]*service.KeyLimits{ + KeyToLimits: map[string]*service.TrafficLimits{ key: &service.KeyLimits{ LargeScaleLimit: 100 * tok, LargeScalePeriod: 60 * time.Second, diff --git a/server.go b/server.go index 2b369a8d..41b858a7 100644 --- a/server.go +++ b/server.go @@ -122,7 +122,7 @@ func (s *SSServer) loadConfig(filename string) error { portChanges := make(map[int]int) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. - portKeyLimits := make(map[int]map[string]*service.KeyLimits) + portKeyLimits := make(map[int]map[string]*service.TrafficLimits) for _, keyConfig := range config.Keys { portChanges[keyConfig.Port] = 1 cipherList, ok := portCiphers[keyConfig.Port] @@ -136,15 +136,15 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(keyConfig.ID, cipher, keyConfig.Secret) cipherList.PushBack(&entry) - var keyLimits map[string]*service.KeyLimits + var keyLimits map[string]*service.TrafficLimits keyLimits, ok = portKeyLimits[keyConfig.Port] if !ok { - keyLimits = make(map[string]*service.KeyLimits) + keyLimits = make(map[string]*service.TrafficLimits) portKeyLimits[keyConfig.Port] = keyLimits } - keyLimits[keyConfig.ID] = keyConfig.Limits - if config.DefaultKeyLimits != nil { - keyLimits[keyConfig.ID] = config.DefaultKeyLimits + keyLimits[keyConfig.ID] = keyConfig.TrafficLimits + if config.DefaultTrafficLimits != nil { + keyLimits[keyConfig.ID] = config.DefaultTrafficLimits } } for port := range s.ports { @@ -207,13 +207,13 @@ func RunSSServer(filename string, natTimeout time.Duration, sm metrics.Shadowsoc type Config struct { Keys []struct { - ID string - Port int - Cipher string - Secret string - Limits *service.KeyLimits + ID string + Port int + Cipher string + Secret string + TrafficLimits *service.TrafficLimits } - DefaultKeyLimits *service.KeyLimits + DefaultTrafficLimits *service.TrafficLimits } func readConfig(filename string) (*Config, error) { diff --git a/service/limiter.go b/service/limiter.go index 1ca6995a..5b068bb6 100644 --- a/service/limiter.go +++ b/service/limiter.go @@ -25,7 +25,7 @@ import ( "golang.org/x/time/rate" ) -type KeyLimits struct { +type TrafficLimits struct { LargeScalePeriod time.Duration LargeScaleLimit int64 SmallScalePeriod time.Duration @@ -34,7 +34,7 @@ type KeyLimits struct { type TrafficLimiterConfig struct { // If the corresponding KeyLimits is nil, it means no limits - KeyToLimits map[string]*KeyLimits + KeyToLimits map[string]*TrafficLimits } type TrafficLimiter interface { diff --git a/service/limiter_test.go b/service/limiter_test.go index 1ecd548d..afc9beb0 100644 --- a/service/limiter_test.go +++ b/service/limiter_test.go @@ -22,14 +22,14 @@ func TestTrafficLimiter(t *testing.T) { var tok int64 = 1024 config := TrafficLimiterConfig{ - KeyToLimits: map[string]*KeyLimits{ - key1: &KeyLimits{ + KeyToLimits: map[string]*TrafficLimits{ + key1: &TrafficLimits{ LargeScalePeriod: time.Minute, LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, SmallScaleLimit: 2 * tok, }, - key2: &KeyLimits{ + key2: &TrafficLimits{ LargeScalePeriod: time.Minute, LargeScaleLimit: 10 * tok, SmallScalePeriod: time.Second, diff --git a/service/limiter_testing.go b/service/limiter_testing.go index 927cd27a..6a3e6c7c 100644 --- a/service/limiter_testing.go +++ b/service/limiter_testing.go @@ -20,7 +20,7 @@ import ( func MakeTestTrafficLimiterConfig(ciphers CipherList) TrafficLimiterConfig { elts := ciphers.SnapshotForClientIP(net.IP{}) - keyToLimits := make(map[string]*KeyLimits) + keyToLimits := make(map[string]*TrafficLimits) for _, elt := range elts { entry := elt.Value.(*CipherEntry) keyToLimits[entry.ID] = nil