diff --git a/app/app.go b/app/app.go index 2e5dc121..ad446250 100644 --- a/app/app.go +++ b/app/app.go @@ -42,10 +42,16 @@ func (app *App) RunPlan(ctx context.Context, stateStorage storage.Storage, plan return errors.Wrap(err, "couldn't materialize the physical plan into an execution plan") } - stream, err := exec.Get(ctx, variables) + programID := &execution.StreamID{Id: ""} + + tx := stateStorage.BeginTransaction() + stream, err := exec.Get(storage.InjectStateTransaction(ctx, tx), variables, programID) if err != nil { return errors.Wrap(err, "couldn't get record stream from execution plan") } + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "couldn't commit transaction to get record stream from execution plan") + } var rec *execution.Record for { diff --git a/execution/distinct.go b/execution/distinct.go index 41c48549..e07f030c 100644 --- a/execution/distinct.go +++ b/execution/distinct.go @@ -2,6 +2,7 @@ package execution import ( "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" "context" @@ -16,8 +17,14 @@ func NewDistinct(child Node) *Distinct { return &Distinct{child: child} } -func (node *Distinct) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - stream, err := node.child.Get(ctx, variables) +func (node *Distinct) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + stream, err := node.child.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get stream for child node in distinct") } diff --git a/execution/engine.go b/execution/engine.go index 78747996..c71cdb1d 100644 --- a/execution/engine.go +++ b/execution/engine.go @@ -61,19 +61,19 @@ type IntermediateRecordStore interface { type PullEngine struct { irs IntermediateRecordStore source RecordStream - sourceStoragePrefix []byte lastCommittedWatermark time.Time watermarkSource WatermarkSource storage storage.Storage + streamID *StreamID } -func NewPullEngine(irs IntermediateRecordStore, storage storage.Storage, source RecordStream, sourceStoragePrefix []byte, watermarkSource WatermarkSource) *PullEngine { +func NewPullEngine(irs IntermediateRecordStore, storage storage.Storage, source RecordStream, streamID *StreamID, watermarkSource WatermarkSource) *PullEngine { return &PullEngine{ - irs: irs, - storage: storage, - source: source, - sourceStoragePrefix: sourceStoragePrefix, - watermarkSource: watermarkSource, + irs: irs, + storage: storage, + source: source, + streamID: streamID, + watermarkSource: watermarkSource, } } @@ -127,20 +127,23 @@ func (engine *PullEngine) Run(ctx context.Context) { } else if err != nil { tx.Abort() log.Println(err) - return // TODO: Error propagation? Add this to the underlying queue as an ErrorElement? How to do this well? + return // TODO: Error propagation? Add this to the underlying queue as an ErrorElement? How to do this well? Send it to the underlying IRS like a watermark? } } } func (engine *PullEngine) loop(ctx context.Context, tx storage.StateTransaction) error { - sourcePrefixedTx := tx.WithPrefix(engine.sourceStoragePrefix) + // This is a transaction prefixed with the current node StreamID, + // which should be used for all storage operations of this node. + // Source streams will get the raw, non-prefixed, transaction. + prefixedTx := tx.WithPrefix(engine.streamID.AsPrefix()) - watermark, err := engine.watermarkSource.GetWatermark(ctx, sourcePrefixedTx) + watermark, err := engine.watermarkSource.GetWatermark(ctx, tx) if err != nil { return errors.Wrap(err, "couldn't get current watermark from source") } if watermark.After(engine.lastCommittedWatermark) { - err := engine.irs.UpdateWatermark(ctx, tx, watermark) + err := engine.irs.UpdateWatermark(ctx, prefixedTx, watermark) if err != nil { return errors.Wrap(err, "couldn't update watermark in intermediate record store") } @@ -148,14 +151,14 @@ func (engine *PullEngine) loop(ctx context.Context, tx storage.StateTransaction) return nil } - record, err := engine.source.Next(storage.InjectStateTransaction(ctx, sourcePrefixedTx)) + record, err := engine.source.Next(storage.InjectStateTransaction(ctx, tx)) if err != nil { if err == ErrEndOfStream { - err := engine.irs.UpdateWatermark(ctx, tx, maxWatermark) + err := engine.irs.UpdateWatermark(ctx, prefixedTx, maxWatermark) if err != nil { return errors.Wrap(err, "couldn't mark end of stream max watermark in intermediate record store") } - err = engine.irs.MarkEndOfStream(ctx, tx) + err = engine.irs.MarkEndOfStream(ctx, prefixedTx) if err != nil { return errors.Wrap(err, "couldn't mark end of stream in intermediate record store") } @@ -163,7 +166,7 @@ func (engine *PullEngine) loop(ctx context.Context, tx storage.StateTransaction) } return errors.Wrap(err, "couldn't get next record") } - err = engine.irs.AddRecord(ctx, tx, 0, record) + err = engine.irs.AddRecord(ctx, prefixedTx, 0, record) if err != nil { return errors.Wrap(err, "couldn't add record to intermediate record store") } @@ -173,7 +176,9 @@ func (engine *PullEngine) loop(ctx context.Context, tx storage.StateTransaction) func (engine *PullEngine) Next(ctx context.Context) (*Record, error) { tx := storage.GetStateTransactionFromContext(ctx) - rec, err := engine.irs.Next(ctx, tx) + prefixedTx := tx.WithPrefix(engine.streamID.AsPrefix()) + + rec, err := engine.irs.Next(ctx, prefixedTx) if err != nil { if err == ErrEndOfStream { return nil, ErrEndOfStream @@ -184,7 +189,9 @@ func (engine *PullEngine) Next(ctx context.Context) (*Record, error) { } func (engine *PullEngine) GetWatermark(ctx context.Context, tx storage.StateTransaction) (time.Time, error) { - return engine.irs.GetWatermark(ctx, tx) + prefixedTx := tx.WithPrefix(engine.streamID.AsPrefix()) + + return engine.irs.GetWatermark(ctx, prefixedTx) } func (engine *PullEngine) Close() error { diff --git a/execution/execution.go b/execution/execution.go index a7cfc564..f2b34568 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -9,7 +9,7 @@ import ( ) type Node interface { - Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) + Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) } type Expression interface { @@ -71,7 +71,7 @@ func NewNodeExpression(node Node) *NodeExpression { } func (ne *NodeExpression) ExpressionValue(ctx context.Context, variables octosql.Variables) (octosql.Value, error) { - records, err := ne.node.Get(ctx, variables) + records, err := ne.node.Get(ctx, variables, GetRawStreamID()) // TODO: Think about this. if err != nil { return octosql.ZeroValue(), errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/execution.pb.go b/execution/execution.pb.go new file mode 100644 index 00000000..20d1f09d --- /dev/null +++ b/execution/execution.pb.go @@ -0,0 +1,78 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: execution/execution.proto + +package execution + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type StreamID struct { + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StreamID) Reset() { *m = StreamID{} } +func (m *StreamID) String() string { return proto.CompactTextString(m) } +func (*StreamID) ProtoMessage() {} +func (*StreamID) Descriptor() ([]byte, []int) { + return fileDescriptor_0a4329d6cc9a89db, []int{0} +} + +func (m *StreamID) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StreamID.Unmarshal(m, b) +} +func (m *StreamID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StreamID.Marshal(b, m, deterministic) +} +func (m *StreamID) XXX_Merge(src proto.Message) { + xxx_messageInfo_StreamID.Merge(m, src) +} +func (m *StreamID) XXX_Size() int { + return xxx_messageInfo_StreamID.Size(m) +} +func (m *StreamID) XXX_DiscardUnknown() { + xxx_messageInfo_StreamID.DiscardUnknown(m) +} + +var xxx_messageInfo_StreamID proto.InternalMessageInfo + +func (m *StreamID) GetId() string { + if m != nil { + return m.Id + } + return "" +} + +func init() { + proto.RegisterType((*StreamID)(nil), "execution.StreamID") +} + +func init() { proto.RegisterFile("execution/execution.proto", fileDescriptor_0a4329d6cc9a89db) } + +var fileDescriptor_0a4329d6cc9a89db = []byte{ + // 114 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4c, 0xad, 0x48, 0x4d, + 0x2e, 0x2d, 0xc9, 0xcc, 0xcf, 0xd3, 0x87, 0xb3, 0xf4, 0x0a, 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0x38, + 0xe1, 0x02, 0x4a, 0x52, 0x5c, 0x1c, 0xc1, 0x25, 0x45, 0xa9, 0x89, 0xb9, 0x9e, 0x2e, 0x42, 0x7c, + 0x5c, 0x4c, 0x99, 0x29, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x4c, 0x99, 0x29, 0x4e, 0xea, + 0x51, 0xaa, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, 0xc9, 0xa5, 0x49, + 0xa9, 0x46, 0x46, 0x46, 0x46, 0xfa, 0xf9, 0xc9, 0x25, 0xf9, 0xc5, 0x85, 0x39, 0x08, 0x53, 0x93, + 0xd8, 0xc0, 0xc6, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x8a, 0x05, 0x76, 0x31, 0x73, 0x00, + 0x00, 0x00, +} diff --git a/execution/execution.proto b/execution/execution.proto new file mode 100644 index 00000000..a017fccf --- /dev/null +++ b/execution/execution.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; +package execution; +option go_package = "github.com/cube2222/octosql/execution"; + +// StreamID is a unique identifier for a RecordStream node. +// This StreamID should prefix all state storage keys this node uses. +message StreamID { + string id = 1; +} diff --git a/execution/filter.go b/execution/filter.go index ec0ef53d..28cf5331 100644 --- a/execution/filter.go +++ b/execution/filter.go @@ -1,10 +1,11 @@ package execution import ( - "github.com/cube2222/octosql" - "context" + "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" + "github.com/pkg/errors" ) @@ -17,8 +18,14 @@ func NewFilter(formula Formula, child Node) *Filter { return &Filter{formula: formula, source: child} } -func (node *Filter) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - recordStream, err := node.source.Get(ctx, variables) +func (node *Filter) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + recordStream, err := node.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/group_by.go b/execution/group_by.go index 1458fdd0..e081ec2b 100644 --- a/execution/group_by.go +++ b/execution/group_by.go @@ -22,10 +22,9 @@ type Aggregate interface { type TriggerPrototype func(ctx context.Context, variables octosql.Variables) (Trigger, error) type GroupBy struct { - storage storage.Storage - source Node - sourceStoragePrefix []byte - key []Expression + storage storage.Storage + source Node + key []Expression fields []octosql.VariableName aggregatePrototypes []AggregatePrototype @@ -37,12 +36,18 @@ type GroupBy struct { triggerPrototype TriggerPrototype } -func NewGroupBy(storage storage.Storage, source Node, sourceStoragePrefix []byte, key []Expression, fields []octosql.VariableName, aggregatePrototypes []AggregatePrototype, eventTimeField octosql.VariableName, as []octosql.VariableName, outEventTimeField octosql.VariableName, triggerPrototype TriggerPrototype) *GroupBy { - return &GroupBy{storage: storage, source: source, sourceStoragePrefix: sourceStoragePrefix, key: key, fields: fields, aggregatePrototypes: aggregatePrototypes, eventTimeField: eventTimeField, as: as, outEventTimeField: outEventTimeField, triggerPrototype: triggerPrototype} +func NewGroupBy(storage storage.Storage, source Node, key []Expression, fields []octosql.VariableName, aggregatePrototypes []AggregatePrototype, eventTimeField octosql.VariableName, as []octosql.VariableName, outEventTimeField octosql.VariableName, triggerPrototype TriggerPrototype) *GroupBy { + return &GroupBy{storage: storage, source: source, key: key, fields: fields, aggregatePrototypes: aggregatePrototypes, eventTimeField: eventTimeField, as: as, outEventTimeField: outEventTimeField, triggerPrototype: triggerPrototype} } -func (node *GroupBy) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - source, err := node.source.Get(ctx, variables) +func (node *GroupBy) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + source, err := node.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get stream for source in group by") } @@ -85,14 +90,13 @@ func (node *GroupBy) Get(ctx context.Context, variables octosql.Variables) (Reco outputFieldNames: outputFieldNames, } processFunc := &ProcessByKey{ - stateStorage: node.storage, eventTimeField: node.eventTimeField, trigger: trigger, keyExpression: node.key, processFunction: groupBy, variables: variables, } - groupByPullEngine := NewPullEngine(processFunc, node.storage, source, node.sourceStoragePrefix, &ZeroWatermarkSource{}) + groupByPullEngine := NewPullEngine(processFunc, node.storage, source, streamID, &ZeroWatermarkSource{}) go groupByPullEngine.Run(ctx) // TODO: .Close() should kill this context and the goroutine. return groupByPullEngine, nil diff --git a/execution/group_by_test.go b/execution/group_by_test.go index 3c2d937b..f3b9db8c 100644 --- a/execution/group_by_test.go +++ b/execution/group_by_test.go @@ -9,6 +9,7 @@ import ( . "github.com/cube2222/octosql/execution" "github.com/cube2222/octosql/execution/trigger" "github.com/cube2222/octosql/streaming/aggregate" + "github.com/cube2222/octosql/streaming/storage" ) func TestGroupBy_SimpleBatch(t *testing.T) { @@ -27,7 +28,6 @@ func TestGroupBy_SimpleBatch(t *testing.T) { gb := NewGroupBy( stateStorage, source, - []byte{}, []Expression{NewVariable(octosql.NewVariableName("ownerid"))}, []octosql.VariableName{ octosql.NewVariableName("ownerid"), @@ -51,10 +51,14 @@ func TestGroupBy_SimpleBatch(t *testing.T) { }, ) - stream, err := gb.Get(ctx, octosql.NoVariables()) + tx := stateStorage.BeginTransaction() + stream, err := gb.Get(storage.InjectStateTransaction(context.Background(), tx), octosql.NoVariables(), GetRawStreamID()) if err != nil { t.Fatal(err) } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } outFields := []octosql.VariableName{"ownerid", "livesleft_avg", "livesleft_count"} expectedOutput := []*Record{ @@ -95,7 +99,6 @@ func TestGroupBy_BatchWithUndos(t *testing.T) { gb := NewGroupBy( stateStorage, source, - []byte{}, []Expression{NewVariable(octosql.NewVariableName("ownerid"))}, []octosql.VariableName{ octosql.NewVariableName("ownerid"), @@ -119,10 +122,14 @@ func TestGroupBy_BatchWithUndos(t *testing.T) { }, ) - stream, err := gb.Get(ctx, octosql.NoVariables()) + tx := stateStorage.BeginTransaction() + stream, err := gb.Get(storage.InjectStateTransaction(context.Background(), tx), octosql.NoVariables(), GetRawStreamID()) if err != nil { t.Fatal(err) } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } outFields := []octosql.VariableName{"ownerid", "livesleft_avg", "livesleft_count"} expectedOutput := []*Record{ @@ -163,7 +170,6 @@ func TestGroupBy_WithOutputUndos(t *testing.T) { gb := NewGroupBy( stateStorage, source, - []byte{}, []Expression{NewVariable(octosql.NewVariableName("ownerid"))}, []octosql.VariableName{ octosql.NewVariableName("ownerid"), @@ -187,10 +193,14 @@ func TestGroupBy_WithOutputUndos(t *testing.T) { }, ) - stream, err := gb.Get(ctx, octosql.NoVariables()) + tx := stateStorage.BeginTransaction() + stream, err := gb.Get(storage.InjectStateTransaction(context.Background(), tx), octosql.NoVariables(), GetRawStreamID()) if err != nil { t.Fatal(err) } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } outFields := []octosql.VariableName{"ownerid", "livesleft_avg", "livesleft_count"} expectedOutput := []*Record{ @@ -233,7 +243,6 @@ func TestGroupBy_newRecordsNoChanges(t *testing.T) { gb := NewGroupBy( stateStorage, source, - []byte{}, []Expression{NewVariable(octosql.NewVariableName("ownerid"))}, []octosql.VariableName{ octosql.NewVariableName("ownerid"), @@ -254,10 +263,14 @@ func TestGroupBy_newRecordsNoChanges(t *testing.T) { }, ) - stream, err := gb.Get(ctx, octosql.NoVariables()) + tx := stateStorage.BeginTransaction() + stream, err := gb.Get(storage.InjectStateTransaction(context.Background(), tx), octosql.NoVariables(), GetRawStreamID()) if err != nil { t.Fatal(err) } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } outFields := []octosql.VariableName{"ownerid", "livesleft_avg"} expectedOutput := []*Record{ @@ -300,7 +313,6 @@ func TestGroupBy_EventTimes(t *testing.T) { gb := NewGroupBy( stateStorage, source, - []byte{}, []Expression{ NewVariable(octosql.NewVariableName("ownerid")), NewVariable(octosql.NewVariableName("t")), @@ -330,10 +342,14 @@ func TestGroupBy_EventTimes(t *testing.T) { }, ) - stream, err := gb.Get(ctx, octosql.NoVariables()) + tx := stateStorage.BeginTransaction() + stream, err := gb.Get(storage.InjectStateTransaction(context.Background(), tx), octosql.NoVariables(), GetRawStreamID()) if err != nil { t.Fatal(err) } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } outFields := []octosql.VariableName{"renamed_t", "ownerid", "livesleft_avg", "livesleft_count"} expectedOutput := []*Record{ diff --git a/execution/innerjoin.go b/execution/innerjoin.go index 73ebe001..1f8a412c 100644 --- a/execution/innerjoin.go +++ b/execution/innerjoin.go @@ -19,8 +19,8 @@ func NewInnerJoin(prefetchCount int, source Node, joined Node) *InnerJoin { return &InnerJoin{prefetchCount: prefetchCount, source: source, joined: joined} } -func (node *InnerJoin) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - recordStream, err := node.source.Get(ctx, variables) +func (node *InnerJoin) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + recordStream, err := node.source.Get(ctx, variables, streamID) if err != nil { return nil, errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/innerjoin_test.go b/execution/innerjoin_test.go index 012374e3..4f50e7d9 100644 --- a/execution/innerjoin_test.go +++ b/execution/innerjoin_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" ) func TestInnerJoinedStream_Next(t *testing.T) { @@ -123,7 +124,11 @@ func TestInnerJoinedStream_Next(t *testing.T) { tt.fields.joined, ), } - equal, err := AreStreamsEqual(context.Background(), stream, tt.want) + stateStorage := GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(context.Background(), tx) + + equal, err := AreStreamsEqual(ctx, stream, tt.want) if (err != nil) != tt.wantErr { t.Errorf("InnerJoinedStream.Next() error = %v, wantErr %v", err, tt.wantErr) return @@ -131,6 +136,8 @@ func TestInnerJoinedStream_Next(t *testing.T) { if err == nil && !equal { t.Errorf("InnerJoinedStream.Next() streams not equal") } + + tx.Commit() }) } } diff --git a/execution/joiner.go b/execution/joiner.go index 6900f5b8..506be751 100644 --- a/execution/joiner.go +++ b/execution/joiner.go @@ -59,7 +59,7 @@ func (joiner *Joiner) fillPending(ctx context.Context) error { joiner.elements++ go func() { - joinedStream, err := joiner.joined.Get(ctx, variables) + joinedStream, err := joiner.joined.Get(ctx, variables, GetRawStreamID()) // TODO: This needs to be changed in the new joiner implementation. if err != nil { joiner.errors <- errors.Wrap(err, "couldn't get joined stream") } diff --git a/execution/leftjoin.go b/execution/leftjoin.go index 56bf5e7d..0f438815 100644 --- a/execution/leftjoin.go +++ b/execution/leftjoin.go @@ -19,8 +19,8 @@ func NewLeftJoin(prefetchCount int, source Node, joined Node) *LeftJoin { return &LeftJoin{prefetchCount: prefetchCount, source: source, joined: joined} } -func (node *LeftJoin) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - recordStream, err := node.source.Get(ctx, variables) +func (node *LeftJoin) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + recordStream, err := node.source.Get(ctx, variables, streamID) if err != nil { return nil, errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/leftjoin_test.go b/execution/leftjoin_test.go index 5bfb0056..f68b470c 100644 --- a/execution/leftjoin_test.go +++ b/execution/leftjoin_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" ) func TestLeftJoinedStream_Next(t *testing.T) { @@ -127,7 +128,11 @@ func TestLeftJoinedStream_Next(t *testing.T) { tt.fields.joined, ), } - equal, err := AreStreamsEqual(context.Background(), stream, tt.want) + stateStorage := GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(context.Background(), tx) + + equal, err := AreStreamsEqual(ctx, stream, tt.want) if (err != nil) != tt.wantErr { t.Errorf("LeftJoinedStream.Next() error = %v, wantErr %v", err, tt.wantErr) return @@ -135,6 +140,8 @@ func TestLeftJoinedStream_Next(t *testing.T) { if err == nil && !equal { t.Errorf("LeftJoinedStream.Next() streams not equal") } + + tx.Commit() }) } } diff --git a/execution/limit.go b/execution/limit.go index 4f0fd501..38655a54 100644 --- a/execution/limit.go +++ b/execution/limit.go @@ -2,6 +2,7 @@ package execution import ( "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" "context" @@ -17,8 +18,14 @@ func NewLimit(data Node, limit Expression) *Limit { return &Limit{data: data, limitExpr: limit} } -func (node *Limit) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - dataStream, err := node.data.Get(ctx, variables) +func (node *Limit) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + dataStream, err := node.data.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get data RecordStream") } diff --git a/execution/limit_test.go b/execution/limit_test.go index e8afd930..2f1e1037 100644 --- a/execution/limit_test.go +++ b/execution/limit_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" ) func TestLimit_Get(t *testing.T) { @@ -113,7 +114,11 @@ func TestLimit_Get(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rs, err := tt.node.Get(ctx, tt.vars) + stateStorage := GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(ctx, tx) + + rs, err := tt.node.Get(ctx, tt.vars, GetRawStreamID()) if (err == nil) != (tt.wantError == NO_ERROR) { t.Errorf("exactly one of test.wantError, tt.node.Get() is not nil") @@ -134,6 +139,8 @@ func TestLimit_Get(t *testing.T) { if err != nil { t.Errorf("limitedStream comparison error: %v", err) } + + tx.Commit() }) } } diff --git a/execution/map.go b/execution/map.go index 2c20b422..d5e862c2 100644 --- a/execution/map.go +++ b/execution/map.go @@ -2,6 +2,7 @@ package execution import ( "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" "context" @@ -18,8 +19,14 @@ func NewMap(expressions []NamedExpression, child Node, keep bool) *Map { return &Map{expressions: expressions, source: child, keep: keep} } -func (node *Map) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - recordStream, err := node.source.Get(ctx, variables) +func (node *Map) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + recordStream, err := node.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/offset.go b/execution/offset.go index 9159ca29..68d8dbdb 100644 --- a/execution/offset.go +++ b/execution/offset.go @@ -2,6 +2,7 @@ package execution import ( "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" "context" @@ -17,8 +18,14 @@ func NewOffset(data Node, offsetExpr Expression) *Offset { return &Offset{data: data, offsetExpr: offsetExpr} } -func (node *Offset) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - dataStream, err := node.data.Get(ctx, variables) +func (node *Offset) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + dataStream, err := node.data.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get data record stream") } diff --git a/execution/offset_test.go b/execution/offset_test.go index 334f55ed..4d60b6c7 100644 --- a/execution/offset_test.go +++ b/execution/offset_test.go @@ -4,6 +4,8 @@ import ( "context" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" + "testing" ) @@ -113,7 +115,11 @@ func TestOffset_Get(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rs, err := tt.node.Get(ctx, tt.vars) + stateStorage := GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(ctx, tx) + + rs, err := tt.node.Get(ctx, tt.vars, GetRawStreamID()) if (err == nil) != (tt.wantError == NO_ERROR) { t.Errorf("exactly one of test.wantError, tt.node.Get() is not nil") @@ -134,6 +140,8 @@ func TestOffset_Get(t *testing.T) { if err != nil { t.Errorf("limitedStream comparison error: %v", err) } + + tx.Commit() }) } } diff --git a/execution/order_by.go b/execution/order_by.go index 99f88c98..52efa755 100644 --- a/execution/order_by.go +++ b/execution/order_by.go @@ -6,6 +6,8 @@ import ( "sort" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" + "github.com/pkg/errors" ) @@ -49,8 +51,14 @@ func isSorteable(x octosql.Value) bool { panic("unreachable") } -func (ob *OrderBy) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - sourceStream, err := ob.source.Get(ctx, variables) +func (ob *OrderBy) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + sourceStream, err := ob.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get underlying stream in order by") } diff --git a/execution/process.go b/execution/process.go index db86a138..090867a8 100644 --- a/execution/process.go +++ b/execution/process.go @@ -27,8 +27,6 @@ type Trigger interface { } type ProcessByKey struct { - stateStorage storage.Storage - trigger Trigger eventTimeField octosql.VariableName // Empty if not grouping by event time. keyExpression []Expression @@ -75,7 +73,7 @@ func (p *ProcessByKey) AddRecord(ctx context.Context, tx storage.StateTransactio } func (p *ProcessByKey) triggerKeys(ctx context.Context, tx storage.StateTransaction) error { - outputQueue := NewOutputQueue(p.stateStorage.WithPrefix(outputQueuePrefix), tx.WithPrefix(outputQueuePrefix)) + outputQueue := NewOutputQueue(tx.WithPrefix(outputQueuePrefix)) for key, err := p.trigger.PollKeyToFire(ctx, tx); err != ErrNoKeyToFire; key, err = p.trigger.PollKeyToFire(ctx, tx) { if err != nil { @@ -108,7 +106,7 @@ var endOfStreamPrefix = []byte("$end_of_stream$") var outputQueuePrefix = []byte("$output_queue$") func (p *ProcessByKey) Next(ctx context.Context, tx storage.StateTransaction) (*Record, error) { - outputQueue := NewOutputQueue(p.stateStorage.WithPrefix(outputQueuePrefix), tx.WithPrefix(outputQueuePrefix)) + outputQueue := NewOutputQueue(tx.WithPrefix(outputQueuePrefix)) endOfStreamState := storage.NewValueState(tx.WithPrefix(endOfStreamPrefix)) var eos octosql.Value @@ -152,7 +150,7 @@ func (p *ProcessByKey) Next(ctx context.Context, tx storage.StateTransaction) (* } func (p *ProcessByKey) UpdateWatermark(ctx context.Context, tx storage.StateTransaction, watermark time.Time) error { - outputQueue := NewOutputQueue(p.stateStorage.WithPrefix(outputQueuePrefix), tx.WithPrefix(outputQueuePrefix)) + outputQueue := NewOutputQueue(tx.WithPrefix(outputQueuePrefix)) err := p.trigger.UpdateWatermark(ctx, tx, watermark) if err != nil { @@ -191,7 +189,7 @@ func (p *ProcessByKey) GetWatermark(ctx context.Context, tx storage.StateTransac } func (p *ProcessByKey) MarkEndOfStream(ctx context.Context, tx storage.StateTransaction) error { - outputQueue := NewOutputQueue(p.stateStorage.WithPrefix(outputQueuePrefix), tx.WithPrefix(outputQueuePrefix)) + outputQueue := NewOutputQueue(tx.WithPrefix(outputQueuePrefix)) err := outputQueue.Push(ctx, &QueueElement{Type: &QueueElement_EndOfStream{EndOfStream: true}}) if err != nil { return errors.Wrap(err, "couldn't push item to output queue") @@ -204,14 +202,12 @@ func (p *ProcessByKey) Close() error { } type OutputQueue struct { - stateStorage storage.Storage - tx storage.StateTransaction + tx storage.StateTransaction } -func NewOutputQueue(stateStorage storage.Storage, tx storage.StateTransaction) *OutputQueue { +func NewOutputQueue(tx storage.StateTransaction) *OutputQueue { return &OutputQueue{ - stateStorage: stateStorage, - tx: tx, + tx: tx, } } @@ -233,9 +229,9 @@ func (q *OutputQueue) Pop(ctx context.Context) (*QueueElement, error) { var element QueueElement err := queueElements.PopFront(&element) if err == storage.ErrNotFound { - subscription := q.stateStorage.Subscribe(ctx) + subscription := q.tx.GetUnderlyingStorage().Subscribe(ctx) - curTx := q.stateStorage.BeginTransaction() + curTx := q.tx.GetUnderlyingStorage().BeginTransaction() defer curTx.Abort() curQueueElements := storage.NewDeque(curTx.WithPrefix(queueElementsPrefix)) diff --git a/execution/process_test.go b/execution/process_test.go index e13e681d..fcc9f282 100644 --- a/execution/process_test.go +++ b/execution/process_test.go @@ -53,12 +53,12 @@ func TestOutputQueue_Ok(t *testing.T) { { tx := stateStorage.BeginTransaction() - queue := NewOutputQueue(stateStorage, tx) + queue := NewOutputQueue(tx) assert.Nil(t, queue.Push(ctx, recordElement)) assert.Nil(t, queue.Push(ctx, recordElement2)) assert.Nil(t, tx.Commit()) tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.Nil(t, queue.Push(ctx, watermarkElement)) assert.Nil(t, queue.Push(ctx, recordElement)) assert.Nil(t, queue.Push(ctx, eosElement)) @@ -66,12 +66,12 @@ func TestOutputQueue_Ok(t *testing.T) { } { tx := stateStorage.BeginTransaction() - queue := NewOutputQueue(stateStorage, tx) + queue := NewOutputQueue(tx) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(recordElement2, GetElementAssertNoError(t, ctx, queue))) assert.Nil(t, tx.Commit()) tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.True(t, proto.Equal(watermarkElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(eosElement, GetElementAssertNoError(t, ctx, queue))) @@ -93,17 +93,17 @@ func TestOutputQueue_AbortTransaction(t *testing.T) { { tx := stateStorage.BeginTransaction() - queue := NewOutputQueue(stateStorage, tx) + queue := NewOutputQueue(tx) assert.Nil(t, queue.Push(ctx, recordElement)) assert.Nil(t, queue.Push(ctx, recordElement2)) tx.Abort() tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.Nil(t, queue.Push(ctx, recordElement)) assert.Nil(t, queue.Push(ctx, recordElement2)) assert.Nil(t, tx.Commit()) tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.Nil(t, queue.Push(ctx, watermarkElement)) assert.Nil(t, queue.Push(ctx, recordElement)) assert.Nil(t, queue.Push(ctx, eosElement)) @@ -111,18 +111,18 @@ func TestOutputQueue_AbortTransaction(t *testing.T) { } { tx := stateStorage.BeginTransaction() - queue := NewOutputQueue(stateStorage, tx) + queue := NewOutputQueue(tx) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(recordElement2, GetElementAssertNoError(t, ctx, queue))) assert.Nil(t, tx.Commit()) tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.True(t, proto.Equal(watermarkElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(eosElement, GetElementAssertNoError(t, ctx, queue))) tx.Abort() tx = stateStorage.BeginTransaction() - queue = NewOutputQueue(stateStorage, tx) + queue = NewOutputQueue(tx) assert.True(t, proto.Equal(watermarkElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.True(t, proto.Equal(eosElement, GetElementAssertNoError(t, ctx, queue))) @@ -144,11 +144,11 @@ func TestOutputQueue_NewTransactionRequired(t *testing.T) { { readTx := stateStorage.BeginTransaction() - readQueue := NewOutputQueue(stateStorage, readTx) + readQueue := NewOutputQueue(readTx) { writeTx := stateStorage.BeginTransaction() - writeQueue := NewOutputQueue(stateStorage, writeTx) + writeQueue := NewOutputQueue(writeTx) assert.Nil(t, writeQueue.Push(ctx, recordElement)) assert.Nil(t, writeTx.Commit()) @@ -159,7 +159,7 @@ func TestOutputQueue_NewTransactionRequired(t *testing.T) { readTx.Abort() readTx = stateStorage.BeginTransaction() - readQueue = NewOutputQueue(stateStorage, readTx) + readQueue = NewOutputQueue(readTx) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, readQueue))) assert.Nil(t, readTx.Commit()) } @@ -179,7 +179,7 @@ func TestOutputQueue_NewTransactionRequired(t *testing.T) { { readTx := stateStorage.BeginTransaction() - readQueue := NewOutputQueue(stateStorage, readTx) + readQueue := NewOutputQueue(readTx) _, err := readQueue.Pop(ctx) assert.IsType(t, err, &ErrWaitForChanges{}) readTx.Abort() @@ -195,7 +195,7 @@ func TestOutputQueue_NewTransactionRequired(t *testing.T) { { writeTx := stateStorage.BeginTransaction() - writeQueue := NewOutputQueue(stateStorage, writeTx) + writeQueue := NewOutputQueue(writeTx) assert.Nil(t, writeQueue.Push(ctx, recordElement)) assert.Nil(t, writeTx.Commit()) @@ -203,7 +203,7 @@ func TestOutputQueue_NewTransactionRequired(t *testing.T) { { tx := stateStorage.BeginTransaction() - queue := NewOutputQueue(stateStorage, tx) + queue := NewOutputQueue(tx) assert.True(t, proto.Equal(recordElement, GetElementAssertNoError(t, ctx, queue))) assert.Nil(t, tx.Commit()) } diff --git a/execution/requalifier.go b/execution/requalifier.go index 2d60a406..0344c107 100644 --- a/execution/requalifier.go +++ b/execution/requalifier.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" ) type Requalifier struct { @@ -21,8 +22,14 @@ func NewRequalifier(qualifier string, child Node) *Requalifier { return &Requalifier{qualifier: qualifier, source: child} } -func (node *Requalifier) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - recordStream, err := node.source.Get(ctx, variables) +func (node *Requalifier) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + recordStream, err := node.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get record stream") } diff --git a/execution/sourcestorage.go b/execution/sourcestorage.go new file mode 100644 index 00000000..5fd9c2f4 --- /dev/null +++ b/execution/sourcestorage.go @@ -0,0 +1,50 @@ +package execution + +import ( + "crypto/rand" + "time" + + "github.com/pkg/errors" + + "github.com/oklog/ulid" + + "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" +) + +// A RecordStream node should use its StreamID as a prefix to all storage operations. +// This is a helper function to make that easier. +func (id *StreamID) AsPrefix() []byte { + return []byte("$" + id.Id + "$") +} + +var inputStreamIDPrefix = []byte("$input$") + +// GetRawStreamID can be used to get a new StreamID without saving it. +func GetRawStreamID() *StreamID { + id := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader) + return &StreamID{ + Id: id.String(), + } +} + +// GetSourceStreamID loads the StreamID of the given input stream in case it exists (from a previous run maybe?) +// Otherwise it allocates a new StreamID and saves it. +func GetSourceStreamID(tx storage.StateTransaction, inputName octosql.Value) (*StreamID, error) { + sourceStreamMap := storage.NewMap(tx.WithPrefix(inputStreamIDPrefix)) + + var streamID StreamID + err := sourceStreamMap.Get(&inputName, &streamID) + if err == storage.ErrNotFound { + streamID = *GetRawStreamID() + + err := sourceStreamMap.Set(&inputName, &streamID) + if err != nil { + return nil, errors.Wrap(err, "couldn't set new stream id") + } + } else if err != nil { + return nil, errors.Wrap(err, "couldn't get current value for stream id") + } + + return &streamID, nil +} diff --git a/execution/test_utils.go b/execution/test_utils.go index e925cd08..24789916 100644 --- a/execution/test_utils.go +++ b/execution/test_utils.go @@ -171,7 +171,7 @@ type DummyNode struct { data []*Record } -func (dn *DummyNode) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { +func (dn *DummyNode) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { if dn.data == nil { return NewInMemoryStream([]*Record{}), nil } diff --git a/execution/tvf/range.go b/execution/tvf/range.go index a08d7285..a64a55ea 100644 --- a/execution/tvf/range.go +++ b/execution/tvf/range.go @@ -31,7 +31,7 @@ func (r *Range) Document() docs.Documentation { ) } -func (r *Range) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (r *Range) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { start, err := r.start.ExpressionValue(ctx, variables) if err != nil { return nil, errors.Wrap(err, "couldn't get range start point") diff --git a/execution/tvf/range_test.go b/execution/tvf/range_test.go index ecb38ffc..c42ca3e3 100644 --- a/execution/tvf/range_test.go +++ b/execution/tvf/range_test.go @@ -6,6 +6,7 @@ import ( "github.com/cube2222/octosql" "github.com/cube2222/octosql/execution" + "github.com/cube2222/octosql/streaming/storage" ) func TestRange_Get(t *testing.T) { @@ -149,7 +150,12 @@ func TestRange_Get(t *testing.T) { start: tt.fields.start, end: tt.fields.end, } - got, err := r.Get(ctx, tt.args.variables) + + stateStorage := execution.GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(ctx, tx) + + got, err := r.Get(ctx, tt.args.variables, execution.GetRawStreamID()) if (err != nil) != tt.wantErr { t.Errorf("Range.Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -161,6 +167,10 @@ func TestRange_Get(t *testing.T) { if !eq { t.Errorf("Range.Get() streams not equal") } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } }) } } diff --git a/execution/tvf/tumble.go b/execution/tvf/tumble.go index 8eb78e9b..a309d666 100644 --- a/execution/tvf/tumble.go +++ b/execution/tvf/tumble.go @@ -10,6 +10,7 @@ import ( "github.com/cube2222/octosql" "github.com/cube2222/octosql/docs" "github.com/cube2222/octosql/execution" + "github.com/cube2222/octosql/streaming/storage" ) type Tumble struct { @@ -38,8 +39,14 @@ func (r *Tumble) Document() docs.Documentation { ) } -func (r *Tumble) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { - source, err := r.source.Get(ctx, variables) +func (r *Tumble) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + sourceStreamID, err := execution.GetSourceStreamID(tx.WithPrefix(streamID.AsPrefix()), octosql.MakePhantom()) + if err != nil { + return nil, errors.Wrap(err, "couldn't get source stream ID") + } + + source, err := r.source.Get(ctx, variables, sourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get source") } diff --git a/execution/tvf/tumble_test.go b/execution/tvf/tumble_test.go index 55a50f58..f7cc80d8 100644 --- a/execution/tvf/tumble_test.go +++ b/execution/tvf/tumble_test.go @@ -7,6 +7,7 @@ import ( "github.com/cube2222/octosql" "github.com/cube2222/octosql/execution" + "github.com/cube2222/octosql/streaming/storage" ) func TestTumble_Get(t *testing.T) { @@ -148,7 +149,12 @@ func TestTumble_Get(t *testing.T) { windowLength: tt.fields.windowLength, offset: tt.fields.offset, } - got, err := r.Get(ctx, tt.args.variables) + + stateStorage := execution.GetTestStorage(t) + tx := stateStorage.BeginTransaction() + ctx := storage.InjectStateTransaction(ctx, tx) + + got, err := r.Get(ctx, tt.args.variables, execution.GetRawStreamID()) if (err != nil) != tt.wantErr { t.Errorf("Tumble.Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -160,6 +166,10 @@ func TestTumble_Get(t *testing.T) { if !eq { t.Errorf("Tumble.Get() streams not equal") } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } }) } } diff --git a/execution/union_all.go b/execution/union_all.go index 11db9770..eea40931 100644 --- a/execution/union_all.go +++ b/execution/union_all.go @@ -2,6 +2,7 @@ package execution import ( "github.com/cube2222/octosql" + "github.com/cube2222/octosql/streaming/storage" "context" @@ -16,12 +17,24 @@ func NewUnionAll(first, second Node) *UnionAll { return &UnionAll{first: first, second: second} } -func (node *UnionAll) Get(ctx context.Context, variables octosql.Variables) (RecordStream, error) { - firstRecordStream, err := node.first.Get(ctx, variables) +func (node *UnionAll) Get(ctx context.Context, variables octosql.Variables, streamID *StreamID) (RecordStream, error) { + tx := storage.GetStateTransactionFromContext(ctx) + prefixedTx := tx.WithPrefix(streamID.AsPrefix()) + + firstSourceStreamID, err := GetSourceStreamID(prefixedTx, octosql.MakeInt(1)) + if err != nil { + return nil, errors.Wrap(err, "couldn't get first source stream ID") + } + secondSourceStreamID, err := GetSourceStreamID(prefixedTx, octosql.MakeInt(2)) + if err != nil { + return nil, errors.Wrap(err, "couldn't get first source stream ID") + } + + firstRecordStream, err := node.first.Get(ctx, variables, firstSourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get first record stream") } - secondRecordStream, err := node.second.Get(ctx, variables) + secondRecordStream, err := node.second.Get(ctx, variables, secondSourceStreamID) if err != nil { return nil, errors.Wrap(err, "couldn't get second record stream") } diff --git a/go.mod b/go.mod index 2c39d3cf..b4fcb94a 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/lib/pq v1.0.0 github.com/mattn/go-runewidth v0.0.4 // indirect github.com/mitchellh/hashstructure v1.0.0 + github.com/oklog/ulid v1.3.1 github.com/olekukonko/tablewriter v0.0.1 github.com/onsi/ginkgo v1.8.0 // indirect github.com/onsi/gomega v1.5.0 // indirect diff --git a/go.sum b/go.sum index 17de4e28..6ee5ea9e 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQz github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/olekukonko/tablewriter v0.0.1 h1:b3iUnf1v+ppJiOfNX4yxxqfWKMQPZR5yoh8urCTFX88= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= diff --git a/physical/group_by.go b/physical/group_by.go index fe41aded..e1704f4a 100644 --- a/physical/group_by.go +++ b/physical/group_by.go @@ -179,9 +179,7 @@ func (node *GroupBy) Transform(ctx context.Context, transformers *Transformers) } func (node *GroupBy) Materialize(ctx context.Context, matCtx *MaterializationContext) (execution.Node, error) { - sourceMatCtx, sourceStoragePrefix := matCtx.WithStoragePrefix() - - source, err := node.Source.Materialize(ctx, sourceMatCtx) + source, err := node.Source.Materialize(ctx, matCtx) if err != nil { return nil, errors.Wrap(err, "couldn't materialize Source node") } @@ -235,7 +233,7 @@ func (node *GroupBy) Materialize(ctx context.Context, matCtx *MaterializationCon meta := node.Metadata() - return execution.NewGroupBy(matCtx.Storage, source, sourceStoragePrefix, key, node.Fields, aggregatePrototypes, eventTimeField, node.As, meta.EventTimeField(), triggerPrototype), nil + return execution.NewGroupBy(matCtx.Storage, source, key, node.Fields, aggregatePrototypes, eventTimeField, node.As, meta.EventTimeField(), triggerPrototype), nil } func (node *GroupBy) groupingByEventTime(sourceMetadata *metadata.NodeMetadata) bool { diff --git a/physical/physical.go b/physical/physical.go index 77e3f9e0..abb1481a 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -2,8 +2,6 @@ package physical import ( "context" - "crypto/rand" - "log" "github.com/pkg/errors" @@ -30,19 +28,6 @@ type MaterializationContext struct { Storage storage.Storage } -func (matCtx *MaterializationContext) WithStoragePrefix() (out *MaterializationContext, prefix []byte) { - prefix = make([]byte, 8) - _, err := rand.Read(prefix) - if err != nil { - log.Fatalf("couldn't generate random byte slice: %v", err) - } - - return &MaterializationContext{ - Config: matCtx.Config, - Storage: matCtx.Storage.WithPrefix(prefix), - }, prefix -} - func NewMaterializationContext(config *config.Config, storage storage.Storage) *MaterializationContext { return &MaterializationContext{ Config: config, diff --git a/storage/csv/datasource.go b/storage/csv/datasource.go index 9f1c1708..54f1f13e 100644 --- a/storage/csv/datasource.go +++ b/storage/csv/datasource.go @@ -70,7 +70,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { file, err := os.Open(ds.path) if err != nil { return nil, errors.Wrap(err, "couldn't open file") diff --git a/storage/csv/datasource_test.go b/storage/csv/datasource_test.go index 7309a50c..6084b765 100644 --- a/storage/csv/datasource_test.go +++ b/storage/csv/datasource_test.go @@ -87,7 +87,7 @@ func TestCSVDataSource_Get(t *testing.T) { } t.Run(tt.name, func(t *testing.T) { - _, err := ds.Get(ctx, octosql.NoVariables()) + _, err := ds.Get(ctx, octosql.NoVariables(), execution.GetRawStreamID()) if (err != nil) != tt.wantErr { t.Errorf("DataSource.Get() error is %v, want %v", err, tt.wantErr) } @@ -455,7 +455,7 @@ func TestCSVRecordStream_Next(t *testing.T) { if err != nil { t.Errorf("Error creating data source: %v", err) } - rs, err := ds.Get(ctx, octosql.NoVariables()) + rs, err := ds.Get(ctx, octosql.NoVariables(), execution.GetRawStreamID()) if err != nil { t.Errorf("DataSource.Get() error: %v", err) return diff --git a/storage/excel/datasource.go b/storage/excel/datasource.go index 7b739de0..2fe4f2e2 100644 --- a/storage/excel/datasource.go +++ b/storage/excel/datasource.go @@ -82,7 +82,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { file, err := excelize.OpenFile(ds.path) if err != nil { return nil, errors.Wrap(err, "couldn't open file") diff --git a/storage/excel/datasource_test.go b/storage/excel/datasource_test.go index 04941d57..5c9aad9d 100644 --- a/storage/excel/datasource_test.go +++ b/storage/excel/datasource_test.go @@ -361,7 +361,7 @@ func TestDataSource_Get(t *testing.T) { verticalOffset: tt.fields.verticalOffset, timeColumns: tt.fields.timeColumns, } - got, err := ds.Get(ctx, tt.args.variables) + got, err := ds.Get(ctx, tt.args.variables, execution.GetRawStreamID()) if (err != nil) != tt.wantErr { t.Errorf("DataSource.Get() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/storage/json/datasource.go b/storage/json/datasource.go index 56376c22..b320947b 100644 --- a/storage/json/datasource.go +++ b/storage/json/datasource.go @@ -56,7 +56,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { file, err := os.Open(ds.path) if err != nil { return nil, errors.Wrap(err, "couldn't open file") diff --git a/storage/json/datasource_test.go b/storage/json/datasource_test.go index 9a8be46d..1b10f75b 100644 --- a/storage/json/datasource_test.go +++ b/storage/json/datasource_test.go @@ -92,7 +92,7 @@ func TestJSONRecordStream_Get(t *testing.T) { t.Errorf("Error creating data source: %v", err) } - got, err := ds.Get(ctx, octosql.NoVariables()) + got, err := ds.Get(ctx, octosql.NoVariables(), execution.GetRawStreamID()) if err != nil { t.Errorf("DataSource.Get() error: %v", err) return diff --git a/storage/mysql/datasource.go b/storage/mysql/datasource.go index a9a39e53..395f65d7 100644 --- a/storage/mysql/datasource.go +++ b/storage/mysql/datasource.go @@ -121,7 +121,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(primaryKeys), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { values := make([]interface{}, 0) for i := range ds.aliases { diff --git a/storage/mysql/datasource_test.go b/storage/mysql/datasource_test.go index 7f6326a9..fa21e1cf 100644 --- a/storage/mysql/datasource_test.go +++ b/storage/mysql/datasource_test.go @@ -314,7 +314,7 @@ func TestDataSource_Get(t *testing.T) { return } - stream, err := execNode.Get(ctx, args.variables) + stream, err := execNode.Get(ctx, args.variables, execution.GetRawStreamID()) if err != nil { t.Errorf("Couldn't get stream: %v", err) return diff --git a/storage/postgres/datasource.go b/storage/postgres/datasource.go index 0f94b9b3..8db86c3f 100644 --- a/storage/postgres/datasource.go +++ b/storage/postgres/datasource.go @@ -130,7 +130,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(primaryKeys), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { values := make([]interface{}, 0) for i := 0; i < len(ds.aliases); i++ { diff --git a/storage/postgres/datasource_test.go b/storage/postgres/datasource_test.go index dcd2e21b..bce0e6b5 100644 --- a/storage/postgres/datasource_test.go +++ b/storage/postgres/datasource_test.go @@ -315,7 +315,7 @@ func TestDataSource_Get(t *testing.T) { return } - stream, err := execNode.Get(ctx, args.variables) + stream, err := execNode.Get(ctx, args.variables, execution.GetRawStreamID()) if err != nil { t.Errorf("Couldn't get stream: %v", err) return diff --git a/storage/redis/datasource.go b/storage/redis/datasource.go index 24012bfa..9d451762 100644 --- a/storage/redis/datasource.go +++ b/storage/redis/datasource.go @@ -85,7 +85,7 @@ func NewDataSourceBuilderFactoryFromConfig(dbConfig map[string]interface{}) (phy return NewDataSourceBuilderFactory(dbKey), nil } -func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables) (execution.RecordStream, error) { +func (ds *DataSource) Get(ctx context.Context, variables octosql.Variables, streamID *execution.StreamID) (execution.RecordStream, error) { keysWanted, err := ds.keyFormula.getAllKeys(ctx, variables) if err != nil { return nil, errors.Wrap(err, "couldn't get all keys from filter") diff --git a/storage/redis/datasource_test.go b/storage/redis/datasource_test.go index bc86e94b..304a27cd 100644 --- a/storage/redis/datasource_test.go +++ b/storage/redis/datasource_test.go @@ -798,7 +798,7 @@ func TestDataSource_Get(t *testing.T) { return } - stream, err := execNode.Get(ctx, tt.args.variables) + stream, err := execNode.Get(ctx, tt.args.variables, execution.GetRawStreamID()) if err != nil && !tt.wantErr { t.Errorf("Error in Get: %v", err) return diff --git a/streaming/storage/storage.go b/streaming/storage/storage.go index e4c3d1a6..10f8ac66 100644 --- a/streaming/storage/storage.go +++ b/streaming/storage/storage.go @@ -26,7 +26,7 @@ func NewBadgerStorage(db *badger.DB) *BadgerStorage { func (bs *BadgerStorage) BeginTransaction() StateTransaction { tx := bs.db.NewTransaction(true) - return &badgerTransaction{tx: tx, prefix: bs.prefix} + return &badgerTransaction{tx: tx, prefix: bs.prefix, storage: bs} } func (bs *BadgerStorage) DropAll(prefix []byte) error { diff --git a/streaming/storage/transaction.go b/streaming/storage/transaction.go index a5d06e84..0d293f9e 100644 --- a/streaming/storage/transaction.go +++ b/streaming/storage/transaction.go @@ -17,6 +17,7 @@ type StateTransaction interface { Iterator(opts ...IteratorOption) Iterator Commit() error Abort() + GetUnderlyingStorage() Storage } type stateTransactionKey struct{} @@ -36,8 +37,9 @@ func InjectStateTransaction(ctx context.Context, tx StateTransaction) context.Co } type badgerTransaction struct { - tx *badger.Txn - prefix []byte + tx *badger.Txn + prefix []byte + storage Storage } func (tx *badgerTransaction) getKeyWithPrefix(key []byte) []byte { @@ -78,8 +80,9 @@ func (tx *badgerTransaction) GetPrefixLength() int { func (tx *badgerTransaction) WithPrefix(prefix []byte) StateTransaction { return &badgerTransaction{ - tx: tx.tx, - prefix: tx.getKeyWithPrefix(prefix), + tx: tx.tx, + prefix: tx.getKeyWithPrefix(prefix), + storage: tx.storage.WithPrefix(prefix), } } @@ -111,6 +114,10 @@ func (tx *badgerTransaction) Abort() { tx.tx.Discard() } +func (tx *badgerTransaction) GetUnderlyingStorage() Storage { + return tx.storage +} + //IteratorOptions are a copy of badger.IteratorOptions //They are used so that there is no explicit badger dependency in StateTransaction type IteratorOptions struct {