From 6789ba775944966bedf95493f8eb238deb2931e7 Mon Sep 17 00:00:00 2001 From: Wilson Wang <3913185+wilsonwang371@users.noreply.github.com> Date: Sat, 1 Oct 2022 12:03:15 -0700 Subject: [PATCH] merge support (#100) * merge support is ready, but never tested * update merge, not tested --- cmd/convert.go | 6 +++- cmd/convert_test.go | 5 +++- cmd/merge.go | 46 +++++++++++++++++++++++++----- cmd/merge_test.go | 40 +++++++++++++++++++++++++++ pkg/core/datafeed.go | 6 +++- pkg/core/strategy.go | 8 +++++- pkg/db/merge.go | 66 ++++++++++++++++++++++++++++++++++++++++++++ pkg/db/model.go | 29 ++++++++++++++++--- pkg/db/model_test.go | 10 +++++-- 9 files changed, 199 insertions(+), 17 deletions(-) create mode 100644 cmd/merge_test.go create mode 100644 pkg/db/merge.go diff --git a/cmd/convert.go b/cmd/convert.go index 1cfcd4f..d3ddd69 100644 --- a/cmd/convert.go +++ b/cmd/convert.go @@ -54,7 +54,11 @@ func ConvertFunction(cmd *cobra.Command, args []string) { } dbsource := convert.NewDBSource(convertDataSource, convertFileType) - dboutput := db.NewSQLiteDataBase(convertOutput, true) + dboutput, err := db.NewSQLiteDataBase(convertOutput, true) + if err != nil { + logger.Logger.Error("failed to create output database", zap.Error(err)) + os.Exit(1) + } if err := rt.Convert(dbsource, dboutput); err != nil { logger.Logger.Error("failed to convert data", zap.Error(err)) os.Exit(1) diff --git a/cmd/convert_test.go b/cmd/convert_test.go index ba70c55..5d55359 100644 --- a/cmd/convert_test.go +++ b/cmd/convert_test.go @@ -25,7 +25,10 @@ func TestConvertSimple(t *testing.T) { } dbsource := convert.NewDBSource("../samples/data/strategy_data.sqlite", "sqlite") - dboutput := db.NewSQLiteDataBase("../stategy_data.db", true) + dboutput, err := db.NewSQLiteDataBase("../stategy_data.db", true) + if err != nil { + t.Fatal("failed to create output db") + } if err := rt.Convert(dbsource, dboutput); err != nil { t.Fatal("failed to convert data") } diff --git a/cmd/merge.go b/cmd/merge.go index 563ba40..86c6350 100644 --- a/cmd/merge.go +++ b/cmd/merge.go @@ -1,22 +1,27 @@ package cmd import ( + "os" + "strings" + + "goat/pkg/db" + "goat/pkg/logger" "goat/pkg/notify" "goat/pkg/util" "github.com/spf13/cobra" + "go.uber.org/zap" ) var ( - mergeDataSource1 string - mergeDataSource2 string - mergeOutput string + mergeDataSourceList string + mergeOutput string ) var mergeCmd = &cobra.Command{ Use: "merge", - Short: "merge command merges two data inputs into one", - Long: `merge command merges two data inputs into one. + Short: "merge command merges two dumpdb data inputs into one", + Long: `merge command merges two dumpdb data inputs into one. `, Run: MergeFunction, } @@ -25,10 +30,37 @@ func MergeFunction(cmd *cobra.Command, args []string) { // handle panic defer util.PanicHandler(notify.NewEmailNotifier(&cfg)) - // TODO: implement merge command + var sources []*db.DB + sourceNames := strings.Split(mergeDataSourceList, ",") + for _, name := range sourceNames { + tmp, err := db.NewSQLiteDataBase(name, false) + if err != nil { + logger.Logger.Error("failed to open database", zap.Error(err)) + os.Exit(1) + } + sources = append(sources, tmp) + } + + output, err := db.NewSQLiteDataBase(mergeOutput, true) + if err != nil { + logger.Logger.Error("failed to create output database", zap.Error(err)) + os.Exit(1) + } + + if err := db.MergeDBs(output, sources); err != nil { + logger.Logger.Error("failed to merge databases", zap.Error(err)) + os.Exit(1) + } } func init() { - // TODO: add flags + mergeCmd.PersistentFlags().StringVarP(&mergeDataSourceList, + "datasources", "s", "", "data source list, separated by comma") + mergeCmd.MarkPersistentFlagRequired("datasources") + + mergeCmd.PersistentFlags().StringVarP(&mergeOutput, "output-file", "o", "", + "output file path") + mergeCmd.MarkPersistentFlagRequired("output-file") + rootCmd.AddCommand(mergeCmd) } diff --git a/cmd/merge_test.go b/cmd/merge_test.go new file mode 100644 index 0000000..e9d15d0 --- /dev/null +++ b/cmd/merge_test.go @@ -0,0 +1,40 @@ +package cmd + +import ( + "os" + "testing" + + "goat/pkg/db" + "goat/pkg/logger" + + "go.uber.org/zap" +) + +func TestMergeSimple(t *testing.T) { + var sources []*db.DB + + sourceNames := []string{ + "../samples/data/strategy_data.dumpdb", + "../samples/data/strategy_data.dumpdb", + } + for _, name := range sourceNames { + tmp, err := db.NewSQLiteDataBase(name, false) + if err != nil { + logger.Logger.Error("failed to open database", zap.Error(err)) + t.Fatal("failed to open database") + } + sources = append(sources, tmp) + } + + defer os.Remove("tempoutput.dumpdb") + output, err := db.NewSQLiteDataBase("tempoutput.dumpdb", true) + if err != nil { + logger.Logger.Error("failed to create output database", zap.Error(err)) + t.Fatal("failed to create output database") + } + + if err := db.MergeDBs(output, sources); err != nil { + logger.Logger.Error("failed to merge databases", zap.Error(err)) + t.Fatal("failed to merge databases") + } +} diff --git a/pkg/core/datafeed.go b/pkg/core/datafeed.go index b133d11..5887b3e 100644 --- a/pkg/core/datafeed.go +++ b/pkg/core/datafeed.go @@ -374,9 +374,13 @@ func NewGenericDataFeed(ctx context.Context, cfg *config.Config, fg FeedGenerato ) DataFeed { var recDB *db.DB var recCount int64 + var err error if recoveryDB != "" { logger.Logger.Debug("recovery mode is enabled", zap.String("db", recoveryDB)) - recDB = db.NewSQLiteDataBase(recoveryDB, false) + recDB, err = db.NewSQLiteDataBase(recoveryDB, false) + if err != nil { + panic(err) + } recCount = recDB.FetchAll(true) } if hooksCtrl == nil { diff --git a/pkg/core/strategy.go b/pkg/core/strategy.go index c81e6d9..8f3cf72 100644 --- a/pkg/core/strategy.go +++ b/pkg/core/strategy.go @@ -3,6 +3,7 @@ package core import ( "context" "fmt" + "os" "sync" "time" @@ -266,9 +267,14 @@ func NewStrategyController(ctx context.Context, cfg *config.Config, strategyEven dumpWg: sync.WaitGroup{}, } + var err error if cfg.Dump.BarDumpDB != "" { - controller.dumpDB = db.NewSQLiteDataBase(cfg.Dump.BarDumpDB, + controller.dumpDB, err = db.NewSQLiteDataBase(cfg.Dump.BarDumpDB, cfg.Dump.RemoveOldBars) + if err != nil { + logger.Logger.Fatal("failed to create dump db", zap.Error(err)) + os.Exit(1) + } } controller.dispatcher.AddSubject(controller.broker) diff --git a/pkg/db/merge.go b/pkg/db/merge.go new file mode 100644 index 0000000..cb7b0aa --- /dev/null +++ b/pkg/db/merge.go @@ -0,0 +1,66 @@ +package db + +import ( + "fmt" + "os" + + "goat/pkg/logger" + + "go.uber.org/zap" +) + +func MergeDBs(output *DB, sources []*DB) error { + if output == nil { + return fmt.Errorf("output db is nil") + } + if len(sources) == 0 { + return fmt.Errorf("no input db") + } + + var totalCount int64 + for _, source := range sources { + count := source.FetchAll(true) + totalCount += count + } + logger.Logger.Info("total count", zap.Int64("count", totalCount)) + +loopNext: + for { + var nextData *BarData + nextIdx := -1 + for idx, oneSource := range sources { + cmpNextData, err := oneSource.Peek() + if err != nil { + logger.Logger.Error("failed to peek data", zap.Error(err)) + os.Exit(1) + } + if cmpNextData == nil { + if len(sources) == 1 { + return nil + } + sources = append(sources[:idx], sources[idx+1:]...) + continue loopNext + } + if nextData == nil || cmpNextData.DateTime < nextData.DateTime || + (cmpNextData.Frequency < nextData.Frequency && cmpNextData.Frequency >= 0) { + nextData = cmpNextData + nextIdx = idx + } + } + bar := &BarData{ + Symbol: nextData.Symbol, + DateTime: nextData.DateTime, + Open: nextData.Open, + High: nextData.High, + Low: nextData.Low, + Close: nextData.Close, + Volume: nextData.Volume, + AdjClose: nextData.AdjClose, + Frequency: nextData.Frequency, + Note: nextData.Note, + } + output.Create(bar) + sources[nextIdx].Next() + } + // we should never reach here +} diff --git a/pkg/db/model.go b/pkg/db/model.go index ca58b0e..e44d0af 100644 --- a/pkg/db/model.go +++ b/pkg/db/model.go @@ -29,10 +29,11 @@ type BarData struct { type DB struct { *gorm.DB dataChan chan *BarData + peekData *BarData err error } -func NewSQLiteDataBase(dbpath string, removeOldData bool) *DB { +func NewSQLiteDataBase(dbpath string, removeOldData bool) (*DB, error) { if _, err := os.Stat(dbpath); err != nil && os.IsNotExist(err) { // file does not exist logger.Logger.Info("using new database file", zap.String("dbpath", dbpath)) @@ -43,14 +44,14 @@ func NewSQLiteDataBase(dbpath string, removeOldData bool) *DB { err = os.Remove(dbpath) if err != nil { logger.Logger.Fatal("failed to remove db file", zap.Error(err)) - panic(err) + return nil, err } } } db, err := gorm.Open(sqlite.Open(dbpath), &gorm.Config{}) if err != nil { logger.Logger.Error("failed to connect database", zap.Error(err)) - panic(err) + return nil, err } db.AutoMigrate(&BarData{}) @@ -58,7 +59,8 @@ func NewSQLiteDataBase(dbpath string, removeOldData bool) *DB { db, make(chan *BarData, dataBatchSize), nil, - } + nil, + }, nil } func (db *DB) fetchAll() { @@ -98,9 +100,28 @@ func (db *DB) FetchAll(bg bool) int64 { } func (db *DB) Next() (*BarData, error) { + if db.peekData != nil { + defer func() { + db.peekData = nil + }() + return db.peekData, nil + } data, ok := <-db.dataChan if !ok { return nil, db.err } return data, nil } + +func (db *DB) Peek() (*BarData, error) { + if db.peekData == nil { + data, err := db.Next() + if err != nil { + return nil, err + } + db.peekData = data + return data, nil + } else { + return db.peekData, nil + } +} diff --git a/pkg/db/model_test.go b/pkg/db/model_test.go index 7605098..a730bc3 100644 --- a/pkg/db/model_test.go +++ b/pkg/db/model_test.go @@ -17,10 +17,16 @@ func TestDBOpen(t *testing.T) { defer os.Remove(file.Name()) file.Close() - db := NewSQLiteDataBase(file.Name(), true) + db, err := NewSQLiteDataBase(file.Name(), true) + if err != nil { + t.Fatal(err) + } assert.NotNil(t, db) os.Remove("/tmp/test.999.db") - db2 := NewSQLiteDataBase("/tmp/test.999.db", false) + db2, err := NewSQLiteDataBase("/tmp/test.999.db", false) + if err != nil { + t.Fatal(err) + } assert.NotNil(t, db2) }