From 16f66f3c4f64a43c6bb25deffb648be9c0d3d986 Mon Sep 17 00:00:00 2001 From: Bill Moran Date: Fri, 10 May 2024 16:37:16 -0400 Subject: [PATCH] Disentangle CLI logic from libraries --- lib/config.go | 42 ++ lib/constants.go | 18 - lib/dbsteward.go | 656 ---------------- lib/extensions.go | 6 +- lib/format/pgsql8/constraint.go | 10 +- lib/format/pgsql8/diff.go | 179 +++-- lib/format/pgsql8/diff_constraints.go | 40 +- lib/format/pgsql8/diff_constraints_test.go | 29 +- lib/format/pgsql8/diff_functions.go | 8 +- lib/format/pgsql8/diff_languages.go | 20 +- lib/format/pgsql8/diff_sequences.go | 4 +- lib/format/pgsql8/diff_tables.go | 82 +- .../pgsql8/diff_tables_escape_char_test.go | 7 +- lib/format/pgsql8/diff_tables_test.go | 50 +- lib/format/pgsql8/diff_types.go | 6 +- lib/format/pgsql8/diff_types_domains_test.go | 13 +- lib/format/pgsql8/diff_types_test.go | 8 +- lib/format/pgsql8/diff_views.go | 14 +- lib/format/pgsql8/diff_views_test.go | 6 +- lib/format/pgsql8/function.go | 8 +- lib/format/pgsql8/language.go | 4 +- lib/format/pgsql8/oneeighty_test.go | 6 +- lib/format/pgsql8/operations.go | 86 +-- .../operations_column_value_default_test.go | 8 +- .../pgsql8/operations_extract_schema_test.go | 17 +- lib/format/pgsql8/pgsql8.go | 35 + lib/format/pgsql8/pgsql8_main_test.go | 7 +- lib/format/pgsql8/schema.go | 10 +- lib/format/pgsql8/sequence.go | 10 +- lib/format/pgsql8/table.go | 14 +- lib/format/pgsql8/table_test.go | 3 +- lib/format/pgsql8/type.go | 4 +- lib/format/pgsql8/view.go | 8 +- lib/format/pgsql8/xml_parser_test.go | 3 +- lib/loghandler.go | 74 -- lib/slonik.go | 5 + main.go | 700 +++++++++++++++++- xmlpostgresintegration_test.go | 10 +- 38 files changed, 1107 insertions(+), 1103 deletions(-) create mode 100644 lib/config.go delete mode 100644 lib/dbsteward.go delete mode 100644 lib/loghandler.go diff --git a/lib/config.go b/lib/config.go new file mode 100644 index 0000000..70676b0 --- /dev/null +++ b/lib/config.go @@ -0,0 +1,42 @@ +package lib + +import ( + "log/slog" + + "github.com/dbsteward/dbsteward/lib/ir" +) + +// Config is a structure containing all configuration information +// for any execution of code. +type Config struct { + Logger *slog.Logger + SqlFormat ir.SqlFormat + CreateLanguages bool + RequireSlonyId bool + RequireSlonySetId bool + GenerateSlonik bool + SlonyIdStartValue uint + SlonyIdSetValue uint + OutputFileStatementLimit uint + IgnoreCustomRoles bool + IgnorePrimaryKeyErrors bool + RequireVerboseIntervalNotation bool + QuoteSchemaNames bool + QuoteObjectNames bool + QuoteTableNames bool + QuoteFunctionNames bool + QuoteColumnNames bool + QuoteAllNames bool + QuoteIllegalIdentifiers bool + QuoteReservedIdentifiers bool + OnlySchemaSql bool + OnlyDataSql bool + LimitToTables map[string][]string + SingleStageUpgrade bool + FileOutputDirectory string + FileOutputPrefix string + IgnoreOldNames bool + AlwaysRecreateViews bool + OldDatabase *ir.Definition + NewDatabase *ir.Definition +} diff --git a/lib/constants.go b/lib/constants.go index 1ca2722..77be077 100644 --- a/lib/constants.go +++ b/lib/constants.go @@ -2,22 +2,4 @@ package lib import "github.com/dbsteward/dbsteward/lib/ir" -type Mode uint - -const ( - ModeUnknown Mode = 0 - ModeXmlDataInsert Mode = 1 - ModeXmlSort Mode = 2 - ModeXmlConvert Mode = 4 - ModeBuild Mode = 8 - ModeDiff Mode = 16 - ModeExtract Mode = 32 - ModeDbDataDiff Mode = 64 - ModeXmlSlonyId Mode = 73 - ModeSqlDiff Mode = 128 - ModeSlonikConvert Mode = 256 - ModeSlonyCompare Mode = 512 - ModeSlonyDiff Mode = 1024 -) - const DefaultSqlFormat = ir.SqlFormatPgsql8 diff --git a/lib/dbsteward.go b/lib/dbsteward.go deleted file mode 100644 index a0407f7..0000000 --- a/lib/dbsteward.go +++ /dev/null @@ -1,656 +0,0 @@ -package lib - -import ( - "fmt" - "log" - "log/slog" - "os" - "path" - "strings" - - "github.com/dbsteward/dbsteward/lib/config" - "github.com/dbsteward/dbsteward/lib/encoding/xml" - "github.com/dbsteward/dbsteward/lib/ir" - "github.com/dbsteward/dbsteward/lib/util" - "github.com/hashicorp/go-multierror" - - "github.com/alexflint/go-arg" - "github.com/rs/zerolog" -) - -type SlonyOperations interface { - SlonyCompare(file string) - SlonyDiff(oldFile, newFile string) -} - -// NOTE: 2.0.0 is the intended golang release. 3.0.0 is the intended refactor/modernization -const Version = "2.0.0" - -// NOTE: we're attempting to maintain "api" compat with legacy dbsteward for now -const ApiVersion = "1.4" - -type DBSteward struct { - logger zerolog.Logger - slogLogger *slog.Logger - - SqlFormat ir.SqlFormat - - CreateLanguages bool - requireSlonyId bool - requireSlonySetId bool - GenerateSlonik bool - slonyIdStartValue uint - slonyIdSetValue uint - OutputFileStatementLimit uint - IgnoreCustomRoles bool - ignorePrimaryKeyErrors bool - RequireVerboseIntervalNotation bool - QuoteSchemaNames bool - QuoteObjectNames bool - QuoteTableNames bool - QuoteFunctionNames bool - QuoteColumnNames bool - QuoteAllNames bool - QuoteIllegalIdentifiers bool - QuoteReservedIdentifiers bool - OnlySchemaSql bool - OnlyDataSql bool - LimitToTables map[string][]string - SingleStageUpgrade bool - fileOutputDirectory string - fileOutputPrefix string - IgnoreOldNames bool - AlwaysRecreateViews bool - - // TODO(go,3) just pass these explicitly! - OldDatabase *ir.Definition - NewDatabase *ir.Definition -} - -func NewDBSteward() *DBSteward { - dbsteward := &DBSteward{ - logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger(), - - SqlFormat: ir.SqlFormatUnknown, - - CreateLanguages: false, - requireSlonyId: false, - requireSlonySetId: false, - GenerateSlonik: false, - slonyIdStartValue: 1, - slonyIdSetValue: 1, - OutputFileStatementLimit: 900, - IgnoreCustomRoles: false, - ignorePrimaryKeyErrors: false, - RequireVerboseIntervalNotation: false, - QuoteSchemaNames: false, - QuoteObjectNames: false, - QuoteTableNames: false, - QuoteFunctionNames: false, - QuoteColumnNames: false, - QuoteAllNames: false, - QuoteIllegalIdentifiers: false, - QuoteReservedIdentifiers: false, - OnlySchemaSql: false, - OnlyDataSql: false, - LimitToTables: map[string][]string{}, - SingleStageUpgrade: false, - fileOutputDirectory: "", - fileOutputPrefix: "", - IgnoreOldNames: false, - AlwaysRecreateViews: true, - - OldDatabase: nil, - NewDatabase: nil, - } - - return dbsteward -} - -// correlates to dbsteward->arg_parse() -func (dbsteward *DBSteward) ArgParse() { - // TODO(go,nth): deck this out with better go-arg config - args := &config.Args{} - arg.MustParse(args) - - dbsteward.setVerbosity(args) - - // XML file parameter sanity checks - if len(args.XmlFiles) > 0 { - if len(args.OldXmlFiles) > 0 { - dbsteward.fatal("Parameter error: xml and oldxml options are not to be mixed. Did you mean newxml?") - } - if len(args.NewXmlFiles) > 0 { - dbsteward.fatal("Parameter error: xml and newxml options are not to be mixed. Did you mean oldxml?") - } - } - if len(args.OldXmlFiles) > 0 && len(args.NewXmlFiles) == 0 { - dbsteward.fatal("Parameter error: oldxml needs newxml specified for differencing to occur") - } - if len(args.NewXmlFiles) > 0 && len(args.OldXmlFiles) == 0 { - dbsteward.fatal("Parameter error: oldxml needs newxml specified for differencing to occur") - } - - // database connectivity values - // dbsteward.dbHost = args.DbHost - // dbsteward.dbPort = args.DbPort - // dbsteward.dbName = args.DbName - // dbsteward.dbUser = args.DbUser - // dbsteward.dbPass = args.DbPassword - - // SQL DDL DML DCL output flags - dbsteward.OnlySchemaSql = args.OnlySchemaSql - dbsteward.OnlyDataSql = args.OnlyDataSql - for _, onlyTable := range args.OnlyTables { - table := ParseQualifiedTableName(onlyTable) - dbsteward.LimitToTables[table.Schema] = append(dbsteward.LimitToTables[table.Schema], table.Table) - } - - // XML parsing switches - dbsteward.SingleStageUpgrade = args.SingleStageUpgrade - if dbsteward.SingleStageUpgrade { - // don't recreate views when in single stage upgrade mode - // TODO(feat) make view diffing smart enough that this doesn't need to be done - dbsteward.AlwaysRecreateViews = false - } - dbsteward.IgnoreOldNames = args.IgnoreOldNames - dbsteward.IgnoreCustomRoles = args.IgnoreCustomRoles - dbsteward.ignorePrimaryKeyErrors = args.IgnorePrimaryKeyErrors - dbsteward.requireSlonyId = args.RequireSlonyId - dbsteward.requireSlonySetId = args.RequireSlonySetId - dbsteward.GenerateSlonik = args.GenerateSlonik - dbsteward.slonyIdStartValue = args.SlonyIdStartValue - dbsteward.slonyIdSetValue = args.SlonyIdSetValue - - // determine operation and check arguments for each - mode := ModeUnknown - switch { - case len(args.XmlDataInsert) > 0: - mode = ModeXmlDataInsert - case len(args.XmlSort) > 0: - mode = ModeXmlSort - case len(args.XmlConvert) > 0: - mode = ModeXmlConvert - case len(args.XmlFiles) > 0: - mode = ModeBuild - case len(args.NewXmlFiles) > 0: - mode = ModeDiff - case args.DbSchemaDump: - mode = ModeExtract - case len(args.DbDataDiff) > 0: - mode = ModeDbDataDiff - case len(args.OldSql) > 0 || len(args.NewSql) > 0: - mode = ModeSqlDiff - case len(args.SlonikConvert) > 0: - mode = ModeSlonikConvert - case len(args.SlonyCompare) > 0: - mode = ModeSlonyCompare - case len(args.SlonyDiffOld) > 0: - mode = ModeSlonyDiff - case len(args.SlonyIdIn) > 0: - mode = ModeXmlSlonyId - } - - // validate mode parameters - if mode == ModeXmlDataInsert { - if len(args.XmlFiles) == 0 { - dbsteward.fatal("xmldatainsert needs xml parameter defined") - } else if len(args.XmlFiles) > 1 { - dbsteward.fatal("xmldatainsert only supports one xml file") - } - } - if mode == ModeExtract || mode == ModeDbDataDiff { - if len(args.DbHost) == 0 { - dbsteward.fatal("dbhost not specified") - } - if len(args.DbName) == 0 { - dbsteward.fatal("dbname not specified") - } - if len(args.DbUser) == 0 { - dbsteward.fatal("dbuser not specified") - } - if args.DbPassword == nil { - p, err := util.PromptPassword("[DBSteward] Enter password for postgres://%s@%s:%d/%s: ", args.DbUser, args.DbHost, args.DbPort, args.DbName) - dbsteward.fatalIfError(err, "Could not read password input") - args.DbPassword = &p - } - } - if mode == ModeExtract || mode == ModeSqlDiff { - if len(args.OutputFile) == 0 { - dbsteward.fatal("output file not specified") - } - } - if mode == ModeXmlSlonyId { - if len(args.SlonyIdOut) > 0 { - if args.SlonyIdIn[0] == args.SlonyIdOut { - // TODO(go,nth) resolve filepaths to do this correctly - // TODO(go,nth) check all SlonyIdIn elements - dbsteward.fatal("slonyidin and slonyidout file paths should not be the same") - } - } - } - - if len(args.OutputDir) > 0 { - if !util.IsDir(args.OutputDir) { - dbsteward.fatal("outputdir is not a directory, must be a writable directory") - } - dbsteward.fileOutputDirectory = args.OutputDir - } - dbsteward.fileOutputPrefix = args.OutputFilePrefix - - if args.XmlCollectDataAddendums > 0 { - if mode != ModeDbDataDiff { - dbsteward.fatal("--xmlcollectdataaddendums is only supported for fresh builds") - } - // dammit go - // invalid operation: args.XmlCollectDataAddendums > len(args.XmlFiles) (mismatched types uint and int) - if int(args.XmlCollectDataAddendums) > len(args.XmlFiles) { - dbsteward.fatal("Cannot collect more data addendums than files provided") - } - } - - dbsteward.Info("DBSteward Version %s", Version) - - // set the global sql format - dbsteward.SqlFormat = dbsteward.reconcileSqlFormat(ir.SqlFormatPgsql8, args.SqlFormat) - dbsteward.Info("Using sqlformat=%s", dbsteward.SqlFormat) - dbsteward.defineSqlFormatDefaultValues(dbsteward.SqlFormat, args) - - dbsteward.QuoteSchemaNames = args.QuoteSchemaNames - dbsteward.QuoteTableNames = args.QuoteTableNames - dbsteward.QuoteColumnNames = args.QuoteColumnNames - dbsteward.QuoteAllNames = args.QuoteAllNames - dbsteward.QuoteIllegalIdentifiers = args.QuoteIllegalNames - dbsteward.QuoteReservedIdentifiers = args.QuoteReservedNames - - // TODO(go,3) move all of these to separate subcommands - switch mode { - case ModeXmlDataInsert: - dbsteward.doXmlDataInsert(args.XmlFiles[0], args.XmlDataInsert) - case ModeXmlSort: - dbsteward.doXmlSort(args.XmlSort) - case ModeXmlConvert: - dbsteward.doXmlConvert(args.XmlConvert) - case ModeXmlSlonyId: - dbsteward.doXmlSlonyId(args.SlonyIdIn, args.SlonyIdOut) - case ModeBuild: - dbsteward.doBuild(args.XmlFiles, args.PgDataXml, args.XmlCollectDataAddendums) - case ModeDiff: - dbsteward.doDiff(args.OldXmlFiles, args.NewXmlFiles, args.PgDataXml) - case ModeExtract: - dbsteward.doExtract(args.DbHost, args.DbPort, args.DbName, args.DbUser, *args.DbPassword, args.OutputFile) - case ModeDbDataDiff: - dbsteward.doDbDataDiff(args.XmlFiles, args.PgDataXml, args.XmlCollectDataAddendums, args.DbHost, args.DbPort, args.DbName, args.DbUser, *args.DbPassword) - case ModeSqlDiff: - dbsteward.doSqlDiff(args.OldSql, args.NewSql, args.OutputFile) - case ModeSlonikConvert: - dbsteward.doSlonikConvert(args.SlonikConvert, args.OutputFile) - case ModeSlonyCompare: - dbsteward.doSlonyCompare(args.SlonyCompare) - case ModeSlonyDiff: - dbsteward.doSlonyDiff(args.SlonyDiffOld, args.SlonyDiffNew) - default: - dbsteward.fatal("No operation specified") - } -} - -// Logger returns an *slog.Logger pointed at the console -func (dbsteward *DBSteward) Logger() *slog.Logger { - if dbsteward == nil { - panic("dbsteward is nil") - } - if dbsteward.slogLogger == nil { - dbsteward.slogLogger = slog.New(newLogHandler(dbsteward)) - } - return dbsteward.slogLogger -} - -func (dbsteward *DBSteward) fatal(s string, args ...interface{}) { - dbsteward.logger.Fatal().Msgf(s, args...) -} -func (dbsteward *DBSteward) fatalIfError(err error, s string, args ...interface{}) { - if err != nil { - dbsteward.logger.Fatal().Err(err).Msgf(s, args...) - } -} - -func (dbsteward *DBSteward) warning(s string, args ...interface{}) { - dbsteward.logger.Warn().Msgf(s, args...) -} - -func (dbsteward *DBSteward) Info(s string, args ...interface{}) { - dbsteward.logger.Info().Msgf(s, args...) -} - -// dbsteward::set_verbosity($options) -func (dbsteward *DBSteward) setVerbosity(args *config.Args) { - // TODO(go,nth): differentiate between notice and info - - // remember, lower level is higher verbosity - // we're abusing the fact that zerolog.LogLevel is defined as an int8 - level := zerolog.InfoLevel - - if args.Debug { - level = zerolog.TraceLevel - } - - for _, v := range args.Verbose { - if v { - level -= 1 - } else { - level += 1 - } - } - for _, q := range args.Quiet { - if q { - level += 1 - } else { - level -= 1 - } - } - - // clamp it to valid values - if level > zerolog.PanicLevel { - level = zerolog.PanicLevel - } - if level < zerolog.TraceLevel { - level = zerolog.TraceLevel - } - - dbsteward.logger = dbsteward.logger.Level(level) -} - -func (dbsteward *DBSteward) reconcileSqlFormat(target, requested ir.SqlFormat) ir.SqlFormat { - if target != ir.SqlFormatUnknown { - if requested != ir.SqlFormatUnknown { - if target == requested { - return target - } - - dbsteward.warning("XML is targeted for %s but you are forcing %s. Things will probably break!", target, requested) - return requested - } - - dbsteward.Info("XML file(s) are targetd for sqlformat=%s", target) - return target - } - - if requested != ir.SqlFormatUnknown { - return requested - } - - return DefaultSqlFormat -} - -func (dbsteward *DBSteward) defineSqlFormatDefaultValues(SqlFormat ir.SqlFormat, args *config.Args) { - switch SqlFormat { - case ir.SqlFormatPgsql8: - dbsteward.CreateLanguages = true - dbsteward.QuoteSchemaNames = false - dbsteward.QuoteTableNames = false - dbsteward.QuoteColumnNames = false - if args.DbPort == 0 { - args.DbPort = 5432 - } - } - - if SqlFormat != ir.SqlFormatPgsql8 { - if len(args.PgDataXml) > 0 { - dbsteward.fatal("pgdataxml parameter is not supported by %s driver", SqlFormat) - } - } -} - -func (dbsteward *DBSteward) calculateFileOutputPrefix(files []string) string { - return path.Join( - dbsteward.calculateFileOutputDirectory(files[0]), - util.CoalesceStr(dbsteward.fileOutputPrefix, util.Basename(files[0], ".xml")), - ) -} -func (dbsteward *DBSteward) calculateFileOutputDirectory(file string) string { - return util.CoalesceStr(dbsteward.fileOutputDirectory, path.Dir(file)) -} - -// Append columns in a table's rows collection, based on a simplified XML definition of what to insert -func (dbsteward *DBSteward) doXmlDataInsert(defFile string, dataFile string) { - // TODO(go,xmlutil) verify this behavior is correct, add tests. need to change fatals to returns - dbsteward.Info("Automatic insert data into %s from %s", defFile, dataFile) - defDoc, err := xml.LoadDefintion(defFile) - dbsteward.fatalIfError(err, "Failed to load %s", defFile) - - dataDoc, err := xml.LoadDefintion(dataFile) - dbsteward.fatalIfError(err, "Failed to load %s", dataFile) - - for _, dataSchema := range dataDoc.Schemas { - defSchema, err := defDoc.GetSchemaNamed(dataSchema.Name) - dbsteward.fatalIfError(err, "while searching %s", defFile) - for _, dataTable := range dataSchema.Tables { - defTable, err := defSchema.GetTableNamed(dataTable.Name) - dbsteward.fatalIfError(err, "while searching %s", defFile) - - dataRows := dataTable.Rows - if dataRows == nil { - dbsteward.fatal("table %s in %s does not have a element", dataTable.Name, dataFile) - } - - if len(dataRows.Columns) == 0 { - dbsteward.fatal("Unexpected: no rows[columns] found in table %s in file %s", dataTable.Name, dataFile) - } - - if len(dataRows.Rows) > 1 { - dbsteward.fatal("Unexpected: more than one rows->row found in table %s in file %s", dataTable.Name, dataFile) - } - - if len(dataRows.Rows[0].Columns) != len(dataRows.Columns) { - dbsteward.fatal("Unexpected: Table %s in %s defines %d colums but has %d elements", - dataTable.Name, dataFile, len(dataRows.Columns), len(dataRows.Rows[0].Columns)) - } - - for i, newColumn := range dataRows.Columns { - dbsteward.Info("Adding rows column %s to definition table %s", newColumn, defTable.Name) - - if defTable.Rows == nil { - defTable.Rows = &ir.DataRows{} - } - err = defTable.Rows.AddColumn(newColumn, dataRows.Columns[i]) - dbsteward.fatalIfError(err, "Could not add column %s to %s in %s", newColumn, dataTable.Name, dataFile) - } - } - } - - defFileModified := defFile + ".xmldatainserted" - dbsteward.Info("Saving modified dbsteward definition as %s", defFileModified) - err = xml.SaveDefinition(dbsteward.Logger(), defFileModified, defDoc) - dbsteward.fatalIfError(err, "saving file") -} -func (dbsteward *DBSteward) doXmlSort(files []string) { - for _, file := range files { - sortedFileName := file + ".xmlsorted" - dbsteward.Info("Sorting XML definition file: %s", file) - dbsteward.Info("Sorted XML output file: %s", sortedFileName) - xml.FileSort(file, sortedFileName) - } -} -func (dbsteward *DBSteward) doXmlConvert(files []string) { - for _, file := range files { - convertedFileName := file + ".xmlconverted" - dbsteward.Info("Upconverting XML definition file: %s", file) - dbsteward.Info("Upconvert XML output file: %s", convertedFileName) - - doc, err := xml.LoadDefintion(file) - dbsteward.fatalIfError(err, "Could not load %s", file) - xml.SqlFormatConvert(doc) - convertedXml, err := xml.FormatXml(dbsteward.Logger(), doc) - dbsteward.fatalIfError(err, "formatting xml") - convertedXml = strings.Replace(convertedXml, "pgdbxml>", "dbsteward>", -1) - err = util.WriteFile(convertedXml, convertedFileName) - dbsteward.fatalIfError(err, "Could not write converted xml to %s", convertedFileName) - } -} -func (dbsteward *DBSteward) doXmlSlonyId(files []string, slonyOut string) { - dbsteward.Info("Compositing XML file for Slony ID processing") - dbDoc, err := xml.XmlComposite(dbsteward.Logger(), files) - dbsteward.fatalIfError(err, "compositing files: %v", files) - dbsteward.Info("Xml files %s composited", strings.Join(files, " ")) - - outputPrefix := dbsteward.calculateFileOutputPrefix(files) - compositeFile := outputPrefix + "_composite.xml" - dbsteward.Info("Saving composite as %s", compositeFile) - err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) - dbsteward.fatalIfError(err, "saving file") - - dbsteward.Info("Slony ID numbering any missing attributes") - dbsteward.Info("slonyidstartvalue = %d", dbsteward.slonyIdStartValue) - dbsteward.Info("slonyidsetvalue = %d", dbsteward.slonyIdSetValue) - slonyIdDoc := xml.SlonyIdNumber(dbDoc) - slonyIdNumberedFile := outputPrefix + "_slonyid_numbered.xml" - if len(slonyOut) > 0 { - slonyIdNumberedFile = slonyOut - } - dbsteward.Info("Saving Slony ID numbered XML as %s", slonyIdNumberedFile) - err = xml.SaveDefinition(dbsteward.Logger(), slonyIdNumberedFile, slonyIdDoc) - dbsteward.fatalIfError(err, "saving file") -} -func (dbsteward *DBSteward) doBuild(files []string, dataFiles []string, addendums uint) { - dbsteward.Info("Compositing XML files...") - if addendums > 0 { - dbsteward.Info("Collecting %d data addendums", addendums) - } - dbDoc, addendumsDoc, err := xml.XmlCompositeAddendums(dbsteward.Logger(), files, addendums) - if err != nil { - mErr, isMErr := err.(*multierror.Error) - if isMErr { - for _, e := range mErr.Errors { - log.Println(e.Error()) - } - } else { - log.Println(err.Error()) - } - os.Exit(1) - } - if len(dataFiles) > 0 { - dbsteward.Info("Compositing pgdata XML files on top of XML composite...") - xml.XmlCompositePgData(dbDoc, dataFiles) - dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) - } - - dbsteward.Info("XML files %s composited", strings.Join(files, " ")) - - outputPrefix := dbsteward.calculateFileOutputPrefix(files) - compositeFile := outputPrefix + "_composite.xml" - dbsteward.Info("Saving composite as %s", compositeFile) - err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) - dbsteward.fatalIfError(err, "saving file") - - if addendumsDoc != nil { - addendumsFile := outputPrefix + "_addendums.xml" - dbsteward.Info("Saving addendums as %s", addendumsFile) - err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, addendumsDoc) - dbsteward.fatalIfError(err, "saving file") - } - - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - err = ops(dbsteward).Build(outputPrefix, dbDoc) - dbsteward.fatalIfError(err, "building") -} -func (dbsteward *DBSteward) doDiff(oldFiles []string, newFiles []string, dataFiles []string) { - dbsteward.Info("Compositing old XML files...") - oldDbDoc, err := xml.XmlComposite(dbsteward.Logger(), oldFiles) - dbsteward.fatalIfError(err, "compositing") - dbsteward.Info("Old XML files %s composited", strings.Join(oldFiles, " ")) - - dbsteward.Info("Compositing new XML files...") - newDbDoc, err := xml.XmlComposite(dbsteward.Logger(), newFiles) - dbsteward.fatalIfError(err, "compositing") - if len(dataFiles) > 0 { - dbsteward.Info("Compositing pgdata XML files on top of new XML composite...") - xml.XmlCompositePgData(newDbDoc, dataFiles) - dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) - } - dbsteward.Info("New XML files %s composited", strings.Join(newFiles, " ")) - - oldOutputPrefix := dbsteward.calculateFileOutputPrefix(oldFiles) - oldCompositeFile := oldOutputPrefix + "_composite.xml" - dbsteward.Info("Saving composite as %s", oldCompositeFile) - err = xml.SaveDefinition(dbsteward.Logger(), oldCompositeFile, oldDbDoc) - dbsteward.fatalIfError(err, "saving file") - - newOutputPrefix := dbsteward.calculateFileOutputPrefix(newFiles) - newCompositeFile := newOutputPrefix + "_composite.xml" - dbsteward.Info("Saving composite as %s", newCompositeFile) - err = xml.SaveDefinition(dbsteward.Logger(), newCompositeFile, newDbDoc) - dbsteward.fatalIfError(err, "saving file") - - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - err = ops(dbsteward).BuildUpgrade( - oldOutputPrefix, oldCompositeFile, oldDbDoc, oldFiles, - newOutputPrefix, newCompositeFile, newDbDoc, newFiles, - ) - dbsteward.fatalIfError(err, "building upgrade") -} -func (dbsteward *DBSteward) doExtract(dbHost string, dbPort uint, dbName, dbUser, dbPass string, outputFile string) { - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - output, err := ops(dbsteward).ExtractSchema(dbHost, dbPort, dbName, dbUser, dbPass) - dbsteward.fatalIfError(err, "extracting") - dbsteward.Info("Saving extracted database schema to %s", outputFile) - err = xml.SaveDefinition(dbsteward.Logger(), outputFile, output) - dbsteward.fatalIfError(err, "saving file") -} -func (dbsteward *DBSteward) doDbDataDiff(files []string, dataFiles []string, addendums uint, dbHost string, dbPort uint, dbName, dbUser, dbPass string) { - dbsteward.Info("Compositing XML files...") - if addendums > 0 { - dbsteward.Info("Collecting %d data addendums", addendums) - } - // TODO(feat) can this just be XmlComposite(files)? why do we need addendums? - dbDoc, _, err := xml.XmlCompositeAddendums(dbsteward.Logger(), files, addendums) - dbsteward.fatalIfError(err, "compositing addendums") - - if len(dataFiles) > 0 { - dbsteward.Info("Compositing pgdata XML files on top of XML composite...") - xml.XmlCompositePgData(dbDoc, dataFiles) - dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) - } - - dbsteward.Info("XML files %s composited", strings.Join(files, " ")) - - outputPrefix := dbsteward.calculateFileOutputPrefix(files) - compositeFile := outputPrefix + "_composite.xml" - dbsteward.Info("Saving composite as %s", compositeFile) - err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) - dbsteward.fatalIfError(err, "saving file") - - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - output, err := ops(dbsteward).CompareDbData(dbDoc, dbHost, dbPort, dbName, dbUser, dbPass) - dbsteward.fatalIfError(err, "comparing data") - err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, output) - dbsteward.fatalIfError(err, "saving file") -} -func (dbsteward *DBSteward) doSqlDiff(oldSql, newSql []string, outputFile string) { - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - ops(dbsteward).SqlDiff(oldSql, newSql, outputFile) -} -func (dbsteward *DBSteward) doSlonikConvert(file string, outputFile string) { - // TODO(go,nth) is there a nicer way to handle this output idiom? - output := NewSlonik().Convert(file) - if len(outputFile) > 0 { - err := util.WriteFile(output, outputFile) - dbsteward.fatalIfError(err, "Failed to save slonikconvert output to %s", outputFile) - } else { - fmt.Println(output) - } -} -func (dbsteward *DBSteward) doSlonyCompare(file string) { - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - ops(dbsteward).(SlonyOperations).SlonyCompare(file) -} -func (dbsteward *DBSteward) doSlonyDiff(oldFile string, newFile string) { - ops, err := Format(DefaultSqlFormat) - dbsteward.fatalIfError(err, "loading default format") - ops(dbsteward).(SlonyOperations).SlonyDiff(oldFile, newFile) -} diff --git a/lib/extensions.go b/lib/extensions.go index 50443d9..f1d1f18 100644 --- a/lib/extensions.go +++ b/lib/extensions.go @@ -24,17 +24,17 @@ type Operations interface { type Encoding interface { } -var formats = make(map[ir.SqlFormat]func(*DBSteward) Operations) +var formats = make(map[ir.SqlFormat]func(Config) Operations) var formatMutex sync.Mutex -func RegisterFormat(id ir.SqlFormat, constructor func(*DBSteward) Operations) { +func RegisterFormat(id ir.SqlFormat, constructor func(Config) Operations) { formatMutex.Lock() defer formatMutex.Unlock() formats[id] = constructor } -func Format(id ir.SqlFormat) (func(*DBSteward) Operations, error) { +func Format(id ir.SqlFormat) (func(Config) Operations, error) { formatMutex.Lock() defer formatMutex.Unlock() constructor, exists := formats[id] diff --git a/lib/format/pgsql8/constraint.go b/lib/format/pgsql8/constraint.go index b721ab0..01e648f 100644 --- a/lib/format/pgsql8/constraint.go +++ b/lib/format/pgsql8/constraint.go @@ -268,8 +268,8 @@ func getTableContraintCreationSql(constraint *sql99.TableConstraint) []output.To return nil } -func constraintDependsOnRenamedTable(dbs *lib.DBSteward, doc *ir.Definition, constraint *sql99.TableConstraint) (bool, error) { - if dbs.IgnoreOldNames { +func constraintDependsOnRenamedTable(conf lib.Config, doc *ir.Definition, constraint *sql99.TableConstraint) (bool, error) { + if conf.IgnoreOldNames { return false, nil } @@ -294,16 +294,16 @@ func constraintDependsOnRenamedTable(dbs *lib.DBSteward, doc *ir.Definition, con if refTable == nil { return false, nil } - isRenamed := dbs.IgnoreOldNames + isRenamed := conf.IgnoreOldNames if !isRenamed { var err error - isRenamed, err = dbs.OldDatabase.IsRenamedTable(slog.Default(), refSchema, refTable) + isRenamed, err = conf.OldDatabase.IsRenamedTable(slog.Default(), refSchema, refTable) if err != nil { return false, fmt.Errorf("while checking if constraint depends on renamed table: %w", err) } } if isRenamed { - dbs.Logger().Info(fmt.Sprintf("Constraint %s.%s.%s references renamed table %s.%s", constraint.Schema.Name, constraint.Table.Name, constraint.Name, refSchema.Name, refTable.Name)) + conf.Logger.Info(fmt.Sprintf("Constraint %s.%s.%s references renamed table %s.%s", constraint.Schema.Name, constraint.Table.Name, constraint.Name, refSchema.Name, refTable.Name)) return true, nil } return false, nil diff --git a/lib/format/pgsql8/diff.go b/lib/format/pgsql8/diff.go index 81b082d..f56eaf8 100644 --- a/lib/format/pgsql8/diff.go +++ b/lib/format/pgsql8/diff.go @@ -63,26 +63,26 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme // this shouldn't be called if we're not generating slonik, it looks for // a slony element in which most likely won't be there if // we're not interested in slony replication - if d.ops.dbsteward.GenerateSlonik { + if d.ops.config.GenerateSlonik { // TODO(go,slony) } // stage 1 and 3 should not be in a transaction as they will be submitted via slonik EXECUTE SCRIPT - if !d.ops.dbsteward.GenerateSlonik { + if !d.ops.config.GenerateSlonik { stage1.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage1.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) } else { stage1.AppendHeader(sql.NewComment("generateslonik specified: pgsql8 STAGE1 upgrade omitting BEGIN. slonik EXECUTE SCRIPT will wrap stage 1 DDL and DCL in a transaction")) } - if !d.ops.dbsteward.SingleStageUpgrade { + if !d.ops.config.SingleStageUpgrade { stage2.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage2.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) stage4.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage4.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) // stage 1 and 3 should not be in a transaction as they will be submitted via slonik EXECUTE SCRIPT - if !d.ops.dbsteward.GenerateSlonik { + if !d.ops.config.GenerateSlonik { stage3.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage3.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) } else { @@ -91,13 +91,13 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme } // start with pre-upgrade sql statements that prepare the database to take on its changes - buildStagedSql(d.ops.dbsteward.NewDatabase, stage1, "STAGE1BEFORE") - buildStagedSql(d.ops.dbsteward.NewDatabase, stage2, "STAGE2BEFORE") + buildStagedSql(d.ops.config.NewDatabase, stage1, "STAGE1BEFORE") + buildStagedSql(d.ops.config.NewDatabase, stage2, "STAGE2BEFORE") - d.ops.dbsteward.Logger().Info("Drop Old Schemas") + d.ops.config.Logger.Info("Drop Old Schemas") d.DropOldSchemas(stage3) - d.ops.dbsteward.Logger().Info("Create New Schemas") + d.ops.config.Logger.Info("Create New Schemas") err := d.CreateNewSchemas(stage1) if err != nil { return err @@ -108,16 +108,16 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme return err } - d.ops.dbsteward.Logger().Info("Update Permissions") + d.ops.config.Logger.Info("Update Permissions") err = d.updatePermissions(stage1, stage3) if err != nil { return err } - d.UpdateDatabaseConfigParameters(stage1, d.ops.dbsteward.NewDatabase, d.ops.dbsteward.OldDatabase) + d.UpdateDatabaseConfigParameters(stage1, d.ops.config.NewDatabase, d.ops.config.OldDatabase) - d.ops.dbsteward.Logger().Info("Update data") - if d.ops.dbsteward.GenerateSlonik { + d.ops.config.Logger.Info("Update data") + if d.ops.config.GenerateSlonik { // TODO(go,slony) format::set_context_replica_set_to_natural_first(dbsteward::$new_database); } err = d.updateData(stage2, true) @@ -131,14 +131,14 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme // append any literal sql in new not in old at the end of data stage 1 // TODO(feat) this relies on exact string match - is there a better way? - for _, newSql := range d.ops.dbsteward.NewDatabase.Sql { + for _, newSql := range d.ops.config.NewDatabase.Sql { // ignore upgrade staged sql elements if newSql.Stage != "" { continue } found := false - for _, oldSql := range d.ops.dbsteward.OldDatabase.Sql { + for _, oldSql := range d.ops.config.OldDatabase.Sql { // ignore upgrade staged sql elements if oldSql.Stage != "" { continue @@ -156,14 +156,14 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme } // append stage defined sql statements to appropriate stage file - if d.ops.dbsteward.GenerateSlonik { + if d.ops.config.GenerateSlonik { // TODO(go,slony) format::set_context_replica_set_to_natural_first(dbsteward::$new_database); } - buildStagedSql(d.ops.dbsteward.NewDatabase, stage1, "STAGE1") - buildStagedSql(d.ops.dbsteward.NewDatabase, stage2, "STAGE2") - buildStagedSql(d.ops.dbsteward.NewDatabase, stage3, "STAGE3") - buildStagedSql(d.ops.dbsteward.NewDatabase, stage4, "STAGE4") + buildStagedSql(d.ops.config.NewDatabase, stage1, "STAGE1") + buildStagedSql(d.ops.config.NewDatabase, stage2, "STAGE2") + buildStagedSql(d.ops.config.NewDatabase, stage3, "STAGE3") + buildStagedSql(d.ops.config.NewDatabase, stage4, "STAGE4") return nil } @@ -172,54 +172,50 @@ func (d *diff) DiffSql(old, new []string, upgradePrefix string) { } func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter) error { - logger := d.ops.dbsteward.Logger() + logger := d.ops.config.Logger logger.Info("Update Structure") - err := diffLanguages(d.ops.dbsteward, stage1) + err := diffLanguages(d.ops.config, stage1) if err != nil { return err } // drop all views in all schemas, regardless whether dependency order is known or not // TODO(go,4) would be so cool if we could parse the view def and only recreate what's required - dropViewsOrdered(stage1, d.ops.dbsteward.OldDatabase, d.ops.dbsteward.NewDatabase) + dropViewsOrdered(stage1, d.ops.config.OldDatabase, d.ops.config.NewDatabase) // TODO(go,3) should we just always use table deps? if len(d.NewTableDependency) == 0 { logger.Debug("not using table dependencies") - for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { - l := logger.With(slog.String("new schema", newSchema.Name)) - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) - if oldSchema != nil { - l = l.With(slog.String("old schema", oldSchema.Name)) - } - err := diffTypes(d.ops.dbsteward, d, stage1, oldSchema, newSchema) + for _, newSchema := range d.ops.config.NewDatabase.Schemas { + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) + err := diffTypes(d.ops.config, d, stage1, oldSchema, newSchema) if err != nil { return err } - err = diffFunctions(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) + err = diffFunctions(d.ops.config, stage1, stage3, oldSchema, newSchema) if err != nil { return err } - err = diffSequences(d.ops.dbsteward, stage1, oldSchema, newSchema) + err = diffSequences(d.ops.config, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing sequences: %w", err) } // remove old constraints before table constraints, so the sql statements succeed - err = dropConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) + err = dropConstraints(d.ops.config, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) if err != nil { return err } - err = dropConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) + err = dropConstraints(d.ops.config, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) if err != nil { return err } - dropTables(d.ops.dbsteward, stage1, oldSchema, newSchema) - err = createTables(d.ops.dbsteward, stage1, oldSchema, newSchema) + dropTables(d.ops.config, stage1, oldSchema, newSchema) + err = createTables(d.ops.config, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while creating tables: %w", err) } - err = diffTables(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) + err = diffTables(d.ops.config, stage1, stage3, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing tables: %w", err) } @@ -228,7 +224,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. return err } diffClusters(stage1, oldSchema, newSchema) - createConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) + createConstraints(d.ops.config, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) err = diffTriggers(stage1, oldSchema, newSchema) if err != nil { return err @@ -236,9 +232,9 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. } // non-primary key constraints may be inter-schema dependant, and dependant on other's primary keys // and therefore should be done after object creation sections - for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) - createConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) + for _, newSchema := range d.ops.config.NewDatabase.Schemas { + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) + createConstraints(d.ops.config, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) } } else { logger.Debug("using table dependencies") @@ -247,14 +243,14 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. processedSchemas := map[string]bool{} for _, newEntry := range d.NewTableDependency { newSchema := newEntry.Schema - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) if !processedSchemas[newSchema.Name] { - err := diffTypes(d.ops.dbsteward, d, stage1, oldSchema, newSchema) + err := diffTypes(d.ops.config, d, stage1, oldSchema, newSchema) if err != nil { return err } - err = diffFunctions(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) + err = diffFunctions(d.ops.config, stage1, stage3, oldSchema, newSchema) if err != nil { return err } @@ -268,7 +264,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. oldSchema := oldEntry.Schema oldTable := oldEntry.Table - newSchema := d.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) + newSchema := d.ops.config.NewDatabase.TryGetSchemaNamed(oldSchema.Name) var newTable *ir.Table if newSchema != nil { newTable = newSchema.TryGetTableNamed(oldTable.Name) @@ -276,11 +272,11 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // NOTE: when dropping constraints, GlobalDBX.RenamedTableCheckPointer() is not called for oldTable // as GlobalDiffConstraints.DiffConstraintsTable() will do rename checking when recreating constraints for renamed tables - err := dropConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint) + err := dropConstraintsTable(d.ops.config, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint) if err != nil { return err } - err = dropConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) + err = dropConstraintsTable(d.ops.config, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -289,13 +285,13 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. processedSchemas = map[string]bool{} for _, newEntry := range d.NewTableDependency { newSchema := newEntry.Schema - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) // schema level stuff should only be done once, keep track of which ones we have done // see above for pre table creation stuff // see below for post table creation stuff if !processedSchemas[newSchema.Name] { - err := diffSequences(d.ops.dbsteward, stage1, oldSchema, newSchema) + err := diffSequences(d.ops.config, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing sequences: %w", err) } @@ -313,15 +309,15 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // when a table has an oldTableName oldSchemaName specified, // GlobalDBX.RenamedTableCheckPointer() will modify these pointers to be the old table var err error - oldSchema, oldTable, err = d.ops.dbsteward.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) + oldSchema, oldTable, err = d.ops.config.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) if err != nil { return fmt.Errorf("getting new table name: %w", err) } - err = createTable(d.ops.dbsteward, stage1, oldSchema, newSchema, newTable) + err = createTable(d.ops.config, stage1, oldSchema, newSchema, newTable) if err != nil { return fmt.Errorf("while creating table %s.%s: %w", newSchema.Name, newTable.Name, err) } - err = diffTable(d.ops.dbsteward, stage1, stage3, oldSchema, oldTable, newSchema, newTable) + err = diffTable(d.ops.config, stage1, stage3, oldSchema, oldTable, newSchema, newTable) if err != nil { return fmt.Errorf("while diffing table %s.%s: %w", newSchema.Name, newTable.Name, err) } @@ -330,7 +326,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. return err } diffClustersTable(stage1, oldTable, newSchema, newTable) - err = createConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) + err = createConstraintsTable(d.ops.config, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -341,7 +337,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // HACK: For now, we'll generate foreign key constraints in stage 4 in updateData below // https://github.com/dbsteward/dbsteward/issues/142 - err = createConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint&^sql99.ConstraintTypeForeign) + err = createConstraintsTable(d.ops.config, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint&^sql99.ConstraintTypeForeign) if err != nil { return err } @@ -353,30 +349,25 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. oldSchema := oldEntry.Schema oldTable := oldEntry.Table - newSchema := d.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) - dropTable(d.ops.dbsteward, stage3, oldSchema, oldTable, newSchema) + newSchema := d.ops.config.NewDatabase.TryGetSchemaNamed(oldSchema.Name) + dropTable(d.ops.config, stage3, oldSchema, oldTable, newSchema) } } - return createViewsOrdered(d.ops.dbsteward, stage3, d.ops.dbsteward.OldDatabase, d.ops.dbsteward.NewDatabase) + return createViewsOrdered(d.ops.config, stage3, d.ops.config.OldDatabase, d.ops.config.NewDatabase) } func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter) error { // TODO(feat) what if readonly user changed? we need to rebuild those grants // TODO(feat) what about removed permissions, shouldn't we REVOKE those? - newDoc := d.ops.dbsteward.NewDatabase - oldDoc := d.ops.dbsteward.OldDatabase - logger := d.ops.dbsteward.Logger() + newDoc := d.ops.config.NewDatabase + oldDoc := d.ops.config.OldDatabase for _, newSchema := range newDoc.Schemas { - l := logger.With(slog.String("new schema", newSchema.Name)) oldSchema := oldDoc.TryGetSchemaNamed(newSchema.Name) - if oldSchema != nil { - l = l.With(slog.String("old schema", oldSchema.Name)) - } for _, newGrant := range newSchema.Grants { if oldSchema == nil || !ir.HasPermissionsOf(oldSchema, newGrant, ir.SqlFormatPgsql8) { - s, err := commonSchema.GetGrantSql(d.ops.dbsteward, newDoc, newSchema, newGrant) + s, err := commonSchema.GetGrantSql(d.ops.config, newDoc, newSchema, newGrant) if err != nil { return err } @@ -386,7 +377,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu for _, newTable := range newSchema.Tables { oldTable := oldSchema.TryGetTableNamed(newTable.Name) - isRenamed, err := d.ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := d.ops.config.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while updating permissions: %w", err) } @@ -397,7 +388,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu } for _, newGrant := range newTable.Grants { if oldTable == nil || !ir.HasPermissionsOf(oldTable, newGrant, ir.SqlFormatPgsql8) { - s, err := getTableGrantSql(d.ops.dbsteward, newSchema, newTable, newGrant) + s, err := getTableGrantSql(d.ops.config, newSchema, newTable, newGrant) if err != nil { return err } @@ -410,7 +401,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu oldSeq := oldSchema.TryGetSequenceNamed(newSeq.Name) for _, newGrant := range newSeq.Grants { if oldSeq == nil || !ir.HasPermissionsOf(oldSeq, newGrant, ir.SqlFormatPgsql8) { - s, err := getSequenceGrantSql(d.ops.dbsteward, newSchema, newSeq, newGrant) + s, err := getSequenceGrantSql(d.ops.config, newSchema, newSeq, newGrant) if err != nil { return err } @@ -423,7 +414,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu oldFunc := oldSchema.TryGetFunctionMatching(newFunc) for _, newGrant := range newFunc.Grants { if oldFunc == nil || !ir.HasPermissionsOf(oldFunc, newGrant, ir.SqlFormatPgsql8) { - grants, err := getFunctionGrantSql(d.ops.dbsteward, newSchema, newFunc, newGrant) + grants, err := getFunctionGrantSql(d.ops.config, newSchema, newFunc, newGrant) if err != nil { return err } @@ -435,8 +426,8 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu for _, newView := range newSchema.Views { oldView := oldSchema.TryGetViewNamed(newView.Name) for _, newGrant := range newView.Grants { - if d.ops.dbsteward.AlwaysRecreateViews || oldView == nil || !ir.HasPermissionsOf(oldView, newGrant, ir.SqlFormatPgsql8) || !oldView.Equals(newView, ir.SqlFormatPgsql8) { - s, err := getViewGrantSql(d.ops.dbsteward, newDoc, newSchema, newView, newGrant) + if d.ops.config.AlwaysRecreateViews || oldView == nil || !ir.HasPermissionsOf(oldView, newGrant, ir.SqlFormatPgsql8) || !oldView.Equals(newView, ir.SqlFormatPgsql8) { + s, err := getViewGrantSql(d.ops.config, newDoc, newSchema, newView, newGrant) if err != nil { return err } @@ -456,20 +447,20 @@ func (d *diff) updateData(ofs output.OutputFileSegmenter, deleteMode bool) error if deleteMode { item = d.NewTableDependency[len(d.NewTableDependency)-1-i] } - l := d.ops.dbsteward.Logger().With(slog.String("table", item.String())) + l := d.ops.config.Logger.With(slog.String("table", item.String())) newSchema := item.Schema newTable := item.Table - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) oldTable := oldSchema.TryGetTableNamed(newTable.Name) - isRenamed, err := d.ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := d.ops.config.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while updatign data: %w", err) } if isRenamed { l.Info(fmt.Sprintf("%s.%s used to be called %s - will diff data against that definition", newSchema.Name, newTable.Name, newTable.OldTableName)) - oldSchema = d.ops.dbsteward.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable = d.ops.dbsteward.OldDatabase.GetOldTable(newSchema, newTable) + oldSchema = d.ops.config.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable = d.ops.config.OldDatabase.GetOldTable(newSchema, newTable) } if deleteMode { @@ -489,7 +480,7 @@ func (d *diff) updateData(ofs output.OutputFileSegmenter, deleteMode bool) error // HACK: For now, we'll generate foreign key constraints in stage 4 after inserting data // https://github.com/dbsteward/dbsteward/issues/142 - err = createConstraintsTable(d.ops.dbsteward, ofs, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeForeign) + err = createConstraintsTable(d.ops.config, ofs, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeForeign) if err != nil { return err } @@ -498,8 +489,8 @@ func (d *diff) updateData(ofs output.OutputFileSegmenter, deleteMode bool) error } else { // dependency order unknown, hit them in natural order // TODO(feat) the above switches on deleteMode, this does not. we never delete data if table dep order is unknown? - for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { - oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + for _, newSchema := range d.ops.config.NewDatabase.Schemas { + oldSchema := d.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) return diffData(d.ops, ofs, oldSchema, newSchema) } } @@ -515,7 +506,7 @@ func (d *diff) DropSchemaSQL(s *ir.Schema) ([]output.ToSql, error) { // CreateSchemaSQL this implementation is a bit hacky as it's a // transitional step as I factor away global variables func (d *diff) CreateSchemaSQL(s *ir.Schema) ([]output.ToSql, error) { - return commonSchema.GetCreationSql(d.ops.dbsteward, s) + return commonSchema.GetCreationSql(d.ops.config, s) } func (diff *diff) DiffDoc(oldFile, newFile string, oldDoc, newDoc *ir.Definition, upgradePrefix string) error { @@ -524,46 +515,46 @@ func (diff *diff) DiffDoc(oldFile, newFile string, oldDoc, newDoc *ir.Definition var stage1, stage2, stage3, stage4 output.OutputFileSegmenter quoter := diff.Quoter() - logger := diff.ops.dbsteward.Logger() - if diff.ops.dbsteward.SingleStageUpgrade { + logger := diff.ops.config.Logger + if diff.ops.config.SingleStageUpgrade { fileName := upgradePrefix + "_single_stage.sql" file, err := os.Create(fileName) if err != nil { return fmt.Errorf("failed to open %s for write: %w", fileName, err) } - stage1 = output.NewOutputFileSegmenterToFile(logger, quoter, fileName, 1, file, fileName, diff.ops.dbsteward.OutputFileStatementLimit) + stage1 = output.NewOutputFileSegmenterToFile(logger, quoter, fileName, 1, file, fileName, diff.ops.config.OutputFileStatementLimit) stage1.SetHeader(sql.NewComment("DBsteward single stage upgrade changes - generated %s\n%s", timestamp, oldSetNewSet)) defer stage1.Close() stage2 = stage1 stage3 = stage1 stage4 = stage1 } else { - stage1 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage1_schema", 1, diff.ops.dbsteward.OutputFileStatementLimit) + stage1 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage1_schema", 1, diff.ops.config.OutputFileStatementLimit) stage1.SetHeader(sql.NewComment("DBSteward stage 1 structure additions and modifications - generated %s\n%s", timestamp, oldSetNewSet)) defer stage1.Close() - stage2 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage2_data", 1, diff.ops.dbsteward.OutputFileStatementLimit) + stage2 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage2_data", 1, diff.ops.config.OutputFileStatementLimit) stage2.SetHeader(sql.NewComment("DBSteward stage 2 data definitions removed - generated %s\n%s", timestamp, oldSetNewSet)) defer stage2.Close() - stage3 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage3_schema", 1, diff.ops.dbsteward.OutputFileStatementLimit) + stage3 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage3_schema", 1, diff.ops.config.OutputFileStatementLimit) stage3.SetHeader(sql.NewComment("DBSteward stage 3 structure changes, constraints, and removals - generated %s\n%s", timestamp, oldSetNewSet)) defer stage3.Close() - stage4 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage4_data", 1, diff.ops.dbsteward.OutputFileStatementLimit) + stage4 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage4_data", 1, diff.ops.config.OutputFileStatementLimit) stage4.SetHeader(sql.NewComment("DBSteward stage 4 data definition changes and additions - generated %s\n%s", timestamp, oldSetNewSet)) defer stage4.Close() } - diff.ops.dbsteward.OldDatabase = oldDoc - diff.ops.dbsteward.NewDatabase = newDoc + diff.ops.config.OldDatabase = oldDoc + diff.ops.config.NewDatabase = newDoc return diff.DiffDocWork(stage1, stage2, stage3, stage4) } func (diff *diff) DropOldSchemas(ofs output.OutputFileSegmenter) { // TODO(feat) support oldname following? - for _, oldSchema := range diff.ops.dbsteward.OldDatabase.Schemas { - if diff.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) == nil { - diff.ops.dbsteward.Logger().Info(fmt.Sprintf("Drop old schema: %s", oldSchema.Name)) + for _, oldSchema := range diff.ops.config.OldDatabase.Schemas { + if diff.ops.config.NewDatabase.TryGetSchemaNamed(oldSchema.Name) == nil { + diff.ops.config.Logger.Info(fmt.Sprintf("Drop old schema: %s", oldSchema.Name)) ofs.MustWriteSql(diff.DropSchemaSQL(oldSchema)) } } @@ -571,9 +562,9 @@ func (diff *diff) DropOldSchemas(ofs output.OutputFileSegmenter) { func (diff *diff) CreateNewSchemas(ofs output.OutputFileSegmenter) error { // TODO(feat) support oldname following? - for _, newSchema := range diff.ops.dbsteward.NewDatabase.Schemas { - if diff.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) == nil { - diff.ops.dbsteward.Logger().Info(fmt.Sprintf("Create new schema: %s", newSchema.Name)) + for _, newSchema := range diff.ops.config.NewDatabase.Schemas { + if diff.ops.config.OldDatabase.TryGetSchemaNamed(newSchema.Name) == nil { + diff.ops.config.Logger.Info(fmt.Sprintf("Create new schema: %s", newSchema.Name)) ofs.MustWriteSql(diff.CreateSchemaSQL(newSchema)) } } diff --git a/lib/format/pgsql8/diff_constraints.go b/lib/format/pgsql8/diff_constraints.go index eebd4ff..591bda6 100644 --- a/lib/format/pgsql8/diff_constraints.go +++ b/lib/format/pgsql8/diff_constraints.go @@ -10,25 +10,25 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func createConstraints(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) { +func createConstraints(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) { for _, newTable := range newSchema.Tables { var oldTable *ir.Table if oldSchema != nil { // TODO(feat) what about renames? oldTable = oldSchema.TryGetTableNamed(newTable.Name) } - createConstraintsTable(dbs, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) + createConstraintsTable(conf, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) } } -func createConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { - isRenamed, err := dbs.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) +func createConstraintsTable(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { + isRenamed, err := conf.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while checking if table was renamed: %w", err) } if isRenamed { // remove all constraints and recreate with new table name conventions - constraints, err := getTableConstraints(dbs.OldDatabase, oldSchema, oldTable, constraintType) + constraints, err := getTableConstraints(conf.OldDatabase, oldSchema, oldTable, constraintType) if err != nil { return err } @@ -42,7 +42,7 @@ func createConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, } // add all still-defined constraints back and any new ones to the table - constraints, err = getTableConstraints(dbs.NewDatabase, newSchema, newTable, constraintType) + constraints, err = getTableConstraints(conf.NewDatabase, newSchema, newTable, constraintType) if err != nil { return err } @@ -52,7 +52,7 @@ func createConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, return nil } - constraints, err := getNewConstraints(dbs, oldSchema, oldTable, newSchema, newTable, constraintType) + constraints, err := getNewConstraints(conf, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -62,14 +62,14 @@ func createConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, return nil } -func dropConstraints(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) error { +func dropConstraints(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) error { for _, newTable := range newSchema.Tables { var oldTable *ir.Table if oldSchema != nil { // TODO(feat) what about renames? oldTable = oldSchema.TryGetTableNamed(newTable.Name) } - err := dropConstraintsTable(dbs, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) + err := dropConstraintsTable(conf, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -77,8 +77,8 @@ func dropConstraints(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSche return nil } -func dropConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { - constraints, err := getOldConstraints(dbs, oldSchema, oldTable, newSchema, newTable, constraintType) +func dropConstraintsTable(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { + constraints, err := getOldConstraints(conf, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -88,11 +88,11 @@ func dropConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, ol return nil } -func getOldConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { +func getOldConstraints(conf lib.Config, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { out := []*sql99.TableConstraint{} if newTable != nil && oldTable != nil { - oldDb := dbs.OldDatabase - newDb := dbs.NewDatabase + oldDb := conf.OldDatabase + newDb := conf.NewDatabase constraints, err := getTableConstraints(oldDb, oldSchema, oldTable, constraintType) if err != nil { return nil, err @@ -106,11 +106,11 @@ func getOldConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Ta out = append(out, oldConstraint) continue } - oldConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, oldConstraint) + oldConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(conf, newDb, oldConstraint) if err != nil { return nil, err } - newConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, newConstraint) + newConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(conf, newDb, newConstraint) if err != nil { return nil, err } @@ -122,11 +122,11 @@ func getOldConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Ta return out, nil } -func getNewConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { +func getNewConstraints(conf lib.Config, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { out := []*sql99.TableConstraint{} if newTable != nil { - oldDb := dbs.OldDatabase - newDb := dbs.NewDatabase + oldDb := conf.OldDatabase + newDb := conf.NewDatabase newConstraints, err := getTableConstraints(newDb, newSchema, newTable, constraintType) if err != nil { return nil, err @@ -136,7 +136,7 @@ func getNewConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Ta if err != nil { return nil, err } - renamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, newConstraint) + renamedTable, err := constraintDependsOnRenamedTable(conf, newDb, newConstraint) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/diff_constraints_test.go b/lib/format/pgsql8/diff_constraints_test.go index fa784d2..a3d1a9f 100644 --- a/lib/format/pgsql8/diff_constraints_test.go +++ b/lib/format/pgsql8/diff_constraints_test.go @@ -3,7 +3,6 @@ package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/format/sql99" @@ -102,7 +101,7 @@ func TestDiffConstraints_DropCreate_ChangePrimaryKeyNameAndTable(t *testing.T) { oldSchema := &ir.Schema{ Name: "public", Tables: []*ir.Table{ - &ir.Table{ + { Name: "test", PrimaryKey: []string{"pka"}, Columns: []*ir.Column{ @@ -114,7 +113,7 @@ func TestDiffConstraints_DropCreate_ChangePrimaryKeyNameAndTable(t *testing.T) { newSchema := &ir.Schema{ Name: "public", Tables: []*ir.Table{ - &ir.Table{ + { Name: "newtable", PrimaryKey: []string{"pkb"}, OldSchemaName: "public", @@ -135,15 +134,15 @@ func TestDiffConstraints_DropCreate_ChangePrimaryKeyNameAndTable(t *testing.T) { newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - dbs := lib.NewDBSteward() - ofs := output.NewSegmenter(defaultQuoter(dbs)) - differ := newDiff(NewOperations(dbs).(*Operations), defaultQuoter(dbs)) - setOldNewDocs(dbs, differ, oldDoc, newDoc) - err := dropConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, nil, sql99.ConstraintTypePrimaryKey) + config := DefaultConfig + ofs := output.NewSegmenter(defaultQuoter(config)) + differ := newDiff(NewOperations(config).(*Operations), defaultQuoter(config)) + config = setOldNewDocs(config, differ, oldDoc, newDoc) + err := dropConstraintsTable(config, ofs, oldSchema, oldSchema.Tables[0], newSchema, nil, sql99.ConstraintTypePrimaryKey) if err != nil { t.Fatal(err) } - err = createConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], sql99.ConstraintTypePrimaryKey) + err = createConstraintsTable(config, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], sql99.ConstraintTypePrimaryKey) if err != nil { t.Fatal(err) } @@ -382,15 +381,15 @@ func diffConstraintsTableCommon(t *testing.T, oldSchema, newSchema *ir.Schema, c newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - dbs := lib.NewDBSteward() - ofs := output.NewSegmenter(defaultQuoter(dbs)) - differ := newDiff(NewOperations(dbs).(*Operations), defaultQuoter(dbs)) - setOldNewDocs(dbs, differ, oldDoc, newDoc) - err := dropConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) + config := DefaultConfig + ofs := output.NewSegmenter(defaultQuoter(config)) + differ := newDiff(NewOperations(config).(*Operations), defaultQuoter(config)) + config = setOldNewDocs(config, differ, oldDoc, newDoc) + err := dropConstraintsTable(config, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) if err != nil { t.Fatal(err) } - err = createConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) + err = createConstraintsTable(config, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_functions.go b/lib/format/pgsql8/diff_functions.go index f86329d..4596fb1 100644 --- a/lib/format/pgsql8/diff_functions.go +++ b/lib/format/pgsql8/diff_functions.go @@ -7,7 +7,7 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func diffFunctions(dbs *lib.DBSteward, stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffFunctions(conf lib.Config, stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { // drop functions that no longer exist in stage 3 if oldSchema != nil { for _, oldFunction := range oldSchema.Functions { @@ -21,14 +21,14 @@ func diffFunctions(dbs *lib.DBSteward, stage1 output.OutputFileSegmenter, stage3 for _, newFunction := range newSchema.Functions { oldFunction := oldSchema.TryGetFunctionMatching(newFunction) if oldFunction == nil || !oldFunction.Equals(newFunction, ir.SqlFormatPgsql8) { - create, err := getFunctionCreationSql(dbs, newSchema, newFunction) + create, err := getFunctionCreationSql(conf, newSchema, newFunction) if err != nil { return nil } stage1.WriteSql(create...) } else if newFunction.ForceRedefine { stage1.WriteSql(sql.NewComment("Function %s.%s has forceRedefine set to true", newSchema.Name, newFunction.Name)) - create, err := getFunctionCreationSql(dbs, newSchema, newFunction) + create, err := getFunctionCreationSql(conf, newSchema, newFunction) if err != nil { return nil } @@ -38,7 +38,7 @@ func diffFunctions(dbs *lib.DBSteward, stage1 output.OutputFileSegmenter, stage3 newReturnType := newSchema.TryGetTypeNamed(newFunction.Returns) if oldReturnType != nil && newReturnType != nil && !oldReturnType.Equals(newReturnType) { stage1.WriteSql(sql.NewComment("Function %s.%s return type %s has changed", newSchema.Name, newFunction.Name, newReturnType.Name)) - create, err := getFunctionCreationSql(dbs, newSchema, newFunction) + create, err := getFunctionCreationSql(conf, newSchema, newFunction) if err != nil { return nil } diff --git a/lib/format/pgsql8/diff_languages.go b/lib/format/pgsql8/diff_languages.go index 4678112..071c7b7 100644 --- a/lib/format/pgsql8/diff_languages.go +++ b/lib/format/pgsql8/diff_languages.go @@ -5,18 +5,18 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func diffLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) error { +func diffLanguages(conf lib.Config, ofs output.OutputFileSegmenter) error { // TODO(go,pgsql) this is a different flow than old dbsteward: // we do equality comparison inside these two methods, instead of a separate loop // need to validate that this behavior is still correct - dropLanguages(dbs, ofs) - return createLanguages(dbs, ofs) + dropLanguages(conf, ofs) + return createLanguages(conf, ofs) } -func dropLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) { - newDoc := dbs.NewDatabase - oldDoc := dbs.OldDatabase +func dropLanguages(conf lib.Config, ofs output.OutputFileSegmenter) { + newDoc := conf.NewDatabase + oldDoc := conf.OldDatabase // drop languages that either do not exist in the new schema or have changed if oldDoc != nil { @@ -29,15 +29,15 @@ func dropLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) { } } -func createLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) error { - newDoc := dbs.NewDatabase - oldDoc := dbs.OldDatabase +func createLanguages(conf lib.Config, ofs output.OutputFileSegmenter) error { + newDoc := conf.NewDatabase + oldDoc := conf.OldDatabase // create languages that either do not exist in the old schema or have changed for _, newLang := range newDoc.Languages { oldLang := oldDoc.TryGetLanguageNamed(newLang.Name) if oldLang == nil || !oldLang.Equals(newLang) { - s, err := getCreateLanguageSql(dbs, newLang) + s, err := getCreateLanguageSql(conf, newLang) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_sequences.go b/lib/format/pgsql8/diff_sequences.go index 4539ef3..e691966 100644 --- a/lib/format/pgsql8/diff_sequences.go +++ b/lib/format/pgsql8/diff_sequences.go @@ -7,7 +7,7 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func diffSequences(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffSequences(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { // drop old sequences if oldSchema != nil { for _, oldSeq := range oldSchema.Sequences { @@ -20,7 +20,7 @@ func diffSequences(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema for _, newSeq := range newSchema.Sequences { oldSeq := oldSchema.TryGetSequenceNamed(newSeq.Name) if oldSeq == nil { - sql, err := getCreateSequenceSql(dbs, newSchema, newSeq) + sql, err := getCreateSequenceSql(conf, newSchema, newSeq) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_tables.go b/lib/format/pgsql8/diff_tables.go index 031550c..b0e4310 100644 --- a/lib/format/pgsql8/diff_tables.go +++ b/lib/format/pgsql8/diff_tables.go @@ -16,7 +16,7 @@ import ( // TODO(go,core) lift much of this up to sql99 // applies transformations to tables that exist in both old and new -func diffTables(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { +func diffTables(conf lib.Config, stage1, stage3 output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { // note: old dbsteward called create_tables here, but because we split out DiffTable, we can't call it both places, // so callers were updated to call createTables or CreateTable just before calling DiffTables or DiffTable, respectively @@ -26,11 +26,11 @@ func diffTables(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, o for _, newTable := range newSchema.Tables { oldTable := oldSchema.TryGetTableNamed(newTable.Name) var err error - oldSchema, oldTable, err = dbs.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) + oldSchema, oldTable, err = conf.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) if err != nil { return err } - err = diffTable(dbs, stage1, stage3, oldSchema, oldTable, newSchema, newTable) + err = diffTable(conf, stage1, stage3, oldSchema, oldTable, newSchema, newTable) if err != nil { return errors.Wrapf(err, "while diffing table %s.%s", newSchema.Name, newTable.Name) } @@ -38,17 +38,17 @@ func diffTables(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, o return nil } -func diffTable(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func diffTable(conf lib.Config, stage1, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { if oldTable == nil || newTable == nil { // create and drop are handled elsewhere return nil } - err := updateTableOptions(dbs.Logger(), stage1, oldTable, newSchema, newTable) + err := updateTableOptions(conf.Logger, stage1, oldTable, newSchema, newTable) if err != nil { return errors.Wrap(err, "while diffing table options") } - err = updateTableColumns(dbs, stage1, stage3, oldTable, newSchema, newTable) + err = updateTableColumns(conf, stage1, stage3, oldTable, newSchema, newTable) if err != nil { return errors.Wrap(err, "while diffing table columns") } @@ -163,22 +163,22 @@ type updateTableColumnsAgg struct { after3 []output.ToSql } -func updateTableColumns(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func updateTableColumns(conf lib.Config, stage1, stage3 output.OutputFileSegmenter, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { agg := &updateTableColumnsAgg{} // TODO(go,pgsql) old dbsteward interleaved commands into a single list, and output in the same order // meaning that a BEFORE3 could be output before a BEFORE1 in a single-stage upgrade. in this implementation, // _all_ BEFORE1s are printed before BEFORE3s. Double check that this doesn't break anything. - err := addDropTableColumns(dbs, agg, oldTable, newTable) + err := addDropTableColumns(conf, agg, oldTable, newTable) if err != nil { return err } - err = addCreateTableColumns(dbs, agg, oldTable, newSchema, newTable) + err = addCreateTableColumns(conf, agg, oldTable, newSchema, newTable) if err != nil { return err } - err = addModifyTableColumns(dbs, agg, oldTable, newSchema, newTable) + err = addModifyTableColumns(conf, agg, oldTable, newSchema, newTable) if err != nil { return err } @@ -212,7 +212,7 @@ func updateTableColumns(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegm return nil } -func addDropTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTable, newTable *ir.Table) error { +func addDropTableColumns(conf lib.Config, agg *updateTableColumnsAgg, oldTable, newTable *ir.Table) error { for _, oldColumn := range oldTable.Columns { if newTable.TryGetColumnNamed(oldColumn.Name) != nil { // new column exists, not dropping it @@ -220,7 +220,7 @@ func addDropTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTabl } renamedColumn := newTable.TryGetColumnOldNamed(oldColumn.Name) - if !dbs.IgnoreOldNames && renamedColumn != nil { + if !conf.IgnoreOldNames && renamedColumn != nil { agg.after3 = append(agg.after3, sql.NewComment( "%s DROP COLUMN %s omitted: new column %s indicates it is the replacement for %s", oldTable.Name, oldColumn.Name, renamedColumn.Name, oldColumn.Name, @@ -232,10 +232,10 @@ func addDropTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTabl return nil } -func addCreateTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func addCreateTableColumns(conf lib.Config, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { // note that postgres treats identifiers as case-sensitive when quoted // TODO(go,3) find a way to generalize/streamline this - caseSensitive := dbs.QuoteAllNames || dbs.QuoteColumnNames + caseSensitive := conf.QuoteAllNames || conf.QuoteColumnNames for _, newColumn := range newTable.Columns { if oldTable.TryGetColumnNamedCase(newColumn.Name, caseSensitive) != nil { @@ -243,7 +243,7 @@ func addCreateTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTa continue } - isRenamed, err := isRenamedColumn(dbs, oldTable, newTable, newColumn) + isRenamed, err := isRenamedColumn(conf, oldTable, newTable, newColumn) if err != nil { return errors.Wrapf(err, "while adding new table columns") } @@ -260,7 +260,7 @@ func addCreateTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTa // notice $include_null_definition is false // this is because ADD COLUMNs with NOT NULL will fail when there are existing rows - colDef, err := getFullColumnDefinition(dbs.Logger(), dbs.NewDatabase, newSchema, newTable, newColumn, false, true) + colDef, err := getFullColumnDefinition(conf.Logger, conf.NewDatabase, newSchema, newTable, newColumn, false, true) if err != nil { return err } @@ -334,10 +334,10 @@ func addCreateTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTa return nil } -func addModifyTableColumns(dbsteward *lib.DBSteward, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func addModifyTableColumns(conf lib.Config, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { // note that postgres treats identifiers as case-sensitive when quoted // TODO(go,3) find a way to generalize/streamline this - caseSensitive := dbsteward.QuoteAllNames || dbsteward.QuoteColumnNames + caseSensitive := conf.QuoteAllNames || conf.QuoteColumnNames for _, newColumn := range newTable.Columns { oldColumn := oldTable.TryGetColumnNamedCase(newColumn.Name, caseSensitive) @@ -345,7 +345,7 @@ func addModifyTableColumns(dbsteward *lib.DBSteward, agg *updateTableColumnsAgg, // old table does not contain column, CREATE handled by addCreateTableColumns continue } - isRenamed, err := isRenamedColumn(dbsteward, oldTable, newTable, newColumn) + isRenamed, err := isRenamedColumn(conf, oldTable, newTable, newColumn) if err != nil { return errors.Wrapf(err, "while diffing table columns") } @@ -356,11 +356,11 @@ func addModifyTableColumns(dbsteward *lib.DBSteward, agg *updateTableColumnsAgg, } // TODO(go,pgsql) orig code calls (oldDB, *newSchema*, oldTable, oldColumn) but that seems wrong, need to validate this - oldType, err := getColumnType(dbsteward.Logger(), dbsteward.OldDatabase, newSchema, oldTable, oldColumn) + oldType, err := getColumnType(conf.Logger, conf.OldDatabase, newSchema, oldTable, oldColumn) if err != nil { return err } - newType, err := getColumnType(dbsteward.Logger(), dbsteward.NewDatabase, newSchema, newTable, newColumn) + newType, err := getColumnType(conf.Logger, conf.NewDatabase, newSchema, newTable, newColumn) if err != nil { return err } @@ -488,13 +488,13 @@ func addAlterStatistics(stage1 output.OutputFileSegmenter, oldTable *ir.Table, n return nil } -func isRenamedColumn(dbsteward *lib.DBSteward, oldTable, newTable *ir.Table, newColumn *ir.Column) (bool, error) { - if dbsteward.IgnoreOldNames { +func isRenamedColumn(conf lib.Config, oldTable, newTable *ir.Table, newColumn *ir.Column) (bool, error) { + if conf.IgnoreOldNames { return false, nil } caseSensitive := false - if dbsteward.QuoteColumnNames || dbsteward.QuoteAllNames || dbsteward.SqlFormat.Equals(ir.SqlFormatMysql5) { + if conf.QuoteColumnNames || conf.QuoteAllNames || conf.SqlFormat.Equals(ir.SqlFormatMysql5) { for _, oldColumn := range oldTable.Columns { if strings.EqualFold(oldColumn.Name, newColumn.Name) { if oldColumn.Name != newColumn.Name && newColumn.OldColumnName == "" { @@ -523,19 +523,19 @@ func isRenamedColumn(dbsteward *lib.DBSteward, oldTable, newTable *ir.Table, new // newColumn.OldColumnName exists in old schema // newColumn.OldColumnName does not exist in new schema if oldTable.TryGetColumnNamedCase(newColumn.OldColumnName, caseSensitive) != nil && newTable.TryGetColumnNamedCase(newColumn.OldColumnName, caseSensitive) == nil { - dbsteward.Logger().Info(fmt.Sprintf("Column %s.%s used to be called %s", newTable.Name, newColumn.Name, newColumn.OldColumnName)) + conf.Logger.Info(fmt.Sprintf("Column %s.%s used to be called %s", newTable.Name, newColumn.Name, newColumn.OldColumnName)) return true, nil } return false, nil } -func createTables(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { +func createTables(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { if newSchema == nil { // if the new schema is nil, there's no tables to create return nil } for _, newTable := range newSchema.Tables { - err := createTable(dbs, ofs, oldSchema, newSchema, newTable) + err := createTable(conf, ofs, oldSchema, newSchema, newTable) if err != nil { return err } @@ -543,8 +543,8 @@ func createTables(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, return nil } -func createTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, newTable *ir.Table) error { - l := dbs.Logger().With( +func createTable(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, newTable *ir.Table) error { + l := conf.Logger.With( slog.String("function", "createTable()"), slog.String("old schema", oldSchema.Name), slog.String("new schema", newSchema.Name), @@ -561,15 +561,15 @@ func createTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, return nil } - isRenamed, err := dbs.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := conf.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return err } if isRenamed { l.Debug("table renamed") // this is a renamed table, so rename it instead of creating a new one - oldTableSchema := dbs.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable := dbs.OldDatabase.GetOldTable(newSchema, newTable) + oldTableSchema := conf.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable := conf.OldDatabase.GetOldTable(newSchema, newTable) // ALTER TABLE ... RENAME TO does not accept schema qualifiers ... oldRef := sql.TableRef{Schema: oldTableSchema.Name, Table: oldTable.Name} @@ -592,7 +592,7 @@ func createTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, } } else { l.Debug("table not renamed") - createTableSQL, err := getCreateTableSql(dbs, newSchema, newTable) + createTableSQL, err := getCreateTableSql(conf, newSchema, newTable) if err != nil { return err } @@ -608,23 +608,23 @@ func createTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, return nil } -func dropTables(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) { +func dropTables(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) { // if newSchema is nil, we'll have already dropped all the tables in it if oldSchema != nil && newSchema != nil { for _, oldTable := range oldSchema.Tables { - dropTable(dbs, ofs, oldSchema, oldTable, newSchema) + dropTable(conf, ofs, oldSchema, oldTable, newSchema) } } } -func dropTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema) { +func dropTable(conf lib.Config, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema) { newTable := newSchema.TryGetTableNamed(oldTable.Name) if newTable != nil { // table exists, nothing to do return } - if !dbs.IgnoreOldNames { - renamedRef := dbs.NewDatabase.TryGetTableFormerlyKnownAs(oldSchema, oldTable) + if !conf.IgnoreOldNames { + renamedRef := conf.NewDatabase.TryGetTableFormerlyKnownAs(oldSchema, oldTable) if renamedRef != nil { ofs.WriteSql(sql.NewComment("DROP TABLE %s.%s omitted: new table %s indicates it is her replacement", oldSchema.Name, oldTable.Name, renamedRef)) return @@ -652,14 +652,14 @@ func diffClustersTable(ofs output.OutputFileSegmenter, oldTable *ir.Table, newSc func diffData(ops *Operations, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { for _, newTable := range newSchema.Tables { - isRenamed, err := ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := ops.config.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while diffing data: %w", err) } if isRenamed { // if the table was renamed, get old definition pointers, diff that - oldSchema := ops.dbsteward.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable := ops.dbsteward.OldDatabase.GetOldTable(newSchema, newTable) + oldSchema := ops.config.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable := ops.config.OldDatabase.GetOldTable(newSchema, newTable) s, err := getCreateDataSql(ops, oldSchema, oldTable, newSchema, newTable) if err != nil { return err diff --git a/lib/format/pgsql8/diff_tables_escape_char_test.go b/lib/format/pgsql8/diff_tables_escape_char_test.go index ae4d6cc..f599425 100644 --- a/lib/format/pgsql8/diff_tables_escape_char_test.go +++ b/lib/format/pgsql8/diff_tables_escape_char_test.go @@ -3,7 +3,6 @@ package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/output" @@ -40,11 +39,11 @@ func TestDiffTables_GetDataSql_EscapeCharacters(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - dbs.NewDatabase = &ir.Definition{ + conf := DefaultConfig + conf.NewDatabase = &ir.Definition{ Schemas: []*ir.Schema{schema}, } - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(conf).(*Operations) ddl, err := getCreateDataSql(ops, nil, nil, schema, schema.Tables[0]) if err != nil { t.Fatal(err) diff --git a/lib/format/pgsql8/diff_tables_test.go b/lib/format/pgsql8/diff_tables_test.go index d359460..d5d33f8 100644 --- a/lib/format/pgsql8/diff_tables_test.go +++ b/lib/format/pgsql8/diff_tables_test.go @@ -7,7 +7,6 @@ import ( "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/stretchr/testify/assert" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" ) @@ -49,20 +48,21 @@ func TestDiffTables_DiffTables_ColumnCaseChange(t *testing.T) { }, } - dbs := lib.NewDBSteward() - dbs.IgnoreOldNames = false - ops := NewOperations(dbs).(*Operations) + conf := DefaultConfig + conf.IgnoreOldNames = false // when quoting is off, a change in case is a no-op - dbs.QuoteAllNames = false - dbs.QuoteColumnNames = false + conf.QuoteAllNames = false + conf.QuoteColumnNames = false + ops := NewOperations(conf).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, lower, upperWithoutOldName) assert.Empty(t, ddl1) assert.Empty(t, ddl3) // when quoting is on, a change in case results in a rename, if there's an oldname - dbs.QuoteAllNames = true - dbs.QuoteColumnNames = true + conf.QuoteAllNames = true + conf.QuoteColumnNames = true + ops = NewOperations(conf).(*Operations) ddl1, ddl3 = diffTablesCommon(t, ops, lower, upperWithOldName) assert.Equal(t, []output.ToSql{ &sql.ColumnRename{ @@ -96,12 +96,12 @@ func TestDiffTables_DiffTables_TableOptions_NoChange(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, schema, schema) assert.Empty(t, ddl1) assert.Empty(t, ddl3) } + func TestDiffTables_DiffTables_TableOptions_AddWith(t *testing.T) { oldSchema := &ir.Schema{ Name: "public", @@ -129,8 +129,7 @@ func TestDiffTables_DiffTables_TableOptions_AddWith(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ sql.NewTableAlter( @@ -143,6 +142,7 @@ func TestDiffTables_DiffTables_TableOptions_AddWith(t *testing.T) { }, ddl1) assert.Empty(t, ddl3) } + func TestDiffTables_DiffTables_TableOptions_AlterWith(t *testing.T) { oldSchema := &ir.Schema{ Name: "public", @@ -178,8 +178,7 @@ func TestDiffTables_DiffTables_TableOptions_AlterWith(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ sql.NewTableAlter( @@ -192,6 +191,7 @@ func TestDiffTables_DiffTables_TableOptions_AlterWith(t *testing.T) { }, ddl1) assert.Empty(t, ddl3) } + func TestDiffTables_DiffTables_TableOptions_AddTablespaceAlterWith(t *testing.T) { oldSchema := &ir.Schema{ Name: "public", @@ -232,8 +232,7 @@ func TestDiffTables_DiffTables_TableOptions_AddTablespaceAlterWith(t *testing.T) }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ &sql.TableMoveTablespaceIndexes{ @@ -251,6 +250,7 @@ func TestDiffTables_DiffTables_TableOptions_AddTablespaceAlterWith(t *testing.T) }, ddl1) assert.Empty(t, ddl3) } + func TestDiffTables_DiffTables_TableOptions_DropTablespace(t *testing.T) { oldSchema := &ir.Schema{ Name: "public", @@ -291,8 +291,7 @@ func TestDiffTables_DiffTables_TableOptions_DropTablespace(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ &sql.TableResetTablespace{ @@ -344,8 +343,7 @@ func TestDiffTables_GetDeleteCreateDataSql_AddSerialColumn(t *testing.T) { }, }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) delddl, err := getDeleteDataSql(ops, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0]) if err != nil { t.Fatal(err) @@ -381,18 +379,18 @@ func diffTablesCommonErr(ops *Operations, oldSchema, newSchema *ir.Schema) ([]ou newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - differ := newDiff(ops, defaultQuoter(ops.dbsteward)) - setOldNewDocs(ops.dbsteward, differ, oldDoc, newDoc) - ofs1 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.dbsteward)) - ofs3 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.dbsteward)) + differ := newDiff(ops, defaultQuoter(ops.config)) + ops.config = setOldNewDocs(ops.config, differ, oldDoc, newDoc) + ofs1 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.config)) + ofs3 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.config)) // note: v1 only used DiffTables, v2 split into CreateTables+DiffTables - err := createTables(ops.dbsteward, ofs1, oldSchema, newSchema) + err := createTables(ops.config, ofs1, oldSchema, newSchema) if err != nil { return ofs1.Body, ofs3.Body, err } - err = diffTables(ops.dbsteward, ofs1, ofs3, oldSchema, newSchema) + err = diffTables(ops.config, ofs1, ofs3, oldSchema, newSchema) if err != nil { return ofs1.Body, ofs3.Body, err } diff --git a/lib/format/pgsql8/diff_types.go b/lib/format/pgsql8/diff_types.go index db5ed8b..5dc60b5 100644 --- a/lib/format/pgsql8/diff_types.go +++ b/lib/format/pgsql8/diff_types.go @@ -10,7 +10,7 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func diffTypes(dbs *lib.DBSteward, differ *diff, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffTypes(conf lib.Config, differ *diff, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { dropTypes(ofs, oldSchema, newSchema) err := createTypes(ofs, oldSchema, newSchema) if err != nil { @@ -41,7 +41,7 @@ func diffTypes(dbs *lib.DBSteward, differ *diff, ofs output.OutputFileSegmenter, ofs.WriteSql(getFunctionDropSql(oldSchema, oldFunc)...) } - columns, sql, err := alterColumnTypePlaceholder(dbs, differ, oldType) + columns, sql, err := alterColumnTypePlaceholder(conf, differ, oldType) if err != nil { return err } @@ -63,7 +63,7 @@ func diffTypes(dbs *lib.DBSteward, differ *diff, ofs output.OutputFileSegmenter, // functions are only recreated if they changed elsewise, so need to create them here for _, newFunc := range commonSchema.GetFunctionsDependingOnType(newSchema, newType) { - s, err := getFunctionCreationSql(dbs, newSchema, newFunc) + s, err := getFunctionCreationSql(conf, newSchema, newFunc) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_types_domains_test.go b/lib/format/pgsql8/diff_types_domains_test.go index c130ba5..6a3756b 100644 --- a/lib/format/pgsql8/diff_types_domains_test.go +++ b/lib/format/pgsql8/diff_types_domains_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -323,12 +322,12 @@ func diffTypesForTest(t *testing.T, oldSchema, newSchema *ir.Schema) []output.To newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) - differ := newDiff(ops, defaultQuoter(dbs)) - setOldNewDocs(dbs, differ, oldDoc, newDoc) - ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(dbs)) - err := diffTypes(dbs, differ, ofs, oldSchema, newSchema) + config := DefaultConfig + ops := NewOperations(config).(*Operations) + differ := newDiff(ops, defaultQuoter(config)) + config = setOldNewDocs(config, differ, oldDoc, newDoc) + ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(config)) + err := diffTypes(config, differ, ofs, oldSchema, newSchema) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_types_test.go b/lib/format/pgsql8/diff_types_test.go index 02f13d0..d8a4000 100644 --- a/lib/format/pgsql8/diff_types_test.go +++ b/lib/format/pgsql8/diff_types_test.go @@ -3,7 +3,6 @@ package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -94,10 +93,9 @@ func TestDiffTypes_DiffTypes_RecreateDependentFunctions(t *testing.T) { }, } - dbs := lib.NewDBSteward() - ops := NewOperations(dbs).(*Operations) - ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(dbs)) - err := diffTypes(dbs, newDiff(ops, defaultQuoter(dbs)), ofs, oldSchema, newSchema) + ops := NewOperations(DefaultConfig).(*Operations) + ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(DefaultConfig)) + err := diffTypes(DefaultConfig, newDiff(ops, defaultQuoter(DefaultConfig)), ofs, oldSchema, newSchema) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_views.go b/lib/format/pgsql8/diff_views.go index 508d2a5..5091e39 100644 --- a/lib/format/pgsql8/diff_views.go +++ b/lib/format/pgsql8/diff_views.go @@ -10,8 +10,8 @@ import ( // TODO(go,core) lift some of these to sql99 -func createViewsOrdered(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { - l := dbs.Logger() +func createViewsOrdered(conf lib.Config, ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { + l := conf.Logger return forEachViewInDepOrder(newDoc, func(newRef ir.ViewRef) error { ll := l.With(slog.String("view", newRef.String())) ll.Debug("consider creating") @@ -24,11 +24,11 @@ func createViewsOrdered(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldD if oldView != nil { ll = ll.With(slog.String("old view", oldView.Name)) } - if shouldCreateView(dbs, oldView, newRef.View) { + if shouldCreateView(conf, oldView, newRef.View) { ll.Debug("shouldCreateView returned true") - s, err := getCreateViewSql(dbs, newRef.Schema, newRef.View) + s, err := getCreateViewSql(conf, newRef.Schema, newRef.View) for _, s1 := range s { - ll.Debug(s1.ToSql(defaultQuoter(dbs))) + ll.Debug(s1.ToSql(defaultQuoter(conf))) } if err != nil { return err @@ -41,8 +41,8 @@ func createViewsOrdered(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldD }) } -func shouldCreateView(dbs *lib.DBSteward, oldView, newView *ir.View) bool { - return oldView == nil || dbs.AlwaysRecreateViews || !oldView.Equals(newView, ir.SqlFormatPgsql8) +func shouldCreateView(conf lib.Config, oldView, newView *ir.View) bool { + return oldView == nil || conf.AlwaysRecreateViews || !oldView.Equals(newView, ir.SqlFormatPgsql8) } func dropViewsOrdered(ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { diff --git a/lib/format/pgsql8/diff_views_test.go b/lib/format/pgsql8/diff_views_test.go index fcdcb0c..e06e8f2 100644 --- a/lib/format/pgsql8/diff_views_test.go +++ b/lib/format/pgsql8/diff_views_test.go @@ -3,7 +3,6 @@ package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -53,10 +52,9 @@ func newSingleView() *ir.Definition { } func TestCreateViewsOrdered(t *testing.T) { - dbs := lib.NewDBSteward() - q := defaultQuoter(dbs) + q := defaultQuoter(DefaultConfig) ofs := output.NewAnnotationStrippingSegmenter(q) - err := createViewsOrdered(dbs, ofs, oldSingleView(), newSingleView()) + err := createViewsOrdered(DefaultConfig, ofs, oldSingleView(), newSingleView()) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/function.go b/lib/format/pgsql8/function.go index d0dac8c..4394000 100644 --- a/lib/format/pgsql8/function.go +++ b/lib/format/pgsql8/function.go @@ -34,7 +34,7 @@ func functionDefinitionReferencesTable(definition *ir.FunctionDefinition) *lib.Q return &parsed } -func getFunctionCreationSql(dbs *lib.DBSteward, schema *ir.Schema, function *ir.Function) ([]output.ToSql, error) { +func getFunctionCreationSql(conf lib.Config, schema *ir.Schema, function *ir.Function) ([]output.ToSql, error) { ref := sql.FunctionRef{Schema: schema.Name, Function: function.Name, Params: function.ParamSigs()} def := function.TryGetDefinition(ir.SqlFormatPgsql8) out := []output.ToSql{ @@ -49,7 +49,7 @@ func getFunctionCreationSql(dbs *lib.DBSteward, schema *ir.Schema, function *ir. } if function.Owner != "" { - role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, function.Owner, dbs.IgnoreCustomRoles) + role, err := roleEnum(conf.Logger, conf.NewDatabase, function.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -93,11 +93,11 @@ func normalizeFunctionParameterType(paramType string) string { return paramType } -func getFunctionGrantSql(dbs *lib.DBSteward, schema *ir.Schema, fn *ir.Function, grant *ir.Grant) ([]output.ToSql, error) { +func getFunctionGrantSql(conf lib.Config, schema *ir.Schema, fn *ir.Function, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) + roles[i], err = roleEnum(conf.Logger, conf.NewDatabase, role, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/language.go b/lib/format/pgsql8/language.go index fe2ff4b..7a66860 100644 --- a/lib/format/pgsql8/language.go +++ b/lib/format/pgsql8/language.go @@ -7,7 +7,7 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func getCreateLanguageSql(dbsteward *lib.DBSteward, lang *ir.Language) ([]output.ToSql, error) { +func getCreateLanguageSql(conf lib.Config, lang *ir.Language) ([]output.ToSql, error) { out := []output.ToSql{ &sql.LanguageCreate{ Language: lang.Name, @@ -19,7 +19,7 @@ func getCreateLanguageSql(dbsteward *lib.DBSteward, lang *ir.Language) ([]output } if lang.Owner != "" { - role, err := roleEnum(dbsteward.Logger(), dbsteward.NewDatabase, lang.Owner, dbsteward.IgnoreCustomRoles) + role, err := roleEnum(conf.Logger, conf.NewDatabase, lang.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/oneeighty_test.go b/lib/format/pgsql8/oneeighty_test.go index 2397b50..1b6db64 100644 --- a/lib/format/pgsql8/oneeighty_test.go +++ b/lib/format/pgsql8/oneeighty_test.go @@ -5,7 +5,6 @@ import ( "os" "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/stretchr/testify/assert" ) @@ -28,9 +27,8 @@ func TestOneEighty(t *testing.T) { } defer Teardowndb(t, c, "pg") role := os.Getenv("DB_USER") - dbs := lib.NewDBSteward() - dbs.SqlFormat = ir.SqlFormatPgsql8 - ops := NewOperations(dbs).(*Operations) + conf := DefaultConfig + ops := NewOperations(conf).(*Operations) statements, err := ops.CreateStatements(ir.FullFeatureSchema(role)) if err != nil { t.Fatal(err) diff --git a/lib/format/pgsql8/operations.go b/lib/format/pgsql8/operations.go index 3d87b70..afc6fde 100644 --- a/lib/format/pgsql8/operations.go +++ b/lib/format/pgsql8/operations.go @@ -22,32 +22,32 @@ import ( ) type Operations struct { - logger *slog.Logger - dbsteward *lib.DBSteward - differ *diff + logger *slog.Logger + config lib.Config + differ *diff } var quoter output.Quoter -func defaultQuoter(dbsteward *lib.DBSteward) output.Quoter { +func defaultQuoter(c lib.Config) output.Quoter { return &sql.Quoter{ - Logger: dbsteward.Logger(), - ShouldQuoteSchemaNames: dbsteward.QuoteAllNames || dbsteward.QuoteSchemaNames, - ShouldQuoteTableNames: dbsteward.QuoteAllNames || dbsteward.QuoteTableNames, - ShouldQuoteColumnNames: dbsteward.QuoteAllNames || dbsteward.QuoteColumnNames, - ShouldQuoteObjectNames: dbsteward.QuoteAllNames || dbsteward.QuoteObjectNames, - ShouldQuoteIllegalIdentifiers: dbsteward.QuoteIllegalIdentifiers, - ShouldQuoteReservedIdentifiers: dbsteward.QuoteReservedIdentifiers, + Logger: c.Logger, + ShouldQuoteSchemaNames: c.QuoteAllNames || c.QuoteSchemaNames, + ShouldQuoteTableNames: c.QuoteAllNames || c.QuoteTableNames, + ShouldQuoteColumnNames: c.QuoteAllNames || c.QuoteColumnNames, + ShouldQuoteObjectNames: c.QuoteAllNames || c.QuoteObjectNames, + ShouldQuoteIllegalIdentifiers: c.QuoteIllegalIdentifiers, + ShouldQuoteReservedIdentifiers: c.QuoteReservedIdentifiers, ShouldEEscape: false, - RequireVerboseIntervalNotation: dbsteward.RequireVerboseIntervalNotation, + RequireVerboseIntervalNotation: c.RequireVerboseIntervalNotation, } } -func NewOperations(dbs *lib.DBSteward) lib.Operations { - quoter = defaultQuoter(dbs) +func NewOperations(c lib.Config) lib.Operations { + quoter = defaultQuoter(c) ops := &Operations{ - logger: dbs.Logger(), - dbsteward: dbs, + logger: c.Logger, + config: c, } ops.differ = newDiff(ops, quoter) return ops @@ -76,7 +76,7 @@ func (ops *Operations) Build(outputPrefix string, dbDoc *ir.Definition) error { return fmt.Errorf("failed to open file %s for output: %w", buildFileName, err) } - buildFileOfs := output.NewOutputFileSegmenterToFile(ops.logger, ops.GetQuoter(), buildFileName, 1, buildFile, buildFileName, ops.dbsteward.OutputFileStatementLimit) + buildFileOfs := output.NewOutputFileSegmenterToFile(ops.logger, ops.GetQuoter(), buildFileName, 1, buildFile, buildFileName, ops.config.OutputFileStatementLimit) err = ops.build(buildFileOfs, dbDoc) if err != nil { return err @@ -87,10 +87,10 @@ func (ops *Operations) Build(outputPrefix string, dbDoc *ir.Definition) error { func (ops *Operations) build(buildFileOfs output.OutputFileSegmenter, dbDoc *ir.Definition) error { // TODO(go,4) can we just consider a build(def) to be diff(null, def)? - if len(ops.dbsteward.LimitToTables) == 0 { + if len(ops.config.LimitToTables) == 0 { buildFileOfs.WriteSql(sql.NewComment("full database definition file generated %s\n", time.Now().Format(time.RFC1123Z))) } - if !ops.dbsteward.GenerateSlonik { + if !ops.config.GenerateSlonik { buildFileOfs.WriteSql(output.NewRawSQL("BEGIN;\n\n")) } @@ -101,12 +101,12 @@ func (ops *Operations) build(buildFileOfs output.OutputFileSegmenter, dbDoc *ir. } // database-specific implementation code refers to dbsteward::$new_database when looking up roles/values/conflicts etc - ops.dbsteward.NewDatabase = dbDoc + ops.config.NewDatabase = dbDoc // language definitions - if ops.dbsteward.CreateLanguages { + if ops.config.CreateLanguages { for _, language := range dbDoc.Languages { - s, err := getCreateLanguageSql(ops.dbsteward, language) + s, err := getCreateLanguageSql(ops.config, language) if err != nil { return err } @@ -159,23 +159,23 @@ outer: ops.logger.Info(setCheckFunctionBodiesInfo) } - if ops.dbsteward.OnlySchemaSql || !ops.dbsteward.OnlyDataSql { + if ops.config.OnlySchemaSql || !ops.config.OnlyDataSql { ops.logger.Info("Defining structure") err := ops.buildSchema(dbDoc, buildFileOfs, tableDependency) if err != nil { return err } } - if !ops.dbsteward.OnlySchemaSql || ops.dbsteward.OnlyDataSql { + if !ops.config.OnlySchemaSql || ops.config.OnlyDataSql { ops.logger.Info("Defining data inserts") err = ops.buildData(ops.logger, dbDoc, buildFileOfs, tableDependency) if err != nil { return err } } - ops.dbsteward.NewDatabase = nil + ops.config.NewDatabase = nil - if !ops.dbsteward.GenerateSlonik { + if !ops.config.GenerateSlonik { buildFileOfs.WriteSql(output.NewRawSQL("COMMIT;\n\n")) } @@ -223,8 +223,8 @@ func (ops *Operations) Upgrade(l *slog.Logger, oldDoc *ir.Definition, newDoc *ir if err != nil { return nil, fmt.Errorf("new document: %w", err) } - ops.dbsteward.OldDatabase = oldDoc - ops.dbsteward.NewDatabase = newDoc + ops.config.OldDatabase = oldDoc + ops.config.NewDatabase = newDoc stage1 := output.NewSegmenter(ops.GetQuoter()) stage2 := output.NewSegmenter(ops.GetQuoter()) @@ -957,7 +957,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // TODO(go,3) roll this into diffing nil -> doc // schema creation for _, schema := range doc.Schemas { - s, err := commonSchema.GetCreationSql(ops.dbsteward, schema) + s, err := commonSchema.GetCreationSql(ops.config, schema) if err != nil { return err } @@ -965,7 +965,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // schema grants for _, grant := range schema.Grants { - s, err := commonSchema.GetGrantSql(ops.dbsteward, doc, schema, grant) + s, err := commonSchema.GetGrantSql(ops.config, doc, schema, grant) if err != nil { return err } @@ -990,7 +990,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm includeColumnDefaultNextvalInCreateSql = false for _, table := range schema.Tables { // table definition - s, err := getCreateTableSql(ops.dbsteward, schema, table) + s, err := getCreateTableSql(ops.config, schema, table) if err != nil { return err } @@ -1004,7 +1004,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // table grants for _, grant := range table.Grants { - s, err := getTableGrantSql(ops.dbsteward, schema, table, grant) + s, err := getTableGrantSql(ops.config, schema, table, grant) if err != nil { return err } @@ -1016,7 +1016,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // sequences contained in the schema for _, sequence := range schema.Sequences { if sequence.OwnedByColumn == "" { - sql, err := getCreateSequenceSql(ops.dbsteward, schema, sequence) + sql, err := getCreateSequenceSql(ops.config, schema, sequence) if err != nil { return err } @@ -1029,7 +1029,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // sequence permission grants for _, grant := range sequence.Grants { - s, err := getSequenceGrantSql(ops.dbsteward, schema, sequence, grant) + s, err := getSequenceGrantSql(ops.config, schema, sequence, grant) if err != nil { return err } @@ -1049,7 +1049,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm for _, schema := range doc.Schemas { for _, function := range schema.Functions { if function.HasDefinition(ir.SqlFormatPgsql8) { - s, err := getFunctionCreationSql(ops.dbsteward, schema, function) + s, err := getFunctionCreationSql(ops.config, schema, function) if err != nil { return err } @@ -1058,7 +1058,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // they are not included in pg_function::get_creation_sql() for _, grant := range function.Grants { - grant, err := getFunctionGrantSql(ops.dbsteward, schema, function, grant) + grant, err := getFunctionGrantSql(ops.config, schema, function, grant) if err != nil { return err } @@ -1079,7 +1079,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // define table primary keys before foreign keys so unique requirements are always met for FOREIGN KEY constraints for _, schema := range doc.Schemas { for _, table := range schema.Tables { - err := createConstraintsTable(ops.dbsteward, ofs, nil, nil, schema, table, sql99.ConstraintTypePrimaryKey) + err := createConstraintsTable(ops.config, ofs, nil, nil, schema, table, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -1090,7 +1090,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // use the dependency order to specify foreign keys in an order that will satisfy nested foreign keys and etc // TODO(feat) shouldn't this consider GlobalDBSteward.LimitToTables like BuildData does? for _, entry := range tableDep { - err := createConstraintsTable(ops.dbsteward, ofs, nil, nil, entry.Schema, entry.Table, sql99.ConstraintTypeConstraint) + err := createConstraintsTable(ops.config, ofs, nil, nil, entry.Schema, entry.Table, sql99.ConstraintTypeConstraint) if err != nil { return err } @@ -1109,7 +1109,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm } } - err := createViewsOrdered(ops.dbsteward, ofs, nil, doc) + err := createViewsOrdered(ops.config, ofs, nil, doc) if err != nil { return err } @@ -1118,7 +1118,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm for _, schema := range doc.Schemas { for _, view := range schema.Views { for _, grant := range view.Grants { - s, err := getViewGrantSql(ops.dbsteward, doc, schema, view, grant) + s, err := getViewGrantSql(ops.config, doc, schema, view, grant) if err != nil { return err } @@ -1132,7 +1132,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm } func (ops *Operations) buildData(_ *slog.Logger, doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []*ir.TableRef) error { - limitToTables := ops.dbsteward.LimitToTables + limitToTables := ops.config.LimitToTables // use the dependency order to then write out the actual data inserts into the data sql file for _, entry := range tableDep { @@ -1222,7 +1222,7 @@ func (ops *Operations) columnValueDefault(l *slog.Logger, schema *ir.Schema, tab } } - col, err := ops.dbsteward.NewDatabase.TryInheritanceGetColumn(schema, table, columnName) + col, err := ops.config.NewDatabase.TryInheritanceGetColumn(schema, table, columnName) if err != nil { return nil, fmt.Errorf("TryInheritanceGetColumn %w", err) } @@ -1244,7 +1244,7 @@ func (ops *Operations) columnValueDefault(l *slog.Logger, schema *ir.Schema, tab return sql.RawSql(col.Default), nil } - colType, err := getColumnType(l, ops.dbsteward.NewDatabase, schema, table, col) + colType, err := getColumnType(l, ops.config.NewDatabase, schema, table, col) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/operations_column_value_default_test.go b/lib/format/pgsql8/operations_column_value_default_test.go index d42a8fd..a91d901 100644 --- a/lib/format/pgsql8/operations_column_value_default_test.go +++ b/lib/format/pgsql8/operations_column_value_default_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" ) @@ -127,12 +126,11 @@ func getColumnValueDefault(def *ir.Column, data *ir.DataCol) (string, error) { }, }, } - dbs := lib.NewDBSteward() - dbs.NewDatabase = doc schema := doc.Schemas[0] table := schema.Tables[0] - - ops := NewOperations(dbs).(*Operations) + conf := DefaultConfig + conf.NewDatabase = doc + ops := NewOperations(conf).(*Operations) // TODO(go,nth) can we do this without also testing GetValueSql? toVal, err := ops.columnValueDefault(slog.Default(), schema, table, def.Name, data) diff --git a/lib/format/pgsql8/operations_extract_schema_test.go b/lib/format/pgsql8/operations_extract_schema_test.go index 1f0392e..480818e 100644 --- a/lib/format/pgsql8/operations_extract_schema_test.go +++ b/lib/format/pgsql8/operations_extract_schema_test.go @@ -5,7 +5,6 @@ import ( "strings" "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/util" "github.com/jackc/pgtype" @@ -107,7 +106,7 @@ func TestOperations_ExtractSchema_Indexes(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -163,7 +162,7 @@ func TestOperations_ExtractSchema_CompoundUniqueConstraint(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -217,7 +216,7 @@ func TestOperations_ExtractSchema_TableComments(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -259,7 +258,7 @@ END; }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -343,7 +342,7 @@ func TestOperations_ExtractSchema_FunctionArgs(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -386,7 +385,7 @@ func TestOperations_ExtractSchema_TableArrayType(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -446,7 +445,7 @@ func TestOperations_ExtractSchema_FKReferentialConstraints(t *testing.T) { }, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -505,7 +504,7 @@ func TestOperations_ExtractSchema_Sequences(t *testing.T) { {Schema: "public", Table: "user", Name: "user_pkey", Type: "p", Columns: []string{"user_id"}}, }, } - ops := NewOperations(lib.NewDBSteward()).(*Operations) + ops := NewOperations(DefaultConfig).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) diff --git a/lib/format/pgsql8/pgsql8.go b/lib/format/pgsql8/pgsql8.go index 12e1153..cd28a12 100644 --- a/lib/format/pgsql8/pgsql8.go +++ b/lib/format/pgsql8/pgsql8.go @@ -1,6 +1,8 @@ package pgsql8 import ( + "log/slog" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" ) @@ -10,3 +12,36 @@ var commonSchema = NewSchema() func init() { lib.RegisterFormat(ir.SqlFormatPgsql8, NewOperations) } + +var DefaultConfig = lib.Config{ + Logger: slog.Default(), + SqlFormat: ir.SqlFormatPgsql8, + CreateLanguages: false, + RequireSlonyId: false, + RequireSlonySetId: false, + GenerateSlonik: false, + SlonyIdStartValue: 1, + SlonyIdSetValue: 1, + OutputFileStatementLimit: 999999, + IgnoreCustomRoles: false, + IgnorePrimaryKeyErrors: false, + RequireVerboseIntervalNotation: false, + QuoteSchemaNames: false, + QuoteObjectNames: false, + QuoteTableNames: false, + QuoteFunctionNames: false, + QuoteColumnNames: false, + QuoteAllNames: false, + QuoteIllegalIdentifiers: true, + QuoteReservedIdentifiers: true, + OnlySchemaSql: false, + OnlyDataSql: false, + LimitToTables: map[string][]string{}, + SingleStageUpgrade: false, + FileOutputDirectory: "", + FileOutputPrefix: "", + IgnoreOldNames: false, + AlwaysRecreateViews: true, + OldDatabase: nil, + NewDatabase: nil, +} diff --git a/lib/format/pgsql8/pgsql8_main_test.go b/lib/format/pgsql8/pgsql8_main_test.go index e5e20de..fe4ec09 100644 --- a/lib/format/pgsql8/pgsql8_main_test.go +++ b/lib/format/pgsql8/pgsql8_main_test.go @@ -5,9 +5,9 @@ import ( "github.com/dbsteward/dbsteward/lib/ir" ) -func setOldNewDocs(dbs *lib.DBSteward, differ *diff, old, new *ir.Definition) { - dbs.OldDatabase = old - dbs.NewDatabase = new +func setOldNewDocs(conf lib.Config, differ *diff, old, new *ir.Definition) lib.Config { + conf.OldDatabase = old + conf.NewDatabase = new var err error if old != nil { differ.OldTableDependency, err = old.TableDependencyOrder() @@ -21,4 +21,5 @@ func setOldNewDocs(dbs *lib.DBSteward, differ *diff, old, new *ir.Definition) { panic(err) } } + return conf } diff --git a/lib/format/pgsql8/schema.go b/lib/format/pgsql8/schema.go index 5490e0c..a324d0d 100644 --- a/lib/format/pgsql8/schema.go +++ b/lib/format/pgsql8/schema.go @@ -19,7 +19,7 @@ func NewSchema() *Schema { return &Schema{} } -func (s *Schema) GetCreationSql(dbs *lib.DBSteward, schema *ir.Schema) ([]output.ToSql, error) { +func (s *Schema) GetCreationSql(conf lib.Config, schema *ir.Schema) ([]output.ToSql, error) { // don't create the public schema if strings.EqualFold(schema.Name, "public") { return nil, nil @@ -30,7 +30,7 @@ func (s *Schema) GetCreationSql(dbs *lib.DBSteward, schema *ir.Schema) ([]output } if schema.Owner != "" { - owner, err := roleEnum(dbs.Logger(), dbs.NewDatabase, schema.Owner, dbs.IgnoreCustomRoles) + owner, err := roleEnum(conf.Logger, conf.NewDatabase, schema.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -53,11 +53,11 @@ func (s *Schema) GetDropSql(schema *ir.Schema) []output.ToSql { } } -func (s *Schema) GetGrantSql(dbs *lib.DBSteward, doc *ir.Definition, schema *ir.Schema, grant *ir.Grant) ([]output.ToSql, error) { +func (s *Schema) GetGrantSql(conf lib.Config, doc *ir.Definition, schema *ir.Schema, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) + roles[i], err = roleEnum(conf.Logger, conf.NewDatabase, role, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -81,7 +81,7 @@ func (s *Schema) GetGrantSql(dbs *lib.DBSteward, doc *ir.Definition, schema *ir. // SCHEMA IMPLICIT GRANTS // READYONLY USER PROVISION: grant usage on the schema for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) + roRole, err := roleEnum(conf.Logger, conf.NewDatabase, ir.RoleReadOnly, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/sequence.go b/lib/format/pgsql8/sequence.go index 2f67179..d22dd69 100644 --- a/lib/format/pgsql8/sequence.go +++ b/lib/format/pgsql8/sequence.go @@ -10,7 +10,7 @@ import ( "github.com/dbsteward/dbsteward/lib/util" ) -func getCreateSequenceSql(dbs *lib.DBSteward, schema *ir.Schema, sequence *ir.Sequence) ([]output.ToSql, error) { +func getCreateSequenceSql(conf lib.Config, schema *ir.Schema, sequence *ir.Sequence) ([]output.ToSql, error) { // TODO(go,3) put validation elsewhere cache, cacheValueSet := sequence.Cache.Maybe() if !cacheValueSet { @@ -38,7 +38,7 @@ func getCreateSequenceSql(dbs *lib.DBSteward, schema *ir.Schema, sequence *ir.Se if sequence.Owner != "" { // NOTE: Old dbsteward uses ALTER TABLE for this, which is valid according to docs, however // ALTER SEQUENCE also works in pgsql 8, and that's more correct - role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, sequence.Owner, dbs.IgnoreCustomRoles) + role, err := roleEnum(conf.Logger, conf.NewDatabase, sequence.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -66,11 +66,11 @@ func getDropSequenceSql(schema *ir.Schema, sequence *ir.Sequence) []output.ToSql } } -func getSequenceGrantSql(dbs *lib.DBSteward, schema *ir.Schema, seq *ir.Sequence, grant *ir.Grant) ([]output.ToSql, error) { +func getSequenceGrantSql(conf lib.Config, schema *ir.Schema, seq *ir.Sequence, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) + roles[i], err = roleEnum(conf.Logger, conf.NewDatabase, role, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func getSequenceGrantSql(dbs *lib.DBSteward, schema *ir.Schema, seq *ir.Sequence // SEQUENCE IMPLICIT GRANTS // READYONLY USER PROVISION: generate a SELECT on the sequence for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) + roRole, err := roleEnum(conf.Logger, conf.NewDatabase, ir.RoleReadOnly, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/table.go b/lib/format/pgsql8/table.go index 032b0d2..43ac5ee 100644 --- a/lib/format/pgsql8/table.go +++ b/lib/format/pgsql8/table.go @@ -15,8 +15,8 @@ import ( var includeColumnDefaultNextvalInCreateSql bool -func getCreateTableSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table) ([]output.ToSql, error) { - l := dbs.Logger().With( +func getCreateTableSql(conf lib.Config, schema *ir.Schema, table *ir.Table) ([]output.ToSql, error) { + l := conf.Logger.With( slog.String("table", table.Name), slog.String("schema", schema.Name), ) @@ -24,7 +24,7 @@ func getCreateTableSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table) ( colSetup := []output.ToSql{} for _, col := range table.Columns { ll := l.With(slog.String("column", col.Name)) - newCol, err := getReducedColumnDefinition(ll, dbs.NewDatabase, schema, table, col) + newCol, err := getReducedColumnDefinition(ll, conf.NewDatabase, schema, table, col) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func getCreateTableSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table) ( ddl = append(ddl, colSetup...) if table.Owner != "" { - role, err := roleEnum(l, dbs.NewDatabase, table.Owner, dbs.IgnoreCustomRoles) + role, err := roleEnum(l, conf.NewDatabase, table.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -125,11 +125,11 @@ func defineTableColumnDefaults(l *slog.Logger, schema *ir.Schema, table *ir.Tabl return out } -func getTableGrantSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table, grant *ir.Grant) ([]output.ToSql, error) { +func getTableGrantSql(conf lib.Config, schema *ir.Schema, table *ir.Table, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) + roles[i], err = roleEnum(conf.Logger, conf.NewDatabase, role, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -155,7 +155,7 @@ func getTableGrantSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table, gr // TABLE IMPLICIT GRANTS // READYONLY USER PROVISION: grant select on the table for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) + roRole, err := roleEnum(conf.Logger, conf.NewDatabase, ir.RoleReadOnly, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/table_test.go b/lib/format/pgsql8/table_test.go index 0cc013e..3433cf0 100644 --- a/lib/format/pgsql8/table_test.go +++ b/lib/format/pgsql8/table_test.go @@ -3,7 +3,6 @@ package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -43,7 +42,7 @@ func TestTable_GetCreationSql_TableOptions(t *testing.T) { }, } - ddl, err := getCreateTableSql(lib.NewDBSteward(), schema, schema.Tables[0]) + ddl, err := getCreateTableSql(DefaultConfig, schema, schema.Tables[0]) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/type.go b/lib/format/pgsql8/type.go index ad5ae21..92805fe 100644 --- a/lib/format/pgsql8/type.go +++ b/lib/format/pgsql8/type.go @@ -119,12 +119,12 @@ func isIntType(spec string) bool { } // Change all table columns that are the given datatype to a placeholder type -func alterColumnTypePlaceholder(dbs *lib.DBSteward, differ *diff, datatype *ir.TypeDef) ([]*ir.ColumnRef, []output.ToSql, error) { +func alterColumnTypePlaceholder(conf lib.Config, differ *diff, datatype *ir.TypeDef) ([]*ir.ColumnRef, []output.ToSql, error) { ddl := []output.ToSql{} cols := []*ir.ColumnRef{} for _, newTableRef := range differ.NewTableDependency { for _, newColumn := range newTableRef.Table.Columns { - columnType, err := getColumnType(dbs.Logger(), dbs.NewDatabase, newTableRef.Schema, newTableRef.Table, newColumn) + columnType, err := getColumnType(conf.Logger, conf.NewDatabase, newTableRef.Schema, newTableRef.Table, newColumn) if err != nil { return nil, nil, err } diff --git a/lib/format/pgsql8/view.go b/lib/format/pgsql8/view.go index 9be0416..b115a52 100644 --- a/lib/format/pgsql8/view.go +++ b/lib/format/pgsql8/view.go @@ -11,7 +11,7 @@ import ( "github.com/dbsteward/dbsteward/lib/util" ) -func getCreateViewSql(dbs *lib.DBSteward, schema *ir.Schema, view *ir.View) ([]output.ToSql, error) { +func getCreateViewSql(conf lib.Config, schema *ir.Schema, view *ir.View) ([]output.ToSql, error) { ref := sql.ViewRef{Schema: schema.Name, View: view.Name} query := view.TryGetViewQuery(ir.SqlFormatPgsql8) util.Assert(query != nil, "Calling View.GetCreationSql for a view not defined for this sqlformat") @@ -30,7 +30,7 @@ func getCreateViewSql(dbs *lib.DBSteward, schema *ir.Schema, view *ir.View) ([]o }) } if view.Owner != "" { - role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, view.Owner, dbs.IgnoreCustomRoles) + role, err := roleEnum(conf.Logger, conf.NewDatabase, view.Owner, conf.IgnoreCustomRoles) if err != nil { return nil, err } @@ -51,12 +51,12 @@ func getDropViewSql(schema *ir.Schema, view *ir.View) []output.ToSql { } } -func getViewGrantSql(dbs *lib.DBSteward, doc *ir.Definition, schema *ir.Schema, view *ir.View, grant *ir.Grant) ([]output.ToSql, error) { +func getViewGrantSql(conf lib.Config, doc *ir.Definition, schema *ir.Schema, view *ir.View, grant *ir.Grant) ([]output.ToSql, error) { // NOTE: pgsql views use table grants! roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) + roles[i], err = roleEnum(conf.Logger, conf.NewDatabase, role, conf.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/xml_parser_test.go b/lib/format/pgsql8/xml_parser_test.go index 635859d..df631d5 100644 --- a/lib/format/pgsql8/xml_parser_test.go +++ b/lib/format/pgsql8/xml_parser_test.go @@ -5,7 +5,6 @@ import ( "log/slog" "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/stretchr/testify/assert" ) @@ -163,7 +162,7 @@ END;`, } // note that Process mutates the document in place - xmlParser := NewXmlParser(defaultQuoter(lib.NewDBSteward())) + xmlParser := NewXmlParser(defaultQuoter(DefaultConfig)) err := xmlParser.Process(slog.Default(), doc) if err != nil { t.Fatal(err) diff --git a/lib/loghandler.go b/lib/loghandler.go deleted file mode 100644 index 4d4d0d4..0000000 --- a/lib/loghandler.go +++ /dev/null @@ -1,74 +0,0 @@ -package lib - -import ( - "bytes" - "context" - "log/slog" - "strings" -) - -func newLogHandler(dbs *DBSteward) slog.Handler { - buf := bytes.Buffer{} - f := slog.NewTextHandler(&buf, nil) - return &logHandler{ - dbsteward: dbs, - formatter: f, - output: &buf, - } -} - -// logHandler is an intermediate step to support both slog logging -// and the old method of dbsteward logging -type logHandler struct { - dbsteward *DBSteward - formatter slog.Handler - output *bytes.Buffer -} - -// Enabled always returns true and let zerolog decide -func (h *logHandler) Enabled(_ context.Context, level slog.Level) bool { - return true -} - -func (h *logHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - return &logHandler{ - dbsteward: h.dbsteward, - output: h.output, - formatter: h.formatter.WithAttrs(attrs), - } -} - -func (h *logHandler) WithGroup(name string) slog.Handler { - return &logHandler{ - dbsteward: h.dbsteward, - output: h.output, - formatter: h.formatter.WithGroup(name), - } -} - -// Handle is a bit of a hack. Just using TextFormatter to do the actual -// handling and and then extracting the result from the byte buffer to -// send it to zerolog as an intermediate step that maintains nearly -// the same behavior as previous while still supporting all of slog's -// features -func (h *logHandler) Handle(ctx context.Context, r slog.Record) error { - h.formatter.Handle(ctx, r) - msg := strings.TrimSpace(h.output.String()) - if msg == "" { - msg = "<>" - } - switch r.Level { - case slog.LevelDebug: - h.dbsteward.logger.Debug().Msgf(msg) - case slog.LevelInfo: - h.dbsteward.logger.Info().Msgf(msg) - case slog.LevelWarn: - h.dbsteward.logger.Warn().Msgf(msg) - default: - // Should be Error, but in case other levels get define at - // least nothing gets lost - h.dbsteward.logger.Error().Msgf(msg) - } - h.output.Reset() - return nil -} diff --git a/lib/slonik.go b/lib/slonik.go index 8cb9402..2142057 100644 --- a/lib/slonik.go +++ b/lib/slonik.go @@ -1,5 +1,10 @@ package lib +type SlonyOperations interface { + SlonyCompare(file string) + SlonyDiff(oldFile, newFile string) +} + type Slonik struct{} func NewSlonik() *Slonik { diff --git a/main.go b/main.go index 942ecd7..798db77 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,710 @@ package main import ( + "bytes" + "context" + "fmt" + "log" + "log/slog" + "os" + "path" + "strings" + + "github.com/alexflint/go-arg" "github.com/dbsteward/dbsteward/lib" + "github.com/dbsteward/dbsteward/lib/config" + "github.com/dbsteward/dbsteward/lib/encoding/xml" _ "github.com/dbsteward/dbsteward/lib/format/pgsql8" + "github.com/dbsteward/dbsteward/lib/ir" + "github.com/dbsteward/dbsteward/lib/util" + "github.com/hashicorp/go-multierror" + "github.com/rs/zerolog" ) func main() { - dbsteward := lib.NewDBSteward() + dbsteward := NewDBSteward() dbsteward.ArgParse() dbsteward.Info("Done") } + +// NOTE: 2.0.0 is the intended golang release. 3.0.0 is the intended refactor/modernization +const Version = "2.0.0" + +// NOTE: we're attempting to maintain "api" compat with legacy dbsteward for now +const ApiVersion = "1.4" + +type Mode uint + +const ( + ModeUnknown Mode = 0 + ModeXmlDataInsert Mode = 1 + ModeXmlSort Mode = 2 + ModeXmlConvert Mode = 4 + ModeBuild Mode = 8 + ModeDiff Mode = 16 + ModeExtract Mode = 32 + ModeDbDataDiff Mode = 64 + ModeXmlSlonyId Mode = 73 + ModeSqlDiff Mode = 128 + ModeSlonikConvert Mode = 256 + ModeSlonyCompare Mode = 512 + ModeSlonyDiff Mode = 1024 +) + +type DBSteward struct { + logger zerolog.Logger + config lib.Config +} + +func NewDBSteward() *DBSteward { + dbsteward := &DBSteward{ + logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger(), + config: lib.Config{ + SqlFormat: lib.DefaultSqlFormat, + CreateLanguages: false, + RequireSlonyId: false, + RequireSlonySetId: false, + GenerateSlonik: false, + SlonyIdStartValue: 1, + SlonyIdSetValue: 1, + OutputFileStatementLimit: 900, + IgnoreCustomRoles: false, + IgnorePrimaryKeyErrors: false, + RequireVerboseIntervalNotation: false, + QuoteSchemaNames: false, + QuoteObjectNames: false, + QuoteTableNames: false, + QuoteFunctionNames: false, + QuoteColumnNames: false, + QuoteAllNames: false, + QuoteIllegalIdentifiers: false, + QuoteReservedIdentifiers: false, + OnlySchemaSql: false, + OnlyDataSql: false, + LimitToTables: map[string][]string{}, + SingleStageUpgrade: false, + FileOutputDirectory: "", + FileOutputPrefix: "", + IgnoreOldNames: false, + AlwaysRecreateViews: true, + OldDatabase: nil, + NewDatabase: nil, + }, + } + + return dbsteward +} + +// correlates to dbsteward->arg_parse() +func (dbsteward *DBSteward) ArgParse() { + // TODO(go,nth): deck this out with better go-arg config + args := &config.Args{} + arg.MustParse(args) + + dbsteward.setVerbosity(args) + + // XML file parameter sanity checks + if len(args.XmlFiles) > 0 { + if len(args.OldXmlFiles) > 0 { + dbsteward.fatal("Parameter error: xml and oldxml options are not to be mixed. Did you mean newxml?") + } + if len(args.NewXmlFiles) > 0 { + dbsteward.fatal("Parameter error: xml and newxml options are not to be mixed. Did you mean oldxml?") + } + } + if len(args.OldXmlFiles) > 0 && len(args.NewXmlFiles) == 0 { + dbsteward.fatal("Parameter error: oldxml needs newxml specified for differencing to occur") + } + if len(args.NewXmlFiles) > 0 && len(args.OldXmlFiles) == 0 { + dbsteward.fatal("Parameter error: oldxml needs newxml specified for differencing to occur") + } + dbsteward.config.Logger = slog.New(newLogHandler(dbsteward)) + // database connectivity values + // dbsteward.dbHost = args.DbHost + // dbsteward.dbPort = args.DbPort + // dbsteward.dbName = args.DbName + // dbsteward.dbUser = args.DbUser + // dbsteward.dbPass = args.DbPassword + + // SQL DDL DML DCL output flags + dbsteward.config.OnlySchemaSql = args.OnlySchemaSql + dbsteward.config.OnlyDataSql = args.OnlyDataSql + for _, onlyTable := range args.OnlyTables { + table := lib.ParseQualifiedTableName(onlyTable) + dbsteward.config.LimitToTables[table.Schema] = append(dbsteward.config.LimitToTables[table.Schema], table.Table) + } + + // XML parsing switches + dbsteward.config.SingleStageUpgrade = args.SingleStageUpgrade + if dbsteward.config.SingleStageUpgrade { + // don't recreate views when in single stage upgrade mode + // TODO(feat) make view diffing smart enough that this doesn't need to be done + dbsteward.config.AlwaysRecreateViews = false + } + dbsteward.config.IgnoreOldNames = args.IgnoreOldNames + dbsteward.config.IgnoreCustomRoles = args.IgnoreCustomRoles + dbsteward.config.IgnorePrimaryKeyErrors = args.IgnorePrimaryKeyErrors + dbsteward.config.RequireSlonyId = args.RequireSlonyId + dbsteward.config.RequireSlonySetId = args.RequireSlonySetId + dbsteward.config.GenerateSlonik = args.GenerateSlonik + dbsteward.config.SlonyIdStartValue = args.SlonyIdStartValue + dbsteward.config.SlonyIdSetValue = args.SlonyIdSetValue + + // determine operation and check arguments for each + mode := ModeUnknown + switch { + case len(args.XmlDataInsert) > 0: + mode = ModeXmlDataInsert + case len(args.XmlSort) > 0: + mode = ModeXmlSort + case len(args.XmlConvert) > 0: + mode = ModeXmlConvert + case len(args.XmlFiles) > 0: + mode = ModeBuild + case len(args.NewXmlFiles) > 0: + mode = ModeDiff + case args.DbSchemaDump: + mode = ModeExtract + case len(args.DbDataDiff) > 0: + mode = ModeDbDataDiff + case len(args.OldSql) > 0 || len(args.NewSql) > 0: + mode = ModeSqlDiff + case len(args.SlonikConvert) > 0: + mode = ModeSlonikConvert + case len(args.SlonyCompare) > 0: + mode = ModeSlonyCompare + case len(args.SlonyDiffOld) > 0: + mode = ModeSlonyDiff + case len(args.SlonyIdIn) > 0: + mode = ModeXmlSlonyId + } + + // validate mode parameters + if mode == ModeXmlDataInsert { + if len(args.XmlFiles) == 0 { + dbsteward.fatal("xmldatainsert needs xml parameter defined") + } else if len(args.XmlFiles) > 1 { + dbsteward.fatal("xmldatainsert only supports one xml file") + } + } + if mode == ModeExtract || mode == ModeDbDataDiff { + if len(args.DbHost) == 0 { + dbsteward.fatal("dbhost not specified") + } + if len(args.DbName) == 0 { + dbsteward.fatal("dbname not specified") + } + if len(args.DbUser) == 0 { + dbsteward.fatal("dbuser not specified") + } + if args.DbPassword == nil { + p, err := util.PromptPassword("[DBSteward] Enter password for postgres://%s@%s:%d/%s: ", args.DbUser, args.DbHost, args.DbPort, args.DbName) + dbsteward.fatalIfError(err, "Could not read password input") + args.DbPassword = &p + } + } + if mode == ModeExtract || mode == ModeSqlDiff { + if len(args.OutputFile) == 0 { + dbsteward.fatal("output file not specified") + } + } + if mode == ModeXmlSlonyId { + if len(args.SlonyIdOut) > 0 { + if args.SlonyIdIn[0] == args.SlonyIdOut { + // TODO(go,nth) resolve filepaths to do this correctly + // TODO(go,nth) check all SlonyIdIn elements + dbsteward.fatal("slonyidin and slonyidout file paths should not be the same") + } + } + } + + if len(args.OutputDir) > 0 { + if !util.IsDir(args.OutputDir) { + dbsteward.fatal("outputdir is not a directory, must be a writable directory") + } + dbsteward.config.FileOutputDirectory = args.OutputDir + } + dbsteward.config.FileOutputPrefix = args.OutputFilePrefix + + if args.XmlCollectDataAddendums > 0 { + if mode != ModeDbDataDiff { + dbsteward.fatal("--xmlcollectdataaddendums is only supported for fresh builds") + } + // dammit go + // invalid operation: args.XmlCollectDataAddendums > len(args.XmlFiles) (mismatched types uint and int) + if int(args.XmlCollectDataAddendums) > len(args.XmlFiles) { + dbsteward.fatal("Cannot collect more data addendums than files provided") + } + } + + dbsteward.Info("DBSteward Version %s", Version) + + // set the global sql format + dbsteward.config.SqlFormat = dbsteward.reconcileSqlFormat(ir.SqlFormatPgsql8, args.SqlFormat) + dbsteward.Info("Using sqlformat=%s", dbsteward.config.SqlFormat) + dbsteward.defineSqlFormatDefaultValues(dbsteward.config.SqlFormat, args) + + dbsteward.config.QuoteSchemaNames = args.QuoteSchemaNames + dbsteward.config.QuoteTableNames = args.QuoteTableNames + dbsteward.config.QuoteColumnNames = args.QuoteColumnNames + dbsteward.config.QuoteAllNames = args.QuoteAllNames + dbsteward.config.QuoteIllegalIdentifiers = args.QuoteIllegalNames + dbsteward.config.QuoteReservedIdentifiers = args.QuoteReservedNames + + // TODO(go,3) move all of these to separate subcommands + switch mode { + case ModeXmlDataInsert: + dbsteward.doXmlDataInsert(args.XmlFiles[0], args.XmlDataInsert) + case ModeXmlSort: + dbsteward.doXmlSort(args.XmlSort) + case ModeXmlConvert: + dbsteward.doXmlConvert(args.XmlConvert) + case ModeXmlSlonyId: + dbsteward.doXmlSlonyId(args.SlonyIdIn, args.SlonyIdOut) + case ModeBuild: + dbsteward.doBuild(args.XmlFiles, args.PgDataXml, args.XmlCollectDataAddendums) + case ModeDiff: + dbsteward.doDiff(args.OldXmlFiles, args.NewXmlFiles, args.PgDataXml) + case ModeExtract: + dbsteward.doExtract(args.DbHost, args.DbPort, args.DbName, args.DbUser, *args.DbPassword, args.OutputFile) + case ModeDbDataDiff: + dbsteward.doDbDataDiff(args.XmlFiles, args.PgDataXml, args.XmlCollectDataAddendums, args.DbHost, args.DbPort, args.DbName, args.DbUser, *args.DbPassword) + case ModeSqlDiff: + dbsteward.doSqlDiff(args.OldSql, args.NewSql, args.OutputFile) + case ModeSlonikConvert: + dbsteward.doSlonikConvert(args.SlonikConvert, args.OutputFile) + case ModeSlonyCompare: + dbsteward.doSlonyCompare(args.SlonyCompare) + case ModeSlonyDiff: + dbsteward.doSlonyDiff(args.SlonyDiffOld, args.SlonyDiffNew) + default: + dbsteward.fatal("No operation specified") + } +} + +// Logger returns an *slog.Logger pointed at the console +func (dbsteward *DBSteward) Logger() *slog.Logger { + if dbsteward == nil { + panic("dbsteward is nil") + } + if dbsteward.config.Logger == nil { + dbsteward.config.Logger = slog.New(newLogHandler(dbsteward)) + } + return dbsteward.config.Logger +} + +func (dbsteward *DBSteward) fatal(s string, args ...interface{}) { + dbsteward.logger.Fatal().Msgf(s, args...) +} +func (dbsteward *DBSteward) fatalIfError(err error, s string, args ...interface{}) { + if err != nil { + dbsteward.logger.Fatal().Err(err).Msgf(s, args...) + } +} + +func (dbsteward *DBSteward) warning(s string, args ...interface{}) { + dbsteward.logger.Warn().Msgf(s, args...) +} + +func (dbsteward *DBSteward) Info(s string, args ...interface{}) { + dbsteward.logger.Info().Msgf(s, args...) +} + +// dbsteward::set_verbosity($options) +func (dbsteward *DBSteward) setVerbosity(args *config.Args) { + // TODO(go,nth): differentiate between notice and info + + // remember, lower level is higher verbosity + // we're abusing the fact that zerolog.LogLevel is defined as an int8 + level := zerolog.InfoLevel + + if args.Debug { + level = zerolog.TraceLevel + } + + for _, v := range args.Verbose { + if v { + level -= 1 + } else { + level += 1 + } + } + for _, q := range args.Quiet { + if q { + level += 1 + } else { + level -= 1 + } + } + + // clamp it to valid values + if level > zerolog.PanicLevel { + level = zerolog.PanicLevel + } + if level < zerolog.TraceLevel { + level = zerolog.TraceLevel + } + + dbsteward.logger = dbsteward.logger.Level(level) +} + +func (dbsteward *DBSteward) reconcileSqlFormat(target, requested ir.SqlFormat) ir.SqlFormat { + if target != ir.SqlFormatUnknown { + if requested != ir.SqlFormatUnknown { + if target == requested { + return target + } + + dbsteward.warning("XML is targeted for %s but you are forcing %s. Things will probably break!", target, requested) + return requested + } + + dbsteward.Info("XML file(s) are targetd for sqlformat=%s", target) + return target + } + + if requested != ir.SqlFormatUnknown { + return requested + } + + return lib.DefaultSqlFormat +} + +func (dbsteward *DBSteward) defineSqlFormatDefaultValues(SqlFormat ir.SqlFormat, args *config.Args) { + switch SqlFormat { + case ir.SqlFormatPgsql8: + dbsteward.config.CreateLanguages = true + dbsteward.config.QuoteSchemaNames = false + dbsteward.config.QuoteTableNames = false + dbsteward.config.QuoteColumnNames = false + if args.DbPort == 0 { + args.DbPort = 5432 + } + } + + if SqlFormat != ir.SqlFormatPgsql8 { + if len(args.PgDataXml) > 0 { + dbsteward.fatal("pgdataxml parameter is not supported by %s driver", SqlFormat) + } + } +} + +func (dbsteward *DBSteward) calculateFileOutputPrefix(files []string) string { + return path.Join( + dbsteward.calculateFileOutputDirectory(files[0]), + util.CoalesceStr(dbsteward.config.FileOutputPrefix, util.Basename(files[0], ".xml")), + ) +} +func (dbsteward *DBSteward) calculateFileOutputDirectory(file string) string { + return util.CoalesceStr(dbsteward.config.FileOutputDirectory, path.Dir(file)) +} + +// Append columns in a table's rows collection, based on a simplified XML definition of what to insert +func (dbsteward *DBSteward) doXmlDataInsert(defFile string, dataFile string) { + // TODO(go,xmlutil) verify this behavior is correct, add tests. need to change fatals to returns + dbsteward.Info("Automatic insert data into %s from %s", defFile, dataFile) + defDoc, err := xml.LoadDefintion(defFile) + dbsteward.fatalIfError(err, "Failed to load %s", defFile) + + dataDoc, err := xml.LoadDefintion(dataFile) + dbsteward.fatalIfError(err, "Failed to load %s", dataFile) + + for _, dataSchema := range dataDoc.Schemas { + defSchema, err := defDoc.GetSchemaNamed(dataSchema.Name) + dbsteward.fatalIfError(err, "while searching %s", defFile) + for _, dataTable := range dataSchema.Tables { + defTable, err := defSchema.GetTableNamed(dataTable.Name) + dbsteward.fatalIfError(err, "while searching %s", defFile) + + dataRows := dataTable.Rows + if dataRows == nil { + dbsteward.fatal("table %s in %s does not have a element", dataTable.Name, dataFile) + } + + if len(dataRows.Columns) == 0 { + dbsteward.fatal("Unexpected: no rows[columns] found in table %s in file %s", dataTable.Name, dataFile) + } + + if len(dataRows.Rows) > 1 { + dbsteward.fatal("Unexpected: more than one rows->row found in table %s in file %s", dataTable.Name, dataFile) + } + + if len(dataRows.Rows[0].Columns) != len(dataRows.Columns) { + dbsteward.fatal("Unexpected: Table %s in %s defines %d colums but has %d elements", + dataTable.Name, dataFile, len(dataRows.Columns), len(dataRows.Rows[0].Columns)) + } + + for i, newColumn := range dataRows.Columns { + dbsteward.Info("Adding rows column %s to definition table %s", newColumn, defTable.Name) + + if defTable.Rows == nil { + defTable.Rows = &ir.DataRows{} + } + err = defTable.Rows.AddColumn(newColumn, dataRows.Columns[i]) + dbsteward.fatalIfError(err, "Could not add column %s to %s in %s", newColumn, dataTable.Name, dataFile) + } + } + } + + defFileModified := defFile + ".xmldatainserted" + dbsteward.Info("Saving modified dbsteward definition as %s", defFileModified) + err = xml.SaveDefinition(dbsteward.Logger(), defFileModified, defDoc) + dbsteward.fatalIfError(err, "saving file") +} +func (dbsteward *DBSteward) doXmlSort(files []string) { + for _, file := range files { + sortedFileName := file + ".xmlsorted" + dbsteward.Info("Sorting XML definition file: %s", file) + dbsteward.Info("Sorted XML output file: %s", sortedFileName) + xml.FileSort(file, sortedFileName) + } +} +func (dbsteward *DBSteward) doXmlConvert(files []string) { + for _, file := range files { + convertedFileName := file + ".xmlconverted" + dbsteward.Info("Upconverting XML definition file: %s", file) + dbsteward.Info("Upconvert XML output file: %s", convertedFileName) + + doc, err := xml.LoadDefintion(file) + dbsteward.fatalIfError(err, "Could not load %s", file) + xml.SqlFormatConvert(doc) + convertedXml, err := xml.FormatXml(dbsteward.Logger(), doc) + dbsteward.fatalIfError(err, "formatting xml") + convertedXml = strings.Replace(convertedXml, "pgdbxml>", "dbsteward>", -1) + err = util.WriteFile(convertedXml, convertedFileName) + dbsteward.fatalIfError(err, "Could not write converted xml to %s", convertedFileName) + } +} +func (dbsteward *DBSteward) doXmlSlonyId(files []string, slonyOut string) { + dbsteward.Info("Compositing XML file for Slony ID processing") + dbDoc, err := xml.XmlComposite(dbsteward.Logger(), files) + dbsteward.fatalIfError(err, "compositing files: %v", files) + dbsteward.Info("Xml files %s composited", strings.Join(files, " ")) + + outputPrefix := dbsteward.calculateFileOutputPrefix(files) + compositeFile := outputPrefix + "_composite.xml" + dbsteward.Info("Saving composite as %s", compositeFile) + err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) + dbsteward.fatalIfError(err, "saving file") + + dbsteward.Info("Slony ID numbering any missing attributes") + dbsteward.Info("slonyidstartvalue = %d", dbsteward.config.SlonyIdStartValue) + dbsteward.Info("slonyidsetvalue = %d", dbsteward.config.SlonyIdSetValue) + slonyIdDoc := xml.SlonyIdNumber(dbDoc) + slonyIdNumberedFile := outputPrefix + "_slonyid_numbered.xml" + if len(slonyOut) > 0 { + slonyIdNumberedFile = slonyOut + } + dbsteward.Info("Saving Slony ID numbered XML as %s", slonyIdNumberedFile) + err = xml.SaveDefinition(dbsteward.Logger(), slonyIdNumberedFile, slonyIdDoc) + dbsteward.fatalIfError(err, "saving file") +} +func (dbsteward *DBSteward) doBuild(files []string, dataFiles []string, addendums uint) { + dbsteward.Info("Compositing XML files...") + if addendums > 0 { + dbsteward.Info("Collecting %d data addendums", addendums) + } + dbDoc, addendumsDoc, err := xml.XmlCompositeAddendums(dbsteward.Logger(), files, addendums) + if err != nil { + mErr, isMErr := err.(*multierror.Error) + if isMErr { + for _, e := range mErr.Errors { + log.Println(e.Error()) + } + } else { + log.Println(err.Error()) + } + os.Exit(1) + } + if len(dataFiles) > 0 { + dbsteward.Info("Compositing pgdata XML files on top of XML composite...") + xml.XmlCompositePgData(dbDoc, dataFiles) + dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) + } + + dbsteward.Info("XML files %s composited", strings.Join(files, " ")) + + outputPrefix := dbsteward.calculateFileOutputPrefix(files) + compositeFile := outputPrefix + "_composite.xml" + dbsteward.Info("Saving composite as %s", compositeFile) + err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) + dbsteward.fatalIfError(err, "saving file") + + if addendumsDoc != nil { + addendumsFile := outputPrefix + "_addendums.xml" + dbsteward.Info("Saving addendums as %s", addendumsFile) + err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, addendumsDoc) + dbsteward.fatalIfError(err, "saving file") + } + + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + err = ops(dbsteward.config).Build(outputPrefix, dbDoc) + dbsteward.fatalIfError(err, "building") +} +func (dbsteward *DBSteward) doDiff(oldFiles []string, newFiles []string, dataFiles []string) { + dbsteward.Info("Compositing old XML files...") + oldDbDoc, err := xml.XmlComposite(dbsteward.Logger(), oldFiles) + dbsteward.fatalIfError(err, "compositing") + dbsteward.Info("Old XML files %s composited", strings.Join(oldFiles, " ")) + + dbsteward.Info("Compositing new XML files...") + newDbDoc, err := xml.XmlComposite(dbsteward.Logger(), newFiles) + dbsteward.fatalIfError(err, "compositing") + if len(dataFiles) > 0 { + dbsteward.Info("Compositing pgdata XML files on top of new XML composite...") + xml.XmlCompositePgData(newDbDoc, dataFiles) + dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) + } + dbsteward.Info("New XML files %s composited", strings.Join(newFiles, " ")) + + oldOutputPrefix := dbsteward.calculateFileOutputPrefix(oldFiles) + oldCompositeFile := oldOutputPrefix + "_composite.xml" + dbsteward.Info("Saving composite as %s", oldCompositeFile) + err = xml.SaveDefinition(dbsteward.Logger(), oldCompositeFile, oldDbDoc) + dbsteward.fatalIfError(err, "saving file") + + newOutputPrefix := dbsteward.calculateFileOutputPrefix(newFiles) + newCompositeFile := newOutputPrefix + "_composite.xml" + dbsteward.Info("Saving composite as %s", newCompositeFile) + err = xml.SaveDefinition(dbsteward.Logger(), newCompositeFile, newDbDoc) + dbsteward.fatalIfError(err, "saving file") + + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + err = ops(dbsteward.config).BuildUpgrade( + oldOutputPrefix, oldCompositeFile, oldDbDoc, oldFiles, + newOutputPrefix, newCompositeFile, newDbDoc, newFiles, + ) + dbsteward.fatalIfError(err, "building upgrade") +} +func (dbsteward *DBSteward) doExtract(dbHost string, dbPort uint, dbName, dbUser, dbPass string, outputFile string) { + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + output, err := ops(dbsteward.config).ExtractSchema(dbHost, dbPort, dbName, dbUser, dbPass) + dbsteward.fatalIfError(err, "extracting") + dbsteward.Info("Saving extracted database schema to %s", outputFile) + err = xml.SaveDefinition(dbsteward.Logger(), outputFile, output) + dbsteward.fatalIfError(err, "saving file") +} +func (dbsteward *DBSteward) doDbDataDiff(files []string, dataFiles []string, addendums uint, dbHost string, dbPort uint, dbName, dbUser, dbPass string) { + dbsteward.Info("Compositing XML files...") + if addendums > 0 { + dbsteward.Info("Collecting %d data addendums", addendums) + } + // TODO(feat) can this just be XmlComposite(files)? why do we need addendums? + dbDoc, _, err := xml.XmlCompositeAddendums(dbsteward.Logger(), files, addendums) + dbsteward.fatalIfError(err, "compositing addendums") + + if len(dataFiles) > 0 { + dbsteward.Info("Compositing pgdata XML files on top of XML composite...") + xml.XmlCompositePgData(dbDoc, dataFiles) + dbsteward.Info("postgres data XML files [%s] composited", strings.Join(dataFiles, " ")) + } + + dbsteward.Info("XML files %s composited", strings.Join(files, " ")) + + outputPrefix := dbsteward.calculateFileOutputPrefix(files) + compositeFile := outputPrefix + "_composite.xml" + dbsteward.Info("Saving composite as %s", compositeFile) + err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) + dbsteward.fatalIfError(err, "saving file") + + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + output, err := ops(dbsteward.config).CompareDbData(dbDoc, dbHost, dbPort, dbName, dbUser, dbPass) + dbsteward.fatalIfError(err, "comparing data") + err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, output) + dbsteward.fatalIfError(err, "saving file") +} +func (dbsteward *DBSteward) doSqlDiff(oldSql, newSql []string, outputFile string) { + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + ops(dbsteward.config).SqlDiff(oldSql, newSql, outputFile) +} +func (dbsteward *DBSteward) doSlonikConvert(file string, outputFile string) { + // TODO(go,nth) is there a nicer way to handle this output idiom? + output := lib.NewSlonik().Convert(file) + if len(outputFile) > 0 { + err := util.WriteFile(output, outputFile) + dbsteward.fatalIfError(err, "Failed to save slonikconvert output to %s", outputFile) + } else { + fmt.Println(output) + } +} +func (dbsteward *DBSteward) doSlonyCompare(file string) { + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + ops(dbsteward.config).(lib.SlonyOperations).SlonyCompare(file) +} +func (dbsteward *DBSteward) doSlonyDiff(oldFile string, newFile string) { + ops, err := lib.Format(lib.DefaultSqlFormat) + dbsteward.fatalIfError(err, "loading default format") + ops(dbsteward.config).(lib.SlonyOperations).SlonyDiff(oldFile, newFile) +} + +func newLogHandler(dbs *DBSteward) slog.Handler { + buf := bytes.Buffer{} + f := slog.NewTextHandler(&buf, nil) + return &logHandler{ + dbsteward: dbs, + formatter: f, + output: &buf, + } +} + +// logHandler is an intermediate step to support both slog logging +// and the old method of dbsteward logging +type logHandler struct { + dbsteward *DBSteward + formatter slog.Handler + output *bytes.Buffer +} + +// Enabled always returns true and let zerolog decide +func (h *logHandler) Enabled(_ context.Context, level slog.Level) bool { + return true +} + +func (h *logHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &logHandler{ + dbsteward: h.dbsteward, + output: h.output, + formatter: h.formatter.WithAttrs(attrs), + } +} + +func (h *logHandler) WithGroup(name string) slog.Handler { + return &logHandler{ + dbsteward: h.dbsteward, + output: h.output, + formatter: h.formatter.WithGroup(name), + } +} + +// Handle is a bit of a hack. Just using TextFormatter to do the actual +// handling and and then extracting the result from the byte buffer to +// send it to zerolog as an intermediate step that maintains nearly +// the same behavior as previous while still supporting all of slog's +// features +func (h *logHandler) Handle(ctx context.Context, r slog.Record) error { + h.formatter.Handle(ctx, r) + msg := strings.TrimSpace(h.output.String()) + if msg == "" { + msg = "<>" + } + switch r.Level { + case slog.LevelDebug: + h.dbsteward.logger.Debug().Msgf(msg) + case slog.LevelInfo: + h.dbsteward.logger.Info().Msgf(msg) + case slog.LevelWarn: + h.dbsteward.logger.Warn().Msgf(msg) + default: + // Should be Error, but in case other levels get define at + // least nothing gets lost + h.dbsteward.logger.Error().Msgf(msg) + } + h.output.Reset() + return nil +} diff --git a/xmlpostgresintegration_test.go b/xmlpostgresintegration_test.go index bf3c146..393a98c 100644 --- a/xmlpostgresintegration_test.go +++ b/xmlpostgresintegration_test.go @@ -7,10 +7,8 @@ import ( "strings" "testing" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/encoding/xml" "github.com/dbsteward/dbsteward/lib/format/pgsql8" - "github.com/dbsteward/dbsteward/lib/ir" ) //go:embed example/someapp_v1.xml @@ -36,8 +34,6 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - dbs := lib.NewDBSteward() - dbs.SqlFormat = ir.SqlFormatPgsql8 err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.Application) if err != nil { t.Fatal(err) @@ -54,7 +50,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops := pgsql8.NewOperations(dbs).(*pgsql8.Operations) + ops := pgsql8.NewOperations(pgsql8.DefaultConfig).(*pgsql8.Operations) statements, err := ops.CreateStatements(*def1) if err != nil { t.Fatal(err) @@ -79,7 +75,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops = pgsql8.NewOperations(dbs).(*pgsql8.Operations) + ops = pgsql8.NewOperations(pgsql8.DefaultConfig).(*pgsql8.Operations) statements, err = ops.Upgrade(slog.Default(), def1, def2) if err != nil { t.Fatal(err) @@ -100,7 +96,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops = pgsql8.NewOperations(dbs).(*pgsql8.Operations) + ops = pgsql8.NewOperations(pgsql8.DefaultConfig).(*pgsql8.Operations) _, err = ops.ExtractSchemaConn(context.TODO(), c) if err != nil { t.Fatal(err)