Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add unit test for /arana/pkg/sequence/... package #765

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ header: # `header` section is configurations for source codes license header.
- 'pkg/resolver/mysql/encoding.go'
- 'pkg/resolver/mysql/sql_error.go'
- 'pkg/resolver/mysql/type.go'
- 'pkg/runtime/mock_runtime.go'
- 'VERSION'
- ".errcheck-exclude"
- ".golangci.yml"
Expand Down
132 changes: 132 additions & 0 deletions pkg/runtime/mock_runtime.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err

if flag&_bypass != 0 {
if len(stmt.From) > 0 {
err := rewriteSelectStatement(ctx, stmt, o)
err := expandSelectStar(ctx, stmt, o)
dongzl marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
}

toSingle := func(db, tbl string) (proto.Plan, error) {
if err := rewriteSelectStatement(ctx, stmt, o); err != nil {
if err := expandSelectStar(ctx, stmt, o); err != nil {
return nil, err
}
ret := &dml.SimpleQueryPlan{
Expand Down Expand Up @@ -243,7 +243,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
return toSingle(db, tbl)
}

if err = rewriteSelectStatement(ctx, stmt, o); err != nil {
if err = expandSelectStar(ctx, stmt, o); err != nil {
return nil, errors.WithStack(err)
}

Expand Down Expand Up @@ -570,7 +570,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
Stmt: selectStmt,
}
if _, ok = selectStmt.Select[0].(*ast.SelectElementAll); ok && len(selectStmt.Select) == 1 {
if err = rewriteSelectStatement(ctx, selectStmt, optimizer); err != nil {
if err = expandSelectStar(ctx, selectStmt, optimizer); err != nil {
return nil, err
}

Expand Down Expand Up @@ -624,7 +624,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
scanner = newSelectScanner(stmt, o.Args)
)

if err = rewriteSelectStatement(ctx, stmt, o); err != nil {
if err = expandSelectStar(ctx, stmt, o); err != nil {
return nil, errors.WithStack(err)
}

Expand Down Expand Up @@ -786,7 +786,7 @@ func overwriteLimit(stmt *ast.SelectStatement, args *[]proto.Value) (originOffse
return
}

func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error {
func expandSelectStar(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error {
// todo db 计算逻辑&tb shard 的计算逻辑
starExpand := false
if len(stmt.Select) == 1 {
Expand Down
16 changes: 16 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/

//go:generate mockgen -destination=./mock_runtime.go -package=runtime . Runtime
package runtime

import (
Expand Down Expand Up @@ -61,6 +62,8 @@ import (
var (
_ Runtime = (*defaultRuntime)(nil)
_ proto.VConn = (*defaultRuntime)(nil)

_runtimes sync.Map
)

var Tracer = otel.Tracer("Runtime")
Expand All @@ -77,8 +80,20 @@ type Runtime interface {
Begin(ctx context.Context, hooks ...TxHook) (proto.Tx, error)
}

// Register registers a Runtime.
func Register(tenant string, schema string, rt Runtime) error {
if _, loaded := _runtimes.LoadOrStore(tenant+":"+schema, rt); loaded {
return perrors.Errorf("cannot register conflict runtime: tenant=%s, name=%s", tenant, schema)
}
return nil
}

// Load loads a Runtime, here schema means logical database name.
func Load(tenant, schema string) (Runtime, error) {
if rt, ok := _runtimes.Load(tenant + ":" + schema); ok {
return rt.(Runtime), nil
}

var ns *namespace.Namespace
if ns = namespace.Load(tenant, schema); ns == nil {
return nil, perrors.Errorf("no such schema: tenant=%s, schema=%s", tenant, schema)
Expand All @@ -96,6 +111,7 @@ func Unload(tenant, schema string) error {

var _ proto.DB = (*AtomDB)(nil)

// AtomDB represents an atom physical database instance
type AtomDB struct {
mu sync.Mutex

Expand Down
42 changes: 19 additions & 23 deletions pkg/sequence/group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@ type groupSequence struct {
tableName string
step int64

nextGroupStartVal int64
nextGroupMaxVal int64
currentGroupMaxVal int64
currentVal int64
}

// Start sequence and do some initialization operations
func (seq *groupSequence) Start(ctx context.Context, option proto.SequenceConfig) error {
func (seq *groupSequence) Start(ctx context.Context, conf proto.SequenceConfig) error {
rt := ctx.Value(proto.RuntimeCtxKey{}).(runtime.Runtime)
ctx = rcontext.WithRead(rcontext.WithDirect(ctx))

Expand All @@ -94,11 +92,11 @@ func (seq *groupSequence) Start(ctx context.Context, option proto.SequenceConfig
}

// init sequence
if err := seq.initStep(option); err != nil {
if err := seq.initStep(conf); err != nil {
return err
}

seq.tableName = option.Name
seq.tableName = conf.Name
return nil
}

Expand Down Expand Up @@ -131,12 +129,12 @@ func (seq *groupSequence) initTable(ctx context.Context, rt runtime.Runtime) err
return nil
}

func (seq *groupSequence) initStep(option proto.SequenceConfig) error {
func (seq *groupSequence) initStep(conf proto.SequenceConfig) error {
seq.mu.Lock()
defer seq.mu.Unlock()

var step int64
stepValue, ok := option.Option[_stepKey]
stepValue, ok := conf.Option[_stepKey]
if ok {
tempStep, err := strconv.Atoi(stepValue)
if err != nil {
Expand Down Expand Up @@ -169,8 +167,6 @@ func (seq *groupSequence) Acquire(ctx context.Context) (int64, error) {
if err != nil {
return 0, err
}
seq.currentVal = seq.nextGroupStartVal
seq.currentGroupMaxVal = seq.nextGroupMaxVal
} else {
seq.currentVal++
}
Expand All @@ -196,13 +192,14 @@ func (seq *groupSequence) acquireNextGroup(ctx context.Context, rt runtime.Runti
if err != nil {
return err
}
val := make([]proto.Value, 1)
vals := make([]proto.Value, 1)
row, err := ds.Next()
if err != nil {
// first time, init the start seq val for the table
if errors.Is(err, io.EOF) {
seq.nextGroupStartVal = _startSequence
seq.nextGroupMaxVal = _startSequence + seq.step - 1
rs, err := tx.Exec(ctx, "", _initGroupSequence, proto.NewValueInt64(seq.nextGroupMaxVal+1), proto.NewValueInt64(seq.step), proto.NewValueString(seq.tableName))
seq.currentVal = _startSequence
seq.currentGroupMaxVal = _startSequence + seq.step - 1
rs, err := tx.Exec(ctx, "", _initGroupSequence, proto.NewValueInt64(_startSequence), proto.NewValueInt64(seq.step), proto.NewValueString(seq.tableName))
if err != nil {
return err
}
Expand All @@ -216,20 +213,19 @@ func (seq *groupSequence) acquireNextGroup(ctx context.Context, rt runtime.Runti
}
return err
}
if err = row.Scan(val); err != nil {

if err = row.Scan(vals); err != nil {
return err
}
_, _ = ds.Next()

if val[0] != nil {
nextGroupStartVal, _ := val[0].Int64()
if nextGroupStartVal%seq.step != 0 {
// padding left
nextGroupStartVal = (nextGroupStartVal/seq.step + 1) * seq.step
}
seq.nextGroupStartVal = nextGroupStartVal
seq.nextGroupMaxVal = seq.nextGroupStartVal + seq.step - 1
rs, err := tx.Exec(ctx, "", _updateNextGroup, proto.NewValueInt64(seq.nextGroupMaxVal+1), proto.NewValueString(seq.tableName))
if vals[0] != nil {
lastGroupStartVal, _ := vals[0].Int64()
// padding left
currentGroupStartVal := (lastGroupStartVal/seq.step+1)*seq.step + 1
seq.currentVal = currentGroupStartVal
seq.currentGroupMaxVal = currentGroupStartVal + seq.step - 1
rs, err := tx.Exec(ctx, "", _updateNextGroup, proto.NewValueInt64(currentGroupStartVal), proto.NewValueString(seq.tableName))
if err != nil {
return err
}
Expand Down
Loading
Loading