diff --git a/go/vt/vtgate/endtoend/main_test.go b/go/vt/vtgate/endtoend/main_test.go index 08aae25420e..b471786b78e 100644 --- a/go/vt/vtgate/endtoend/main_test.go +++ b/go/vt/vtgate/endtoend/main_test.go @@ -153,6 +153,19 @@ var ( Name: "hash", }}, }, + "oltp_test": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Column: "id", + Name: "hash", + }}, + Columns: []*vschemapb.Column{{ + Name: "c", + Type: sqltypes.Char, + }, { + Name: "pad", + Type: sqltypes.Char, + }}, + }, }, } diff --git a/go/vt/vtgate/endtoend/oltp_test.go b/go/vt/vtgate/endtoend/oltp_test.go new file mode 100644 index 00000000000..f8ca646f8c7 --- /dev/null +++ b/go/vt/vtgate/endtoend/oltp_test.go @@ -0,0 +1,132 @@ +package endtoend + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "sync" + "testing" + + "vitess.io/vitess/go/mysql" +) + +// 10 groups, 119 characters +const cValueTemplate = "###########-###########-###########-" + + "###########-###########-###########-" + + "###########-###########-###########-" + + "###########" + +// 5 groups, 59 characters +const padValueTemplate = "###########-###########-###########-" + + "###########-###########" + +func sysbenchRandom(rng *rand.Rand, template string) []byte { + out := make([]byte, 0, len(template)) + for i := range template { + switch template[i] { + case '#': + out = append(out, '0'+byte(rng.Intn(10))) + default: + out = append(out, template[i]) + } + } + return out +} + +var oltpInitOnce sync.Once + +func BenchmarkOLTP(b *testing.B) { + const MaxRows = 10000 + const RangeSize = 100 + + rng := rand.New(rand.NewSource(1234)) + + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + var query bytes.Buffer + + oltpInitOnce.Do(func() { + b.Logf("seeding database for benchmark...") + + var rows int = 1 + for i := 0; i < MaxRows/10; i++ { + query.Reset() + query.WriteString("insert into oltp_test(id, k, c, pad) values ") + for j := 0; j < 10; j++ { + if j > 0 { + query.WriteString(", ") + } + _, _ = fmt.Fprintf(&query, "(%d, %d, '%s', '%s')", rows, rng.Int31n(0xFFFF), sysbenchRandom(rng, cValueTemplate), sysbenchRandom(rng, padValueTemplate)) + rows++ + } + + _, err = conn.ExecuteFetch(query.String(), -1, false) + if err != nil { + b.Fatal(err) + } + } + b.Logf("finshed (inserted %d rows)", rows) + }) + + b.Run("SimpleRanges", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := rng.Intn(MaxRows) + + query.Reset() + _, _ = fmt.Fprintf(&query, "SELECT c FROM oltp_test WHERE id BETWEEN %d AND %d", id, id+rng.Intn(RangeSize)-1) + _, err := conn.ExecuteFetch(query.String(), 1000, false) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("SumRanges", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := rng.Intn(MaxRows) + + query.Reset() + _, _ = fmt.Fprintf(&query, "SELECT SUM(k) FROM oltp_test WHERE id BETWEEN %d AND %d", id, id+rng.Intn(RangeSize)-1) + _, err := conn.ExecuteFetch(query.String(), 1000, false) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("OrderRanges", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := rng.Intn(MaxRows) + + query.Reset() + _, _ = fmt.Fprintf(&query, "SELECT c FROM oltp_test WHERE id BETWEEN %d AND %d ORDER BY c", id, id+rng.Intn(RangeSize)-1) + _, err := conn.ExecuteFetch(query.String(), 1000, false) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("DistinctRanges", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := rng.Intn(MaxRows) + + query.Reset() + _, _ = fmt.Fprintf(&query, "SELECT DISTINCT c FROM oltp_test WHERE id BETWEEN %d AND %d ORDER BY c", id, id+rng.Intn(RangeSize)-1) + _, err := conn.ExecuteFetch(query.String(), 1000, false) + if err != nil { + b.Error(err) + } + } + }) +} diff --git a/go/vt/vtgate/endtoend/schema.sql b/go/vt/vtgate/endtoend/schema.sql index 5fb1f52224f..d543d130c14 100644 --- a/go/vt/vtgate/endtoend/schema.sql +++ b/go/vt/vtgate/endtoend/schema.sql @@ -72,3 +72,11 @@ create table t1_sharded( id2 bigint, primary key(id1) ) Engine=InnoDB; + +create table oltp_test( + id bigint not null auto_increment, + k bigint default 0 not null, + c char(120) default '' not null, + pad char(60) default '' not null, + primary key (id) +) Engine=InnoDB; \ No newline at end of file diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 1982328a8a6..61a3140aa27 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -114,6 +114,35 @@ func (oa *OrderedAggregate) TryExecute(ctx context.Context, vcursor VCursor, bin return qr.Truncate(oa.TruncateColumnCount), nil } +func (oa *OrderedAggregate) executeGroupBy(result *sqltypes.Result) (*sqltypes.Result, error) { + if len(result.Rows) < 1 { + return result, nil + } + + out := &sqltypes.Result{ + Fields: result.Fields, + Rows: result.Rows[:0], + } + + var currentKey []sqltypes.Value + var lastRow sqltypes.Row + var err error + for _, row := range result.Rows { + var nextGroup bool + + currentKey, nextGroup, err = oa.nextGroupBy(currentKey, row) + if err != nil { + return nil, err + } + if nextGroup { + out.Rows = append(out.Rows, lastRow) + } + lastRow = row + } + out.Rows = append(out.Rows, lastRow) + return out, nil +} + func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { result, err := vcursor.ExecutePrimitive( ctx, @@ -124,6 +153,10 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa if err != nil { return nil, err } + if len(oa.Aggregates) == 0 { + return oa.executeGroupBy(result) + } + agg, fields, err := newAggregation(result.Fields, oa.Aggregates) if err != nil { return nil, err @@ -160,8 +193,63 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa return out, nil } +func (oa *OrderedAggregate) executeStreamGroupBy(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { + cb := func(qr *sqltypes.Result) error { + return callback(qr.Truncate(oa.TruncateColumnCount)) + } + + var fields []*querypb.Field + var currentKey []sqltypes.Value + var lastRow sqltypes.Row + + visitor := func(qr *sqltypes.Result) error { + var err error + if fields == nil && len(qr.Fields) > 0 { + fields = qr.Fields + if err = cb(&sqltypes.Result{Fields: fields}); err != nil { + return err + } + } + for _, row := range qr.Rows { + var nextGroup bool + + currentKey, nextGroup, err = oa.nextGroupBy(currentKey, row) + if err != nil { + return err + } + + if nextGroup { + // this is a new grouping. let's yield the old one, and start a new + if err := cb(&sqltypes.Result{Rows: []sqltypes.Row{lastRow}}); err != nil { + return err + } + } + + lastRow = row + } + return nil + } + + /* we need the input fields types to correctly calculate the output types */ + err := vcursor.StreamExecutePrimitive(ctx, oa.Input, bindVars, true, visitor) + if err != nil { + return err + } + + if lastRow != nil { + if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{lastRow}}); err != nil { + return err + } + } + return nil +} + // TryStreamExecute is a Primitive function. func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { + if len(oa.Aggregates) == 0 { + return oa.executeStreamGroupBy(ctx, vcursor, bindVars, callback) + } + cb := func(qr *sqltypes.Result) error { return callback(qr.Truncate(oa.TruncateColumnCount)) } diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 312e04f98f2..88c934421fb 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math/rand" + "slices" "sort" "strconv" "strings" @@ -428,10 +429,10 @@ func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { comparers := extractSlices(route.OrderBy) - sort.Slice(out.Rows, func(i, j int) bool { + slices.SortFunc(out.Rows, func(a, b sqltypes.Row) int { var cmp int if err != nil { - return true + return -1 } // If there are any errors below, the function sets // the external err and returns true. Once err is set, @@ -439,16 +440,15 @@ func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { // Slice think that all elements are in the correct // order and return more quickly. for _, c := range comparers { - cmp, err = c.compare(out.Rows[i], out.Rows[j]) + cmp, err = c.compare(a, b) if err != nil { - return true + return -1 } - if cmp == 0 { - continue + if cmp != 0 { + return cmp } - return cmp < 0 } - return true + return 0 }) return out.Truncate(route.TruncateColumnCount), err diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index 3c9e632e819..d05e86a12bb 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -52,12 +52,37 @@ func (err UnsupportedCollationError) Error() string { var UnsupportedCollationHashError = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") func compare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { + v1t := v1.Type() + // We have a fast path here for the case where both values are // the same type, and it's one of the basic types we can compare // directly. This is a common case for equality checks. - if v1.Type() == v2.Type() { + if v1t == v2.Type() { switch { - case sqltypes.IsSigned(v1.Type()): + case sqltypes.IsText(v1t): + if collationID == collations.CollationBinaryID { + return bytes.Compare(v1.Raw(), v2.Raw()), nil + } + coll := colldata.Lookup(collationID) + if coll == nil { + return 0, UnsupportedCollationError{ID: collationID} + } + result := coll.Collate(v1.Raw(), v2.Raw(), false) + switch { + case result < 0: + return -1, nil + case result > 0: + return 1, nil + default: + return 0, nil + } + case sqltypes.IsBinary(v1t), v1t == sqltypes.Date, v1t == sqltypes.Datetime, v1t == sqltypes.Timestamp: + // We can't optimize for Time here, since Time is not sortable + // based on the raw bytes. This is because of cases like + // '24:00:00' and '101:00:00' which are both valid times and + // order wrong based on the raw bytes. + return bytes.Compare(v1.Raw(), v2.Raw()), nil + case sqltypes.IsSigned(v1t): i1, err := v1.ToInt64() if err != nil { return 0, err @@ -74,7 +99,7 @@ func compare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { default: return 0, nil } - case sqltypes.IsUnsigned(v1.Type()): + case sqltypes.IsUnsigned(v1t): u1, err := v1.ToUint64() if err != nil { return 0, err @@ -91,30 +116,6 @@ func compare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { default: return 0, nil } - case sqltypes.IsBinary(v1.Type()), v1.Type() == sqltypes.Date, - v1.Type() == sqltypes.Datetime, v1.Type() == sqltypes.Timestamp: - // We can't optimize for Time here, since Time is not sortable - // based on the raw bytes. This is because of cases like - // '24:00:00' and '101:00:00' which are both valid times and - // order wrong based on the raw bytes. - return bytes.Compare(v1.Raw(), v2.Raw()), nil - case sqltypes.IsText(v1.Type()): - if collationID == collations.CollationBinaryID { - return bytes.Compare(v1.Raw(), v2.Raw()), nil - } - coll := colldata.Lookup(collationID) - if coll == nil { - return 0, UnsupportedCollationError{ID: collationID} - } - result := coll.Collate(v1.Raw(), v2.Raw(), false) - switch { - case result < 0: - return -1, nil - case result > 0: - return 1, nil - default: - return 0, nil - } } } diff --git a/go/vt/vttest/local_cluster.go b/go/vt/vttest/local_cluster.go index 86b8079a9c8..9d84cb7fceb 100644 --- a/go/vt/vttest/local_cluster.go +++ b/go/vt/vttest/local_cluster.go @@ -673,7 +673,7 @@ func (db *LocalCluster) applyVschema(keyspace string, migration string) error { func (db *LocalCluster) reloadSchemaKeyspace(keyspace string) error { server := fmt.Sprintf("localhost:%v", db.vt.PortGrpc) args := []string{"ReloadSchemaKeyspace", "--include_primary=true", keyspace} - fmt.Printf("Reloading keyspace schema %v", args) + log.Infof("Reloading keyspace schema %v", args) err := vtctlclient.RunCommandAndWait(context.Background(), server, args, func(e *logutil.Event) { log.Info(e) diff --git a/go/vt/vttest/vtprocess.go b/go/vt/vttest/vtprocess.go index ec26acf41b8..2053973b766 100644 --- a/go/vt/vttest/vtprocess.go +++ b/go/vt/vttest/vtprocess.go @@ -26,6 +26,7 @@ import ( "path" "strings" "syscall" + "testing" "time" "google.golang.org/protobuf/encoding/prototext" @@ -141,8 +142,10 @@ func (vtp *VtProcess) WaitStart() (err error) { vtp.proc.Env = append(vtp.proc.Env, os.Environ()...) vtp.proc.Env = append(vtp.proc.Env, vtp.Env...) - vtp.proc.Stderr = os.Stderr - vtp.proc.Stdout = os.Stdout + if testing.Verbose() { + vtp.proc.Stderr = os.Stderr + vtp.proc.Stdout = os.Stdout + } log.Infof("%v %v", strings.Join(vtp.proc.Args, " ")) err = vtp.proc.Start()