diff --git a/benchmarks_test.go b/benchmarks_test.go index 41290e1..c813cfd 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -14,16 +14,16 @@ type valueSource interface { func benchmarkAdd(b *testing.B, n int, src valueSource) { valsToAdd := make([]float64, n) - cset := newCentroidSet(100) + d := NewWithCompression(100) for i := 0; i < n; i++ { v := src.Next() valsToAdd[i] = v - cset.Add(v, 1) + d.Add(v, 1) } b.ResetTimer() for i := 0; i < b.N; i++ { - cset.Add(valsToAdd[i%n], 1) + d.Add(valsToAdd[i%n], 1) } b.StopTimer() } @@ -31,16 +31,16 @@ func benchmarkAdd(b *testing.B, n int, src valueSource) { func benchmarkQuantile(b *testing.B, n int, src valueSource) { quantilesToCheck := make([]float64, n) - cset := newCentroidSet(100) + d := NewWithCompression(100) for i := 0; i < n; i++ { v := src.Next() quantilesToCheck[i] = v - cset.Add(v, 1) + d.Add(v, 1) } b.ResetTimer() for i := 0; i < b.N; i++ { - _ = cset.Quantile(quantilesToCheck[i%n]) + _ = d.Quantile(quantilesToCheck[i%n]) } b.StopTimer() } diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 0000000..59dc5e0 --- /dev/null +++ b/fuzz.go @@ -0,0 +1,42 @@ +// +build gofuzz + +package tdigest + +import ( + "bytes" + "fmt" + "log" + + "github.com/davecgh/go-spew/spew" +) + +func Fuzz(data []byte) int { + v := new(TDigest) + err := v.UnmarshalBinary(data) + if err != nil { + return 0 + } + + remarshaled, err := v.MarshalBinary() + if err != nil { + panic(err) + } + + if !bytes.HasPrefix(data, remarshaled) { + panic(fmt.Sprintf("not equal: \n%v\nvs\n%v", data, remarshaled)) + } + + for q := float64(0.1); q <= 1.0; q += 0.05 { + prev, this := v.Quantile(q-0.1), v.Quantile(q) + if prev-this > 1e-100 { // Floating point math makes this slightly imprecise. + log.Printf("v: %s", spew.Sprint(v)) + log.Printf("q: %v", q) + log.Printf("prev: %v", prev) + log.Printf("this: %v", this) + panic("quantiles should only increase") + } + } + + v.Add(1, 1) + return 1 +} diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..45ddf9b --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,84 @@ +package tdigest + +import ( + "bytes" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestFuzzPanicRegressions(t *testing.T) { + // This test contains a list of byte sequences discovered by + // github.com/dvyukov/go-fuzz which, at one time, caused tdigest to panic. The + // test just makes sure that they no longer cause a panic. + testcase := func(crasher []byte) func(*testing.T) { + return func(t *testing.T) { + v := new(TDigest) + err := v.UnmarshalBinary(crasher) + if err != nil { + return + } + remarshaled, err := v.MarshalBinary() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + if !bytes.HasPrefix(crasher, remarshaled) { + t.Fatalf("not equal: \n%v\nvs\n%v", crasher, remarshaled) + } + + for q := float64(0.1); q <= 1.0; q += 0.05 { + prev, this := v.Quantile(q-0.1), v.Quantile(q) + if prev-this > 1e-100 { // Floating point math makes this slightly imprecise. + t.Logf("v: %s", spew.Sprint(v)) + t.Logf("q: %v", q) + t.Logf("prev: %v", prev) + t.Logf("this: %v", this) + t.Fatal("quantiles should only increase") + } + } + + v.Add(1, 1) + } + } + t.Run("fuzz1", testcase([]byte{ + 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xfc, + })) + t.Run("fuzz2", testcase([]byte{ + 0x01, 0x00, 0x00, 0x00, 0xdb, 0x46, 0x5f, 0xbd, + 0xdb, 0x46, 0x00, 0xbd, 0xe0, 0xdf, 0xca, 0xab, + 0x37, 0x31, 0x37, 0x32, 0x37, 0x33, 0x37, 0x34, + 0x37, 0x35, 0x37, 0x36, 0x37, 0x37, 0x37, 0x38, + 0x37, 0x39, 0x28, + })) + t.Run("fuzz3", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0xbf, + })) + t.Run("fuzz4", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x63, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x4e, + })) + t.Run("fuzz5", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x92, 0x00, + })) +} diff --git a/serde.go b/serde.go new file mode 100644 index 0000000..677ecb1 --- /dev/null +++ b/serde.go @@ -0,0 +1,123 @@ +package tdigest + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" +) + +const ( + magic = int16(0xc80) + encodingVersion = int32(1) +) + +func marshalBinary(d *TDigest) ([]byte, error) { + buf := bytes.NewBuffer(nil) + w := &binaryBufferWriter{buf: buf} + w.writeValue(magic) + w.writeValue(encodingVersion) + w.writeValue(d.compression) + w.writeValue(int32(len(d.centroids))) + for _, c := range d.centroids { + w.writeValue(c.count) + w.writeValue(c.mean) + } + + if w.err != nil { + return nil, w.err + } + return buf.Bytes(), nil +} + +func unmarshalBinary(d *TDigest, p []byte) error { + var ( + mv int16 + ev int32 + n int32 + ) + r := &binaryReader{r: bytes.NewReader(p)} + r.readValue(&mv) + if r.err != nil { + return r.err + } + if mv != magic { + return fmt.Errorf("data corruption detected: invalid header magic value 0x%04x", mv) + } + r.readValue(&ev) + if r.err != nil { + return r.err + } + if ev != encodingVersion { + return fmt.Errorf("data corruption detected: invalid encoding version %d", ev) + } + r.readValue(&d.compression) + r.readValue(&n) + if r.err != nil { + return r.err + } + if n < 0 { + return fmt.Errorf("data corruption detected: number of centroids cannot be negative, have %v", n) + + } + if n > 1<<20 { + return fmt.Errorf("invalid n, cannot be greater than 2^20: %v", n) + } + d.centroids = make([]*centroid, int(n)) + for i := 0; i < int(n); i++ { + c := new(centroid) + r.readValue(&c.count) + r.readValue(&c.mean) + if r.err != nil { + return r.err + } + if c.count < 0 { + return fmt.Errorf("data corruption detected: negative count: %d", c.count) + } + if i > 0 { + prev := d.centroids[i-1] + if c.mean < prev.mean { + return fmt.Errorf("data corruption detected: centroid %d has lower mean (%v) than preceding centroid %d (%v)", i, c.mean, i-1, prev.mean) + } + } + d.centroids[i] = c + if c.count > math.MaxInt64-d.countTotal { + return fmt.Errorf("data corruption detected: centroid total size overflow") + } + d.countTotal += c.count + } + + if n := r.r.Len(); n > 0 { + return fmt.Errorf("found %d unexpected bytes trailing the tdigest", n) + } + + return nil +} + +type binaryBufferWriter struct { + buf *bytes.Buffer + err error +} + +func (w *binaryBufferWriter) writeValue(v interface{}) { + if w.err != nil { + return + } + w.err = binary.Write(w.buf, binary.LittleEndian, v) +} + +type binaryReader struct { + r *bytes.Reader + err error +} + +func (r *binaryReader) readValue(v interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.r, binary.LittleEndian, v) + if r.err == io.EOF { + r.err = io.ErrUnexpectedEOF + } +} diff --git a/serde_test.go b/serde_test.go new file mode 100644 index 0000000..1d7dcfe --- /dev/null +++ b/serde_test.go @@ -0,0 +1,260 @@ +package tdigest + +import ( + "errors" + "io" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestMarshalRoundTrip(t *testing.T) { + testcase := func(in *TDigest) func(*testing.T) { + return func(t *testing.T) { + b, err := in.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary err: %v", err) + } + out := new(TDigest) + err = out.UnmarshalBinary(b) + if err != nil { + t.Fatalf("UnmarshalBinary err: %v", err) + } + if !reflect.DeepEqual(in, out) { + t.Errorf("marshaling round trip resulted in changes") + t.Logf("in: %+v", in) + t.Logf("out: %+v", out) + } + } + } + t.Run("empty", testcase(New())) + t.Run("1 value", testcase(simpleTDigest(1))) + t.Run("1000 values", testcase(simpleTDigest(1000))) +} + +func TestUnmarshalErrors(t *testing.T) { + testcase := func(in []byte, wantErr error) func(*testing.T) { + return func(t *testing.T) { + have := new(TDigest) + err := unmarshalBinary(have, in) + if err != nil { + if wantErr == nil { + t.Fatalf("unexpected unmarshal err: %v", err) + } + if err.Error() != wantErr.Error() { + t.Fatalf("wrong error, want=%q, have=%q", wantErr.Error(), err.Error()) + } else { + return + } + } else if wantErr != nil { + t.Fatalf("expected err=%q, got nil", wantErr.Error()) + } + } + } + t.Run("nil", testcase( + nil, + io.ErrUnexpectedEOF, + )) + t.Run("bad magic", testcase( + []byte{ + 0x80, 0x0d, + }, + errors.New("data corruption detected: invalid header magic value 0x0d80"), + )) + t.Run("incomplete encoding", testcase( + []byte{ + 0x80, 0x0c, + 0x00, + }, + io.ErrUnexpectedEOF, + )) + t.Run("bad encoding", testcase( + []byte{ + 0x80, 0x0c, + 0xFF, 0xFF, 0xFF, 0xFF, + }, + errors.New("data corruption detected: invalid encoding version -1"), + )) + t.Run("incomplete compression", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, + }, + io.ErrUnexpectedEOF, + )) + t.Run("incomplete n", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x00, + }, + io.ErrUnexpectedEOF, + )) + t.Run("negative n", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0xFF, 0xFF, 0xFF, 0xFF, + }, + errors.New("data corruption detected: number of centroids cannot be negative, have -1"), + )) + t.Run("huge n", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0xFF, 0xFF, 0xFF, 0x7F, + }, + errors.New("invalid n, cannot be greater than 2^20: 2147483647"), + )) + t.Run("missing centroids", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x01, 0x00, 0x00, 0x00, + }, + io.ErrUnexpectedEOF, + )) + t.Run("partial centroid", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + }, + io.ErrUnexpectedEOF, + )) + t.Run("negative count", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x01, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + }, + errors.New("data corruption detected: negative count: -1"), + )) + t.Run("decreasing means", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + }, + errors.New("data corruption detected: centroid 1 has lower mean (1) than preceding centroid 0 (2)"), + )) + t.Run("total size overflow", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x02, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, + }, + errors.New("data corruption detected: centroid total size overflow"), + )) + t.Run("trailing bytes", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, + 0x00, + }, + errors.New("found 1 unexpected bytes trailing the tdigest"), + )) +} + +func TestUnmarshal(t *testing.T) { + testcase := func(in []byte, want *TDigest) func(*testing.T) { + return func(t *testing.T) { + have := new(TDigest) + err := unmarshalBinary(have, in) + if err != nil { + t.Fatalf("unexpected unmarshal err: %v", err) + } + if !reflect.DeepEqual(have, want) { + t.Error("unmarshal did not produce expected digest") + t.Logf("want=%s", spew.Sprint(want)) + t.Logf("have=%s", spew.Sprint(have)) + } + } + } + t.Run("no centroids", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x00, 0x00, 0x00, 0x00, + }, + &TDigest{ + centroids: make([]*centroid, 0), + compression: 100, + countTotal: 0, + }, + )) + t.Run("one centroid", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + }, + &TDigest{ + centroids: []*centroid{ + ¢roid{ + count: 1, + mean: 1, + }, + }, + compression: 100, + countTotal: 1, + }, + )) + t.Run("two centroids", testcase( + []byte{ + 0x80, 0x0c, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, + }, + &TDigest{ + centroids: []*centroid{ + ¢roid{ + count: 1, + mean: 1, + }, + ¢roid{ + count: 1, + mean: 2, + }, + }, + compression: 100, + countTotal: 2, + }, + )) +} diff --git a/tdigest.go b/tdigest.go index ffdbd05..a2a8581 100644 --- a/tdigest.go +++ b/tdigest.go @@ -6,30 +6,27 @@ import ( "math/rand" ) -// A TDigest is an efficient data structure for computing streaming -// approximate quantiles of a dataset. It supports two methods: -// -// Add will add a value to the TDigest, updating all quantiles. A -// weight can be specified; use weight of 1 if you don't care about -// weighting your dataset. -// -// Quantile(q) will estimate the qth quantile value of the -// dataset. The input value of q should be in the range [0.0, 1.0]; if -// it is outside that range, it will be clipped into it automatically. -// -// Calling Quantile on a TDigest with no data should return NaN. -// -// MergeInto(other) will add all of the data within a TDigest into -// other, combining them into one larger TDigest. -type TDigest interface { - Add(val float64, weight int) - Quantile(q float64) (val float64) - MergeInto(other TDigest) +// centroid is a simple container for a mean,count pair. +type centroid struct { + mean float64 + count int64 +} + +func (c *centroid) String() string { + return fmt.Sprintf("c{%f x%d}", c.mean, c.count) +} + +// A TDigest is an efficient data structure for computing streaming approximate +// quantiles of a dataset. +type TDigest struct { + centroids []*centroid + compression float64 + countTotal int64 } // New produces a new TDigest using the default compression level of // 100. -func New() TDigest { +func New() *TDigest { return NewWithCompression(100) } @@ -43,30 +40,8 @@ func New() TDigest { // small (think like 1e-6 percentile points) errors at extreme points // in the distribution, and compression ratios of around 500 for large // data sets (1 millionish datapoints). -func NewWithCompression(compression float64) TDigest { - return newCentroidSet(compression) -} - -// centroid is a simple container for a mean,count pair. -type centroid struct { - mean float64 - count int -} - -func (c *centroid) String() string { - return fmt.Sprintf("c{%f x%d}", c.mean, c.count) -} - -type centroidSet struct { - centroids []*centroid - compression float64 - countTotal int - - reclusterAt int -} - -func newCentroidSet(compression float64) *centroidSet { - return ¢roidSet{ +func NewWithCompression(compression float64) *TDigest { + return &TDigest{ centroids: make([]*centroid, 0), compression: compression, countTotal: 0, @@ -77,14 +52,14 @@ func newCentroidSet(compression float64) *centroidSet { // input value. // // TODO: Use a better data structure to avoid this loop. -func (cs *centroidSet) nearest(val float64) []int { +func (d *TDigest) nearest(val float64) []int { var ( nearestDist float64 = math.Inf(+1) thisDist float64 delta float64 result []int = make([]int, 0) ) - for i, c := range cs.centroids { + for i, c := range d.centroids { thisDist = val - c.mean if thisDist < 0 { thisDist *= -1 @@ -101,7 +76,7 @@ func (cs *centroidSet) nearest(val float64) []int { // we have a tie result = append(result, i) default: - // Since cs.centroids is sorted by mean, this means we + // Since d.centroids is sorted by mean, this means we // have passed the best spot, so we may as well break break } @@ -110,27 +85,27 @@ func (cs *centroidSet) nearest(val float64) []int { } // returns the maximum weight that can be placed at specified index -func (cs *centroidSet) weightLimit(idx int) int { - ptile := cs.quantileOf(idx) - limit := int(4 * cs.compression * ptile * (1 - ptile) * float64(len(cs.centroids))) +func (d *TDigest) weightLimit(idx int) int64 { + ptile := d.quantileOf(idx) + limit := int64(4 * d.compression * ptile * (1 - ptile) * float64(len(d.centroids))) return limit } // checks whether the centroid has room for more weight -func (cs *centroidSet) centroidHasRoom(idx int) bool { - return cs.centroids[idx].count < cs.weightLimit(idx) +func (d *TDigest) centroidHasRoom(idx int) bool { + return d.centroids[idx].count < d.weightLimit(idx) } // find which centroid to add the value to (by index) -func (cs *centroidSet) findAddTarget(val float64) int { +func (d *TDigest) findAddTarget(val float64) int { var ( - nearest []int = cs.nearest(val) + nearest []int = d.nearest(val) eligible []int ) for _, c := range nearest { // if there is room for more weight at this centroid... - if cs.centroidHasRoom(c) { + if d.centroidHasRoom(c) { eligible = append(eligible, c) } } @@ -156,7 +131,7 @@ func (cs *centroidSet) findAddTarget(val float64) int { var anyLesser, anyGreater bool for _, c := range eligible { - m := cs.centroids[c].mean + m := d.centroids[c].mean if m < val { anyLesser = true } else if m > val { @@ -190,10 +165,10 @@ func (cs *centroidSet) findAddTarget(val float64) int { return eligible[rand.Intn(len(eligible))] } -func (cs *centroidSet) addNewCentroid(mean float64, weight int) { - var idx int = len(cs.centroids) +func (d *TDigest) addNewCentroid(mean float64, weight int64) { + var idx int = len(d.centroids) - for i, c := range cs.centroids { + for i, c := range d.centroids { // add in sorted order if mean < c.mean { idx = i @@ -201,24 +176,30 @@ func (cs *centroidSet) addNewCentroid(mean float64, weight int) { } } - cs.centroids = append(cs.centroids, nil) - copy(cs.centroids[idx+1:], cs.centroids[idx:]) - cs.centroids[idx] = ¢roid{mean, weight} + d.centroids = append(d.centroids, nil) + copy(d.centroids[idx+1:], d.centroids[idx:]) + d.centroids[idx] = ¢roid{mean, weight} } -// Add a value to the centroidSet. -func (cs *centroidSet) Add(val float64, weight int) { - cs.countTotal += weight - var idx = cs.findAddTarget(val) +// Add will add a value to the TDigest, updating all quantiles. A +// weight can be specified; use weight of 1 if you don't care about +// weighting your dataset. +func (d *TDigest) Add(val float64, weight int) { + d.add(val, int64(weight)) +} + +func (d *TDigest) add(val float64, weight int64) { + d.countTotal += weight + var idx = d.findAddTarget(val) if idx == -1 { - cs.addNewCentroid(val, weight) + d.addNewCentroid(val, weight) return } - c := cs.centroids[idx] + c := d.centroids[idx] - limit := cs.weightLimit(idx) + limit := d.weightLimit(idx) // how much weight will we be adding? // if adding this node to this centroid would put it over the // weight limit, just add the most we can and recur with the remainder @@ -233,33 +214,35 @@ func (cs *centroidSet) Add(val float64, weight int) { c.count += add c.mean = c.mean + float64(add)*(val-c.mean)/float64(c.count) - cs.Add(val, remainder) + d.add(val, remainder) } else { c.count += weight c.mean = c.mean + float64(weight)*(val-c.mean)/float64(c.count) } - } // returns the approximate quantile that a particular centroid // represents -func (cs *centroidSet) quantileOf(idx int) float64 { - total := 0 - for _, c := range cs.centroids[:idx] { +func (d *TDigest) quantileOf(idx int) float64 { + var total int64 + for _, c := range d.centroids[:idx] { total += c.count } - return (float64(cs.centroids[idx].count/2) + float64(total)) / float64(cs.countTotal) + return (float64(d.centroids[idx].count/2) + float64(total)) / float64(d.countTotal) } -// Quantile returns the approximate value at a quantile (eg the 99th -// percentile value would be centroidSet.quantileValue(0.99)) -func (cs *centroidSet) Quantile(q float64) float64 { - var n = len(cs.centroids) +// Quantile(q) will estimate the qth quantile value of the dataset. The input +// value of q should be in the range [0.0, 1.0]; if it is outside that range, it +// will be clipped into it automatically. +// +// Calling Quantile on a TDigest with no data will return NaN. +func (d *TDigest) Quantile(q float64) float64 { + var n = len(d.centroids) if n == 0 { return math.NaN() } if n == 1 { - return cs.centroids[0].mean + return d.centroids[0].mean } if q < 0 { @@ -269,22 +252,22 @@ func (cs *centroidSet) Quantile(q float64) float64 { } // rescale into count units instead of 0 to 1 units - q = float64(cs.countTotal) * q + q = float64(d.countTotal) * q // find the first centroid which straddles q var ( qTotal float64 = 0 i int ) - for i = 0; i < n && float64(cs.centroids[i].count)/2+qTotal < q; i++ { - qTotal += float64(cs.centroids[i].count) + for i = 0; i < n && float64(d.centroids[i].count)/2+qTotal < q; i++ { + qTotal += float64(d.centroids[i].count) } if i == 0 { // special case 1: the targeted quantile is before the // left-most centroid. extrapolate from the slope from // centroid0 to centroid1. - c0 := cs.centroids[0] - c1 := cs.centroids[1] + c0 := d.centroids[0] + c1 := d.centroids[1] slope := (c1.mean - c0.mean) / (float64(c1.count)/2 + float64(c0.count)/2) deltaQ := q - float64(c0.count)/2 // this is negative return c0.mean + slope*deltaQ @@ -293,42 +276,56 @@ func (cs *centroidSet) Quantile(q float64) float64 { // special case 2: the targeted quantile is from the // right-most centroid. extrapolate from the slope at the // right edge. - c0 := cs.centroids[n-2] - c1 := cs.centroids[n-1] + c0 := d.centroids[n-2] + c1 := d.centroids[n-1] slope := (c1.mean - c0.mean) / (float64(c1.count)/2 + float64(c0.count)/2) deltaQ := q - (qTotal - float64(c1.count)/2) return c1.mean + slope*deltaQ } // common case: targeted quantile is between 2 centroids - c0 := cs.centroids[i-1] - c1 := cs.centroids[i] + c0 := d.centroids[i-1] + c1 := d.centroids[i] slope := (c1.mean - c0.mean) / (float64(c1.count)/2 + float64(c0.count)/2) deltaQ := q - (float64(c1.count)/2 + qTotal) return c1.mean + slope*deltaQ } -func (cs *centroidSet) MergeInto(other TDigest) { - // Add each centroid in cs into other. They should be added in +// MergeInto(other) will add all of the data within a TDigest into other, +// combining them into one larger TDigest. +func (d *TDigest) MergeInto(other *TDigest) { + // Add each centroid in d into other. They should be added in // random order. - addOrder := rand.Perm(len(cs.centroids)) + addOrder := rand.Perm(len(d.centroids)) for _, idx := range addOrder { - c := cs.centroids[idx] + c := d.centroids[idx] // gradually write up the volume written so that the tdigest doesnt overload early - added := 0 - for i := 1; i < 10; i++ { + added := int64(0) + for i := int64(1); i < 10; i++ { toAdd := i * 2 if added+i > c.count { toAdd = c.count - added } - other.Add(c.mean, toAdd) + other.add(c.mean, toAdd) added += toAdd if added >= c.count { break } } if added < c.count { - other.Add(c.mean, c.count-added) + other.add(c.mean, c.count-added) } - other.Add(c.mean, c.count) + other.add(c.mean, c.count) } } + +// MarshalBinary serializes d as a sequence of bytes, suitable to be +// deserialized later with UnmarshalBinary. +func (d *TDigest) MarshalBinary() ([]byte, error) { + return marshalBinary(d) +} + +// UnmarshalBinary populates d with the parsed contents of p, which should have +// been created with a call to MarshalBinary. +func (d *TDigest) UnmarshalBinary(p []byte) error { + return unmarshalBinary(d, p) +} diff --git a/tdigest_test.go b/tdigest_test.go index 2015c9d..a867c83 100644 --- a/tdigest_test.go +++ b/tdigest_test.go @@ -26,15 +26,15 @@ func TestFindNearest(t *testing.T) { } for i, tc := range testcases { - cs := centroidSet{centroids: tc.centroids} - have := cs.nearest(tc.val) + d := TDigest{centroids: tc.centroids} + have := d.nearest(tc.val) if len(tc.want) == 0 { if len(have) != 0 { - t.Errorf("centroidSet.nearest wrong test=%d, have=%v, want=%v", i, have, tc.want) + t.Errorf("TDigest.nearest wrong test=%d, have=%v, want=%v", i, have, tc.want) } } else { if !reflect.DeepEqual(tc.want, have) { - t.Errorf("centroidSet.nearest wrong test=%d, have=%v, want=%v", i, have, tc.want) + t.Errorf("TDigest.nearest wrong test=%d, have=%v, want=%v", i, have, tc.want) } } } @@ -42,13 +42,13 @@ func TestFindNearest(t *testing.T) { func BenchmarkFindNearest(b *testing.B) { n := 500 - cset := simpleCentroidSet(n) + d := simpleTDigest(n) b.ResetTimer() var val float64 - for i := 0; i < b.N; i++ { - val = float64(i % cset.countTotal) - _ = cset.nearest(val) + for i := int64(0); i < int64(b.N); i++ { + val = float64(i % d.countTotal) + _ = d.nearest(val) } } @@ -63,10 +63,10 @@ func TestFindAddTarget(t *testing.T) { {[]*centroid{}, 1, -1}, } for i, tc := range testcases { - cs := centroidSet{centroids: tc.centroids, countTotal: len(tc.centroids)} - have := cs.findAddTarget(tc.val) + d := TDigest{centroids: tc.centroids, countTotal: int64(len(tc.centroids))} + have := d.findAddTarget(tc.val) if have != tc.want { - t.Errorf("centroidSet.findAddTarget wrong test=%d, have=%v, want=%v", i, have, tc.want) + t.Errorf("TDigest.findAddTarget wrong test=%d, have=%v, want=%v", i, have, tc.want) } } } @@ -88,21 +88,21 @@ func TestAddNewCentroid(t *testing.T) { } for i, tc := range testcases { - cset := csetFromMeans(tc.centroidVals) - cset.addNewCentroid(tc.add, 1) + d := tdFromMeans(tc.centroidVals) + d.addNewCentroid(tc.add, 1) - have := make([]float64, len(cset.centroids)) - for i, c := range cset.centroids { + have := make([]float64, len(d.centroids)) + for i, c := range d.centroids { have[i] = c.mean } if !reflect.DeepEqual(tc.want, have) { - t.Errorf("centroidSet.addNewCentroid wrong test=%d, have=%v, want=%v", i, have, tc.want) + t.Errorf("TDigest.addNewCentroid wrong test=%d, have=%v, want=%v", i, have, tc.want) } } } -func verifyCentroidOrder(t *testing.T, cs *centroidSet) { +func verifyCentroidOrder(t *testing.T, cs *TDigest) { if len(cs.centroids) < 2 { return } @@ -119,7 +119,7 @@ func TestQuantileOrder(t *testing.T) { // stumbled upon in real world application: adding a 1 to this // resulted in the 6th centroid getting incremented instead of the // 7th. - cset := ¢roidSet{ + d := &TDigest{ countTotal: 14182, compression: 100, centroids: []*centroid{ @@ -152,38 +152,38 @@ func TestQuantileOrder(t *testing.T) { ¢roid{1034640.000000, 1}, }, } - cset.Add(1.0, 1) - verifyCentroidOrder(t, cset) + d.Add(1.0, 1) + verifyCentroidOrder(t, d) } func TestQuantile(t *testing.T) { type testcase struct { - weights []int + weights []int64 idx int want float64 } testcases := []testcase{ - {[]int{1, 1, 1, 1}, 0, 0.0}, - {[]int{1, 1, 1, 1}, 1, 0.25}, - {[]int{1, 1, 1, 1}, 2, 0.5}, - {[]int{1, 1, 1, 1}, 3, 0.75}, - - {[]int{5, 1, 1, 1}, 0, 0.250}, - {[]int{5, 1, 1, 1}, 1, 0.625}, - {[]int{5, 1, 1, 1}, 2, 0.750}, - {[]int{5, 1, 1, 1}, 3, 0.875}, - - {[]int{1, 1, 1, 5}, 0, 0.0}, - {[]int{1, 1, 1, 5}, 1, 0.125}, - {[]int{1, 1, 1, 5}, 2, 0.250}, - {[]int{1, 1, 1, 5}, 3, 0.625}, + {[]int64{1, 1, 1, 1}, 0, 0.0}, + {[]int64{1, 1, 1, 1}, 1, 0.25}, + {[]int64{1, 1, 1, 1}, 2, 0.5}, + {[]int64{1, 1, 1, 1}, 3, 0.75}, + + {[]int64{5, 1, 1, 1}, 0, 0.250}, + {[]int64{5, 1, 1, 1}, 1, 0.625}, + {[]int64{5, 1, 1, 1}, 2, 0.750}, + {[]int64{5, 1, 1, 1}, 3, 0.875}, + + {[]int64{1, 1, 1, 5}, 0, 0.0}, + {[]int64{1, 1, 1, 5}, 1, 0.125}, + {[]int64{1, 1, 1, 5}, 2, 0.250}, + {[]int64{1, 1, 1, 5}, 3, 0.625}, } for i, tc := range testcases { - cset := csetFromWeights(tc.weights) - have := cset.quantileOf(tc.idx) + d := tdFromWeights(tc.weights) + have := d.quantileOf(tc.idx) if have != tc.want { - t.Errorf("centroidSet.quantile wrong test=%d, have=%.3f, want=%.3f", i, have, tc.want) + t.Errorf("TDigest.quantile wrong test=%d, have=%.3f, want=%.3f", i, have, tc.want) } } } @@ -203,19 +203,19 @@ func TestAddValue(t *testing.T) { {4.0, 1, []*centroid{{0, 1}, {1, 1}, {2.5, 2}, {4, 1}}}, } - cset := newCentroidSet(1) + d := NewWithCompression(1) for i, tc := range testcases { - cset.Add(tc.value, tc.weight) - if !reflect.DeepEqual(cset.centroids, tc.want) { - t.Fatalf("centroidSet.addValue unexpected state step=%d, have=%v, want=%v", i, cset.centroids, tc.want) + d.Add(tc.value, tc.weight) + if !reflect.DeepEqual(d.centroids, tc.want) { + t.Fatalf("TDigest.addValue unexpected state step=%d, have=%v, want=%v", i, d.centroids, tc.want) } } } func TestQuantileValue(t *testing.T) { - cset := newCentroidSet(1) - cset.countTotal = 8 - cset.centroids = []*centroid{{0.5, 3}, {1, 1}, {2, 2}, {3, 1}, {8, 1}} + d := NewWithCompression(1) + d.countTotal = 8 + d.centroids = []*centroid{{0.5, 3}, {1, 1}, {2, 2}, {3, 1}, {8, 1}} type testcase struct { q float64 @@ -240,9 +240,9 @@ func TestQuantileValue(t *testing.T) { var epsilon = 1e-8 for i, tc := range testcases { - have := cset.Quantile(tc.q) + have := d.Quantile(tc.q) if math.Abs(have-tc.want) > epsilon { - t.Errorf("centroidSet.Quantile wrong step=%d, have=%v, want=%v", + t.Errorf("TDigest.Quantile wrong step=%d, have=%v, want=%v", i, have, tc.want) } } @@ -250,47 +250,47 @@ func TestQuantileValue(t *testing.T) { func BenchmarkFindAddTarget(b *testing.B) { n := 500 - cset := simpleCentroidSet(n) + d := simpleTDigest(n) b.ResetTimer() var val float64 - for i := 0; i < b.N; i++ { - val = float64(i % cset.countTotal) - _ = cset.findAddTarget(val) + for i := int64(0); i < int64(b.N); i++ { + val = float64(i % d.countTotal) + _ = d.findAddTarget(val) } } // add the values [0,n) to a centroid set, equal weights -func simpleCentroidSet(n int) *centroidSet { - cset := newCentroidSet(1.0) +func simpleTDigest(n int) *TDigest { + d := NewWithCompression(1.0) for i := 0; i < n; i++ { - cset.Add(float64(i), 1) + d.Add(float64(i), 1) } - return cset + return d } -func csetFromMeans(means []float64) *centroidSet { +func tdFromMeans(means []float64) *TDigest { centroids := make([]*centroid, len(means)) for i, m := range means { centroids[i] = ¢roid{m, 1} } - cset := newCentroidSet(1.0) - cset.centroids = centroids - cset.countTotal = len(centroids) - return cset + d := NewWithCompression(1.0) + d.centroids = centroids + d.countTotal = int64(len(centroids)) + return d } -func csetFromWeights(weights []int) *centroidSet { +func tdFromWeights(weights []int64) *TDigest { centroids := make([]*centroid, len(weights)) - countTotal := 0 + countTotal := int64(0) for i, w := range weights { centroids[i] = ¢roid{float64(i), w} countTotal += w } - cset := newCentroidSet(1.0) - cset.centroids = centroids - cset.countTotal = countTotal - return cset + d := NewWithCompression(1.0) + d.centroids = centroids + d.countTotal = countTotal + return d } func ExampleTDigest() {