diff --git a/tcp_sink.go b/tcp_sink.go index dbe447a..4900540 100644 --- a/tcp_sink.go +++ b/tcp_sink.go @@ -13,11 +13,6 @@ import ( logger "github.com/sirupsen/logrus" ) -// TODO(btc): add constructor that accepts functional options in order to allow -// users to choose the constants that work best for them. (Leave the existing -// c'tor for backwards compatibility) -// e.g. `func NewTCPStatsdSinkWithOptions(opts ...Option) Sink` - const ( defaultRetryInterval = time.Second * 3 defaultDialTimeout = defaultRetryInterval / 2 @@ -30,19 +25,65 @@ const ( chanSize = approxMaxMemBytes / defaultBufferSize ) +// An SinkOption configures a Sink. +type SinkOption interface { + apply(*tcpStatsdSink) +} + +// sinkOptionFunc wraps a func so it satisfies the Option interface. +type sinkOptionFunc func(*tcpStatsdSink) + +func (f sinkOptionFunc) apply(sink *tcpStatsdSink) { + f(sink) +} + +// WithStatsdHost sets the host of the statsd sink otherwise the host is +// read from the environment variable "STATSD_HOST". +func WithStatsdHost(host string) SinkOption { + return sinkOptionFunc(func(sink *tcpStatsdSink) { + sink.statsdHost = host + }) +} + +// WithStatsdPort sets the port of the statsd sink otherwise the port is +// read from the environment variable "STATSD_PORT". +func WithStatsdPort(port int) SinkOption { + return sinkOptionFunc(func(sink *tcpStatsdSink) { + sink.statsdPort = port + }) +} + +// WithLogger configures the sink to use the provided logger otherwise +// the standard logrus logger is used. +func WithLogger(log *logger.Logger) SinkOption { + // TODO (CEV): use the zap.Logger + return sinkOptionFunc(func(sink *tcpStatsdSink) { + sink.log = log + }) +} + // NewTCPStatsdSink returns a FlushableSink that is backed by a buffered writer // and a separate goroutine that flushes those buffers to a statsd connection. -func NewTCPStatsdSink() FlushableSink { +func NewTCPStatsdSink(opts ...SinkOption) FlushableSink { outc := make(chan *bytes.Buffer, chanSize) // TODO(btc): parameterize writer := sinkWriter{ outc: outc, } + // TODO (CEV): this auto loading from the env is bad and should be removed. + conf := GetSettings() s := &tcpStatsdSink{ outc: outc, // TODO(btc): parameterize size bufWriter: bufio.NewWriterSize(&writer, defaultBufferSize), // arbitrarily buffered doFlush: make(chan chan struct{}, 8), + // CEV: default to the standard logger to match the legacy implementation. + log: logger.StandardLogger(), + statsdHost: conf.StatsdHost, + statsdPort: conf.StatsdPort, + } + for _, opt := range opts { + opt.apply(s) } go s.run() return s @@ -55,6 +96,9 @@ type tcpStatsdSink struct { bufWriter *bufio.Writer doFlush chan chan struct{} droppedBytes uint64 + log *logger.Logger + statsdHost string + statsdPort int } type sinkWriter struct { @@ -92,11 +136,26 @@ func (s *tcpStatsdSink) flush() error { return err } +func (s *tcpStatsdSink) drainFlushQueue() { + // Limit the number of items we'll flush to prevent this from possibly + // hanging when the flush channel is saturated with sends. + doFlush := s.doFlush + n := cap(doFlush) * 8 + for i := 0; i < n; i++ { + select { + case ch := <-doFlush: + close(ch) + default: + return + } + } +} + // s.mu should be held func (s *tcpStatsdSink) handleFlushErrorSize(err error, dropped int) { d := uint64(dropped) if (s.droppedBytes+d)%logOnEveryNDroppedBytes > s.droppedBytes%logOnEveryNDroppedBytes { - logger.WithField("total_dropped_bytes", s.droppedBytes+d). + s.log.WithField("total_dropped_bytes", s.droppedBytes+d). WithField("dropped_bytes", d). Error(err) } @@ -177,20 +236,29 @@ func (s *tcpStatsdSink) FlushTimer(name string, value float64) { } func (s *tcpStatsdSink) run() { - conf := GetSettings() - addr := net.JoinHostPort(conf.StatsdHost, strconv.Itoa(conf.StatsdPort)) + addr := net.JoinHostPort(s.statsdHost, strconv.Itoa(s.statsdPort)) + + var reconnectFailed bool // true if last reconnect failed t := time.NewTicker(flushInterval) defer t.Stop() for { if s.conn == nil { if err := s.connect(addr); err != nil { - logger.Warnf("statsd connection error: %s", err) + s.log.Warnf("statsd connection error: %s", err) + + // If the previous reconnect attempt failed, drain the flush + // queue to prevent Flush() from blocking indefinitely. + if reconnectFailed { + s.drainFlushQueue() + } + reconnectFailed = true // TODO (CEV): don't sleep on the first retry - time.Sleep(3 * time.Second) + time.Sleep(defaultRetryInterval) continue } + reconnectFailed = false } select { diff --git a/tcp_sink_test.go b/tcp_sink_test.go index cd4c4b1..03af481 100644 --- a/tcp_sink_test.go +++ b/tcp_sink_test.go @@ -2,13 +2,19 @@ package stats import ( "bufio" + "bytes" + "context" "fmt" "io" "io/ioutil" "net" "os" + "os/exec" + "path/filepath" + "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -482,6 +488,488 @@ func TestTCPStatsdSink_Flush(t *testing.T) { }) } +// Test that drainFlushQueue() does not hang when there are continuous +// flush requests. +func TestTCPStatsdSink_DrainFlushQueue(t *testing.T) { + s := &tcpStatsdSink{ + doFlush: make(chan chan struct{}, 8), + } + + sent := new(int64) + + // Saturate the flush channel + + done := make(chan struct{}) + defer close(done) + + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for { + select { + case s.doFlush <- make(chan struct{}): + atomic.AddInt64(sent, 1) + case <-done: + return + } + } + }() + } + + // Wait for the flush channel to fill + for len(s.doFlush) < cap(s.doFlush) { + runtime.Gosched() + } + + flushed := make(chan struct{}) + go func() { + s.drainFlushQueue() + close(flushed) + }() + + // We will flush up to cap(s.doFlush) * 8 items, so the max number + // of sends will be that plus the capacity of the buffer. + maxSends := cap(s.doFlush)*8 + cap(s.doFlush) + + select { + case <-flushed: + n := int(atomic.LoadInt64(sent)) + switch { + case n < cap(s.doFlush): + // This should be impossible since we fill the channel + // before calling drainFlushQueue(). + t.Errorf("Sent less than %d items: %d", cap(s.doFlush), n) + case n > maxSends: + // This should be nearly impossible to get without inserting + // runtime.Gosched() into the flush/drain loop. + t.Errorf("Sent more than %d items: %d", maxSends, n) + } + case <-time.After(time.Second / 2): + // 500ms is really generous, it should return almost immediately. + t.Error("drainFlushQueue did not return in time") + } +} + +type tcpTestSink struct { + ll *net.TCPListener + addr *net.TCPAddr + mu sync.Mutex // buf lock + buf bytes.Buffer + stats chan string + done chan struct{} // closed when read loop exits +} + +func newTCPTestSink(t testing.TB) *tcpTestSink { + l, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + }) + if err != nil { + t.Fatal("ListenTCP:", err) + } + s := &tcpTestSink{ + ll: l, + addr: l.Addr().(*net.TCPAddr), + stats: make(chan string, 64), + done: make(chan struct{}), + } + go s.run(t) + return s +} + +func (s *tcpTestSink) writeStat(line []byte) { + select { + case s.stats <- string(line): + default: + } + s.mu.Lock() + s.buf.Write(line) + s.mu.Unlock() +} + +func (s *tcpTestSink) run(t testing.TB) { + defer close(s.done) + buf := bufio.NewReader(nil) + for { + conn, err := s.ll.AcceptTCP() + if err != nil { + // Log errors other than poll.ErrNetClosing, which is an + // internal error so we have to match against it's string. + if !strings.Contains(err.Error(), "use of closed network connection") { + t.Logf("Error: accept: %v", err) + } + return + } + // read stats line by line + buf.Reset(conn) + for { + b, e := buf.ReadBytes('\n') + if len(b) > 0 { + s.writeStat(b) + } + if e != nil { + if e != io.EOF { + err = e + } + break + } + } + if buf.Buffered() != 0 { + buf.WriteTo(&s.buf) + } + if err != nil { + t.Errorf("Error: reading stats: %v", err) + } + } +} + +func (s *tcpTestSink) Restart(t testing.TB, resetBuffer bool) { + if err := s.Close(); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + } + select { + case <-s.done: + // Ok + case <-time.After(time.Second * 3): + t.Fatal("timeout waiting for run loop to exit") + } + l, err := net.ListenTCP(s.addr.Network(), s.addr) + if err != nil { + t.Fatalf("restarting connection: %v", err) + } + if resetBuffer { + s.buf.Reset() + } + *s = tcpTestSink{ + ll: l, + addr: s.addr, + buf: s.buf, + stats: make(chan string, 64), + done: make(chan struct{}), + } + go s.run(t) +} + +func (s *tcpTestSink) Close() error { + select { + case <-s.done: + return nil // closed + default: + return s.ll.Close() + } +} + +func (s *tcpTestSink) WaitForStat(t testing.TB, timeout time.Duration) string { + t.Helper() + if timeout <= 0 { + timeout = defaultRetryInterval * 2 + } + to := time.NewTimer(timeout) + defer to.Stop() + select { + case s := <-s.stats: + return s + case <-to.C: + t.Fatalf("timeout waiting to receive stat: %s", timeout) + } + return "" +} + +func (s *tcpTestSink) Stats() <-chan string { + return s.stats +} + +func (s *tcpTestSink) Bytes() []byte { + s.mu.Lock() + b := append([]byte(nil), s.buf.Bytes()...) + s.mu.Unlock() + return b +} + +func (s *tcpTestSink) String() string { + s.mu.Lock() + str := s.buf.String() + s.mu.Unlock() + return str +} + +func (s *tcpTestSink) Address() *net.TCPAddr { + return s.addr +} + +func mergeEnv(extra ...string) []string { + var prefixes []string + for _, s := range extra { + n := strings.IndexByte(s, '=') + prefixes = append(prefixes, s[:n+1]) + } + ignore := func(s string) bool { + for _, pfx := range prefixes { + if strings.HasPrefix(s, pfx) { + return true + } + } + return false + } + + env := os.Environ() + a := env[:0] + for _, s := range env { + if !ignore(s) { + a = append(a, s) + } + } + return append(a, extra...) +} + +// CommandEnv returns the environment variables for an *exec.Cmd to use +// with this test sink. +func (s *tcpTestSink) CommandEnv() []string { + return mergeEnv( + fmt.Sprintf("STATSD_PORT=%d", s.Address().Port), + fmt.Sprintf("STATSD_HOST=%s", s.Address().IP.String()), + "GOSTATS_FLUSH_INTERVAL_SECONDS=1", + ) +} + +func discardLogger() *logger.Logger { + log := logger.New() + log.Out = ioutil.Discard + return log +} + +func TestTCPStatsdSink(t *testing.T) { + setup := func(t *testing.T, stop bool) (*tcpTestSink, *tcpStatsdSink) { + ts := newTCPTestSink(t) + + if stop { + if err := ts.Close(); err != nil { + t.Fatal(err) + } + } + + sink := NewTCPStatsdSink( + WithLogger(discardLogger()), + WithStatsdHost(ts.Address().IP.String()), + WithStatsdPort(ts.Address().Port), + ).(*tcpStatsdSink) + + return ts, sink + } + + t.Run("StatTypes", func(t *testing.T) { + var expected = [...]string{ + "counter:1|c\n", + "gauge:1|g\n", + "timer_int:1|ms\n", + "timer_float:1.230000|ms\n", + } + + ts, sink := setup(t, false) + defer ts.Close() + + sink.FlushCounter("counter", 1) + sink.FlushGauge("gauge", 1) + sink.FlushTimer("timer_int", 1) + sink.FlushTimer("timer_float", 1.23) + sink.Flush() + + for _, exp := range expected { + stat := ts.WaitForStat(t, time.Millisecond*50) + if stat != exp { + t.Errorf("stats got: %q want: %q", stat, exp) + } + } + + // make sure there aren't any extra stats we're missing + exp := strings.Join(expected[:], "") + buf := ts.String() + if buf != exp { + t.Errorf("stats buffer\ngot:\n%q\nwant:\n%q\n", buf, exp) + } + }) + + // Make sure that stats are immediately flushed so that stats from fast + // exiting programs are not lost. + t.Run("ImmediateFlush", func(t *testing.T) { + const expected = "counter:1|c\n" + + ts, sink := setup(t, false) + defer ts.Close() + + sink.FlushCounter("counter", 1) + sink.Flush() + + stat := ts.WaitForStat(t, time.Millisecond*50) + if stat != expected { + t.Errorf("stats got: %q want: %q", stat, expected) + } + }) + + // Test that we can successfully reconnect and flush a stat. + t.Run("Reconnect", func(t *testing.T) { + if testing.Short() { + t.Skip("Skipping: short test") + } + t.Parallel() + + const expected = "counter:1|c\n" + + ts, sink := setup(t, true) + defer ts.Close() + + sink.FlushCounter("counter", 1) + + flushed := make(chan struct{}) + go func() { + flushed <- struct{}{} + sink.Flush() + close(flushed) + }() + + <-flushed // wait till we're ready + ts.Restart(t, true) + + stat := ts.WaitForStat(t, defaultRetryInterval*2) + if stat != expected { + t.Fatalf("stats got: %q want: %q", stat, expected) + } + + // Make sure our flush call returned + select { + case <-flushed: + case <-time.After(time.Millisecond * 100): + // The flushed channel should be closed by this point, + // but this was failing in CI on go1.12 due to timing + // issues so we relax the constraint and give it 100ms. + t.Error("Flush() did not return") + } + }) + + // Test that when reconnecting fails, calls the Flush() do not block + // indefinitely. + t.Run("ReconnectFailure", func(t *testing.T) { + if testing.Short() { + t.Skip("Skipping: short test") + } + t.Parallel() + + ts, sink := setup(t, true) + defer ts.Close() + + sink.FlushCounter("counter", 1) + + const N = 16 + flushCount := new(int64) + flushed := make(chan struct{}) + go func() { + wg := new(sync.WaitGroup) + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + sink.Flush() + atomic.AddInt64(flushCount, 1) + wg.Done() + }() + } + wg.Wait() + close(flushed) + }() + + // Make sure our flush call returned + select { + case <-flushed: + // Ok + case <-time.After(defaultRetryInterval * 2): + t.Fatalf("Only %d of %d Flush() calls succeeded", + atomic.LoadInt64(flushCount), N) + } + }) +} + +func buildBinary(t testing.TB, path string) (string, func()) { + var binaryName string + if strings.HasSuffix(path, ".go") { + // foo/bar/main.go => bar + binaryName = filepath.Base(filepath.Dir(path)) + } else { + filepath.Base(path) + } + + tmpdir, err := ioutil.TempDir("", "gostats-") + if err != nil { + t.Fatalf("creating tempdir: %v", err) + } + output := filepath.Join(tmpdir, binaryName) + + out, err := exec.Command("go", "build", "-o", output, path).CombinedOutput() + if err != nil { + t.Fatalf("failed to build %s: %s\n### output:\n%s\n###\n", + path, err, strings.TrimSpace(string(out))) + } + + cleanup := func() { + os.RemoveAll(tmpdir) + } + return output, cleanup +} + +func TestTCPStatsdSink_Integration(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fastExitExe, deleteBinary := buildBinary(t, "testdata/fast_exit/fast_exit.go") + defer deleteBinary() + + // Test the stats of a fast exiting program are captured. + t.Run("FastExit", func(t *testing.T) { + ts := newTCPTestSink(t) + defer ts.Close() + + cmd := exec.CommandContext(ctx, fastExitExe) + cmd.Env = ts.CommandEnv() + + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Running command: %s\n### output:\n%s\n###\n", + fastExitExe, strings.TrimSpace(string(out))) + } + + stats := ts.String() + const expected = "test.fast.exit.counter:1|c\n" + if stats != expected { + t.Errorf("stats: got: %q want: %q", stats, expected) + } + }) + + // Test that Flush() does not hang if the TCP sink is in a reconnect loop + t.Run("Reconnect", func(t *testing.T) { + ts := newTCPTestSink(t) + defer ts.Close() + + cmd := exec.CommandContext(ctx, fastExitExe) + cmd.Env = ts.CommandEnv() + + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + errCh := make(chan error, 1) + go func() { errCh <- cmd.Wait() }() + + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(defaultRetryInterval * 2): + t.Fatal("Timed out waiting for command to exit") + } + }) +} + type nopWriter struct{} func (nopWriter) Write(b []byte) (int, error) { diff --git a/testdata/fast_exit/.gitignore b/testdata/fast_exit/.gitignore new file mode 100644 index 0000000..cfe711d --- /dev/null +++ b/testdata/fast_exit/.gitignore @@ -0,0 +1,2 @@ +fast_exit +*.exe diff --git a/testdata/fast_exit/fast_exit.go b/testdata/fast_exit/fast_exit.go new file mode 100644 index 0000000..ad18904 --- /dev/null +++ b/testdata/fast_exit/fast_exit.go @@ -0,0 +1,11 @@ +package main + +import stats "github.com/lyft/gostats" + +// Test that the stats of a fast exiting program (such as one that immediately +// errors on startup) can send stats. +func main() { + store := stats.NewDefaultStore() + store.Scope("test.fast.exit").NewCounter("counter").Inc() + store.Flush() +}