Skip to content

Commit

Permalink
test: add unit test for /arana/pkg/sequence/... package
Browse files Browse the repository at this point in the history
  • Loading branch information
Mulavar committed Sep 12, 2023
1 parent c80f220 commit d651a97
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 47 deletions.
12 changes: 6 additions & 6 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err

if flag&_bypass != 0 {
if len(stmt.From) > 0 {
err := rewriteSelectStatement(ctx, stmt, o)
err := expandSelectStar(ctx, stmt, o)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
}

toSingle := func(db, tbl string) (proto.Plan, error) {
if err := rewriteSelectStatement(ctx, stmt, o); err != nil {
if err := expandSelectStar(ctx, stmt, o); err != nil {
return nil, err
}
ret := &dml.SimpleQueryPlan{
Expand Down Expand Up @@ -243,7 +243,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
return toSingle(db, tbl)
}

if err = rewriteSelectStatement(ctx, stmt, o); err != nil {
if err = expandSelectStar(ctx, stmt, o); err != nil {
return nil, errors.WithStack(err)
}

Expand Down Expand Up @@ -570,7 +570,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
Stmt: selectStmt,
}
if _, ok = selectStmt.Select[0].(*ast.SelectElementAll); ok && len(selectStmt.Select) == 1 {
if err = rewriteSelectStatement(ctx, selectStmt, optimizer); err != nil {
if err = expandSelectStar(ctx, selectStmt, optimizer); err != nil {
return nil, err
}

Expand Down Expand Up @@ -624,7 +624,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
scanner = newSelectScanner(stmt, o.Args)
)

if err = rewriteSelectStatement(ctx, stmt, o); err != nil {
if err = expandSelectStar(ctx, stmt, o); err != nil {
return nil, errors.WithStack(err)
}

Expand Down Expand Up @@ -786,7 +786,7 @@ func overwriteLimit(stmt *ast.SelectStatement, args *[]proto.Value) (originOffse
return
}

func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error {
func expandSelectStar(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error {
// todo db 计算逻辑&tb shard 的计算逻辑
starExpand := false
if len(stmt.Select) == 1 {
Expand Down
16 changes: 16 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/

//go:generate mockgen -destination=../../testdata/mock_rt_runtime.go -package=testdata . Runtime
package runtime

import (
Expand Down Expand Up @@ -61,6 +62,8 @@ import (
var (
_ Runtime = (*defaultRuntime)(nil)
_ proto.VConn = (*defaultRuntime)(nil)

_runtimes sync.Map
)

var Tracer = otel.Tracer("Runtime")
Expand All @@ -77,8 +80,20 @@ type Runtime interface {
Begin(ctx context.Context, hooks ...TxHook) (proto.Tx, error)
}

// Register registers a Runtime.
func Register(tenant string, schema string, rt Runtime) error {
if _, loaded := _runtimes.LoadOrStore(tenant+":"+schema, rt); loaded {
return perrors.Errorf("cannot register conflict runtime: tenant=%s, name=%s", tenant, schema)
}
return nil
}

// Load loads a Runtime, here schema means logical database name.
func Load(tenant, schema string) (Runtime, error) {
if rt, ok := _runtimes.Load(tenant + ":" + schema); ok {
return rt.(Runtime), nil
}

var ns *namespace.Namespace
if ns = namespace.Load(tenant, schema); ns == nil {
return nil, perrors.Errorf("no such schema: tenant=%s, schema=%s", tenant, schema)
Expand All @@ -96,6 +111,7 @@ func Unload(tenant, schema string) error {

var _ proto.DB = (*AtomDB)(nil)

//AtomDB represents an atom physical database instance
type AtomDB struct {
mu sync.Mutex

Expand Down
42 changes: 19 additions & 23 deletions pkg/sequence/group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@ type groupSequence struct {
tableName string
step int64

nextGroupStartVal int64
nextGroupMaxVal int64
currentGroupMaxVal int64
currentVal int64
}

// Start sequence and do some initialization operations
func (seq *groupSequence) Start(ctx context.Context, option proto.SequenceConfig) error {
func (seq *groupSequence) Start(ctx context.Context, conf proto.SequenceConfig) error {
rt := ctx.Value(proto.RuntimeCtxKey{}).(runtime.Runtime)
ctx = rcontext.WithRead(rcontext.WithDirect(ctx))

Expand All @@ -94,11 +92,11 @@ func (seq *groupSequence) Start(ctx context.Context, option proto.SequenceConfig
}

// init sequence
if err := seq.initStep(option); err != nil {
if err := seq.initStep(conf); err != nil {
return err
}

seq.tableName = option.Name
seq.tableName = conf.Name
return nil
}

Expand Down Expand Up @@ -131,12 +129,12 @@ func (seq *groupSequence) initTable(ctx context.Context, rt runtime.Runtime) err
return nil
}

func (seq *groupSequence) initStep(option proto.SequenceConfig) error {
func (seq *groupSequence) initStep(conf proto.SequenceConfig) error {
seq.mu.Lock()
defer seq.mu.Unlock()

var step int64
stepValue, ok := option.Option[_stepKey]
stepValue, ok := conf.Option[_stepKey]
if ok {
tempStep, err := strconv.Atoi(stepValue)
if err != nil {
Expand Down Expand Up @@ -169,8 +167,6 @@ func (seq *groupSequence) Acquire(ctx context.Context) (int64, error) {
if err != nil {
return 0, err
}
seq.currentVal = seq.nextGroupStartVal
seq.currentGroupMaxVal = seq.nextGroupMaxVal
} else {
seq.currentVal++
}
Expand All @@ -196,13 +192,14 @@ func (seq *groupSequence) acquireNextGroup(ctx context.Context, rt runtime.Runti
if err != nil {
return err
}
val := make([]proto.Value, 1)
vals := make([]proto.Value, 1)
row, err := ds.Next()
if err != nil {
// first time, init the start seq val for the table
if errors.Is(err, io.EOF) {
seq.nextGroupStartVal = _startSequence
seq.nextGroupMaxVal = _startSequence + seq.step - 1
rs, err := tx.Exec(ctx, "", _initGroupSequence, proto.NewValueInt64(seq.nextGroupMaxVal+1), proto.NewValueInt64(seq.step), proto.NewValueString(seq.tableName))
seq.currentVal = _startSequence
seq.currentGroupMaxVal = _startSequence + seq.step - 1
rs, err := tx.Exec(ctx, "", _initGroupSequence, proto.NewValueInt64(_startSequence), proto.NewValueInt64(seq.step), proto.NewValueString(seq.tableName))
if err != nil {
return err
}
Expand All @@ -216,20 +213,19 @@ func (seq *groupSequence) acquireNextGroup(ctx context.Context, rt runtime.Runti
}
return err
}
if err = row.Scan(val); err != nil {

if err = row.Scan(vals); err != nil {
return err
}
_, _ = ds.Next()

if val[0] != nil {
nextGroupStartVal, _ := val[0].Int64()
if nextGroupStartVal%seq.step != 0 {
// padding left
nextGroupStartVal = (nextGroupStartVal/seq.step + 1) * seq.step
}
seq.nextGroupStartVal = nextGroupStartVal
seq.nextGroupMaxVal = seq.nextGroupStartVal + seq.step - 1
rs, err := tx.Exec(ctx, "", _updateNextGroup, proto.NewValueInt64(seq.nextGroupMaxVal+1), proto.NewValueString(seq.tableName))
if vals[0] != nil {
lastGroupStartVal, _ := vals[0].Int64()
// padding left
currentGroupStartVal := (lastGroupStartVal/seq.step+1)*seq.step + 1
seq.currentVal = currentGroupStartVal
seq.currentGroupMaxVal = currentGroupStartVal + seq.step - 1
rs, err := tx.Exec(ctx, "", _updateNextGroup, proto.NewValueInt64(currentGroupStartVal), proto.NewValueString(seq.tableName))
if err != nil {
return err
}
Expand Down
82 changes: 74 additions & 8 deletions pkg/sequence/group/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,95 @@ package group
import (
"context"
"fmt"
"io"
"sync"
"testing"
)

import (
"github.com/golang/mock/gomock"

"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime"
"github.com/arana-db/arana/testdata"
)

const (
tenant = "fakeTenant"
schema = "employees"
tableName = "mock_group_sequence"
)

func Test_groupSequence_Acquire(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.WithValue(context.Background(), proto.ContextKeyTenant{}, tenant)
ctx = context.WithValue(ctx, proto.ContextKeySchema{}, schema)

// build mock row
mockRow := testdata.NewMockRow(ctrl)
mockRow.EXPECT().Scan(gomock.Any()).DoAndReturn(func(dest []proto.Value) error {
dest[0] = proto.NewValueInt64(1)
return nil
})

// build mock dataset
mockDataset := testdata.NewMockDataset(ctrl)
mockDataset.EXPECT().Next().Return(nil, io.EOF).Times(1)
mockDataset.EXPECT().Next().Return(mockRow, nil).Times(2)

// build mock result
mockRes := testdata.NewMockResult(ctrl)
mockRes.EXPECT().RowsAffected().Return(uint64(0), nil).AnyTimes()
mockRes.EXPECT().Dataset().Return(mockDataset, nil).AnyTimes()

// build mock transaction
mockTx := testdata.NewMockTx(ctrl)
mockTx.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Eq(_initGroupSequenceTableSql)).Return(mockRes, nil).AnyTimes()
mockTx.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Eq(_initGroupSequence), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockRes, nil).AnyTimes()
mockTx.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Eq(_updateNextGroup), gomock.Any(), gomock.Any()).Return(mockRes, nil).AnyTimes()
mockTx.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Eq(_selectNextGroupWithXLock), gomock.Any()).Return(mockRes, nil).AnyTimes()
mockTx.EXPECT().Commit(gomock.Any()).Return(mockRes, uint16(0), nil).AnyTimes()
mockTx.EXPECT().Rollback(gomock.Any()).Return(mockRes, uint16(0), nil).AnyTimes()

// build mock runtime
mockRt := testdata.NewMockRuntime(ctrl)
mockRt.EXPECT().Begin(gomock.Any()).Return(mockTx, nil).AnyTimes()

runtime.Register(tenant, schema, mockRt)

ctx = context.WithValue(ctx, proto.RuntimeCtxKey{}, mockRt)

conf := proto.SequenceConfig{
Name: "group",
Type: "group",
Option: map[string]string{_stepKey: "100"},
}

validate(t, ctx, conf, 0, 0, 1)
validate(t, ctx, conf, 100, 100, 101)
}

func validate(t *testing.T, ctx context.Context, conf proto.SequenceConfig, currentVal, currentGroupMaxVal, expectVal int64) {
seq := &groupSequence{
mu: sync.Mutex{},
tableName: "mock_group_sequence",
step: 100,
currentGroupMaxVal: 99,
currentVal: 50,
tableName: tableName,
currentGroupMaxVal: currentGroupMaxVal,
currentVal: currentVal,
}

val, err := seq.Acquire(context.Background())
err := seq.Start(ctx, conf)
assert.NoError(t, err)

val, err := seq.Acquire(ctx)

assert.NoError(t, err, fmt.Sprintf("acquire err : %v", err))

curVal := seq.CurrentVal()

assert.Equal(t, int64(51), curVal, fmt.Sprintf("acquire val: %d, cur val: %d", val, curVal))
assert.Equal(t, val, curVal, fmt.Sprintf("acquire val: %d, cur val: %d", val, curVal))
assert.Equal(t, expectVal, curVal, fmt.Sprintf("acquire val: %d, cur val: %d", val, curVal))
}
8 changes: 4 additions & 4 deletions pkg/sequence/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ func (seq *snowflakeSequence) GetSequenceConfig() proto.SequenceConfig {
}
}

func (seq *snowflakeSequence) findWorkID(ctx context.Context, tx proto.Tx, seqName string) (int64, error) {
ret, err := tx.Query(ctx, "", _selectSelfWorkIdWithXLock, proto.NewValueString(seqName), proto.NewValueString(_nodeId))
func (seq *snowflakeSequence) findWorkID(ctx context.Context, tx proto.Tx, tableName string) (int64, error) {
ret, err := tx.Query(ctx, "", _selectSelfWorkIdWithXLock, proto.NewValueString(tableName), proto.NewValueString(_nodeId))
if err != nil {
return 0, err
}
Expand All @@ -271,7 +271,7 @@ func (seq *snowflakeSequence) findWorkID(ctx context.Context, tx proto.Tx, seqNa
}
}

ret, err = tx.Query(ctx, "", _selectMaxWorkIdWithXLock, proto.NewValueString(seqName))
ret, err = tx.Query(ctx, "", _selectMaxWorkIdWithXLock, proto.NewValueString(tableName))
if err != nil {
return 0, err
}
Expand All @@ -297,7 +297,7 @@ func (seq *snowflakeSequence) findWorkID(ctx context.Context, tx proto.Tx, seqNa
curId, _ := val[0].Int64()
curId++
if curId > workIdMax {
ret, err := tx.Query(ctx, "", _selectFreeWorkIdWithXLock, proto.NewValueString(seqName))
ret, err := tx.Query(ctx, "", _selectFreeWorkIdWithXLock, proto.NewValueString(tableName))
if err != nil {
return 0, err
}
Expand Down
Loading

0 comments on commit d651a97

Please sign in to comment.