diff --git a/lib/format/pgsql8/oneeighty_test.go b/lib/format/pgsql8/oneeighty_test.go index a15bd49..cc92510 100644 --- a/lib/format/pgsql8/oneeighty_test.go +++ b/lib/format/pgsql8/oneeighty_test.go @@ -2,15 +2,12 @@ package pgsql8 import ( "context" - "fmt" "os" "testing" "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/ir" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" ) @@ -26,11 +23,11 @@ import ( // * Data types are not normalized nor standardized nor anything like that func TestOneEighty(t *testing.T) { - c := initdb(t) + c := Initdb(t, "pg") if c == nil { t.SkipNow() } - defer teardowndb(t, c) + defer Teardowndb(t, c, "pg") role := os.Getenv("DB_USER") lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ ir.SqlFormatPgsql8: GlobalLookup, @@ -63,98 +60,3 @@ func TestOneEighty(t *testing.T) { } assert.Equal(t, ir.FullFeatureSchema(role), *reflection, "reflection does not match original") } - -func initdb(t *testing.T) *pgx.Conn { - if os.Getenv("DB_NAME") == "" { - return nil - } - conn, err := pgx.Connect(context.TODO(), adminDSNFromEnv()) - if err != nil { - t.Fatal(err) - return nil - } - defer conn.Close(context.TODO()) - _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP DATABASE IF EXISTS %s", os.Getenv("DB_NAME"))) - if err != nil { - t.Fatal(err) - return nil - } - _, err = conn.Exec(context.TODO(), fmt.Sprintf("CREATE DATABASE %s", os.Getenv("DB_NAME"))) - if err != nil { - t.Fatal(err) - return nil - } - _, err = conn.Exec(context.TODO(), fmt.Sprintf("CREATE ROLE %s", ir.AdditionalRole)) - if err != nil { - if (err.(*pgconn.PgError)).Code != "42710" { // Role exists - t.Fatal(err) - return nil - } - } - err = conn.Close(context.TODO()) - if err != nil { - t.Fatal(err) - return nil - } - conn, err = pgx.Connect(context.TODO(), userDSNFromEnv()) - if err != nil { - t.Fatal(err) - return nil - } - return conn -} - -func teardowndb(t *testing.T, c *pgx.Conn) { - err := c.Close(context.TODO()) - if err != nil { - t.Fatal(err) - return - } - conn, err := pgx.Connect(context.TODO(), adminDSNFromEnv()) - if err != nil { - t.Fatal(err) - return - } - defer conn.Close(context.TODO()) - _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP DATABASE IF EXISTS %s", os.Getenv("DB_NAME"))) - if err != nil { - t.Log(err) - } - _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP ROLE IF EXISTS %s", ir.AdditionalRole)) - if err != nil { - t.Log(err) - } -} - -func adminDSNFromEnv() string { - host := os.Getenv("DB_HOST") - user := os.Getenv("DB_SUPERUSER") - password := os.Getenv("DB_PASSWORD") - dbName := "postgres" - port := os.Getenv("DB_PORT") - return fmt.Sprintf( - "host=%s user=%s password=%s dbname=%s port=%s", - host, - user, - password, - dbName, - port, - ) -} - -func userDSNFromEnv() string { - host := os.Getenv("DB_HOST") - user := os.Getenv("DB_USER") - password := os.Getenv("DB_PASSWORD") - dbName := os.Getenv("DB_NAME") - port := os.Getenv("DB_PORT") - cs := fmt.Sprintf( - "host=%s user=%s password=%s dbname=%s port=%s", - host, - user, - password, - dbName, - port, - ) - return cs -} diff --git a/lib/format/pgsql8/operations.go b/lib/format/pgsql8/operations.go index f6d0ed8..123bd46 100644 --- a/lib/format/pgsql8/operations.go +++ b/lib/format/pgsql8/operations.go @@ -220,6 +220,34 @@ func (ops *Operations) BuildUpgrade( return nil } +func (ops *Operations) Upgrade(l *slog.Logger, oldDoc *ir.Definition, newDoc *ir.Definition) ([]output.DDLStatement, error) { + var err error + ops.differ.OldTableDependency, err = oldDoc.TableDependencyOrder() + if err != nil { + return nil, fmt.Errorf("old document: %w", err) + } + ops.differ.NewTableDependency, err = newDoc.TableDependencyOrder() + if err != nil { + return nil, fmt.Errorf("new document: %w", err) + } + lib.GlobalDBSteward.OldDatabase = oldDoc + lib.GlobalDBSteward.NewDatabase = newDoc + + stage1 := output.NewSegmenter(ops.GetQuoter()) + stage2 := output.NewSegmenter(ops.GetQuoter()) + stage3 := output.NewSegmenter(ops.GetQuoter()) + stage4 := output.NewSegmenter(ops.GetQuoter()) + err = ops.differ.DiffDocWork(stage1, stage2, stage3, stage4) + if err != nil { + return nil, err + } + stmts := stage1.AllStatements() + stmts = append(stmts, stage2.AllStatements()...) + stmts = append(stmts, stage3.AllStatements()...) + stmts = append(stmts, stage4.AllStatements()...) + return stmts, nil +} + func (ops *Operations) ExtractSchemaConn(ctx context.Context, c *pgx.Conn) (*ir.Definition, error) { conn := &liveConnection{c} return ops.extractSchema(ctx, conn) diff --git a/lib/format/pgsql8/testutils.go b/lib/format/pgsql8/testutils.go new file mode 100644 index 0000000..2942595 --- /dev/null +++ b/lib/format/pgsql8/testutils.go @@ -0,0 +1,116 @@ +package pgsql8 + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/dbsteward/dbsteward/lib/ir" + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +func Initdb(t *testing.T, dbSuffix string) *pgx.Conn { + if os.Getenv("DB_NAME") == "" { + return nil + } + conn, err := pgx.Connect(context.TODO(), adminDSNFromEnv()) + if err != nil { + t.Fatal(err) + return nil + } + defer conn.Close(context.TODO()) + _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP DATABASE IF EXISTS %s", os.Getenv("DB_NAME")+dbSuffix)) + if err != nil { + t.Fatal(err) + return nil + } + _, err = conn.Exec(context.TODO(), fmt.Sprintf("CREATE DATABASE %s", os.Getenv("DB_NAME")+dbSuffix)) + if err != nil { + t.Fatal(err) + return nil + } + err = CreateRoleIfNotExists(conn, ir.AdditionalRole) + if err != nil { + t.Fatal(err) + return nil + } + err = conn.Close(context.TODO()) + if err != nil { + t.Fatal(err) + return nil + } + conn, err = pgx.Connect(context.TODO(), userDSNFromEnv(dbSuffix)) + if err != nil { + t.Fatal(err) + return nil + } + return conn +} + +func CreateRoleIfNotExists(conn *pgx.Conn, name string) error { + _, err := conn.Exec(context.TODO(), fmt.Sprintf("CREATE ROLE %s", name)) + if err != nil { + code := (err.(*pgconn.PgError)).Code + if code != "42710" && code != "23505" { // Role exists + return err + } + } + return nil +} + +func Teardowndb(t *testing.T, c *pgx.Conn, dbSuffix string) { + err := c.Close(context.TODO()) + if err != nil { + t.Fatal(err) + return + } + conn, err := pgx.Connect(context.TODO(), adminDSNFromEnv()) + if err != nil { + t.Fatal(err) + return + } + defer conn.Close(context.TODO()) + _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP DATABASE IF EXISTS %s", os.Getenv("DB_NAME")+dbSuffix)) + if err != nil { + t.Log(err) + } + _, err = conn.Exec(context.TODO(), fmt.Sprintf("DROP ROLE IF EXISTS %s", ir.AdditionalRole)) + if err != nil { + t.Log(err) + } +} + +func adminDSNFromEnv() string { + host := os.Getenv("DB_HOST") + user := os.Getenv("DB_SUPERUSER") + password := os.Getenv("DB_PASSWORD") + dbName := "postgres" + port := os.Getenv("DB_PORT") + return fmt.Sprintf( + "host=%s user=%s password=%s dbname=%s port=%s", + host, + user, + password, + dbName, + port, + ) +} + +func userDSNFromEnv(suffix string) string { + host := os.Getenv("DB_HOST") + user := os.Getenv("DB_USER") + password := os.Getenv("DB_PASSWORD") + dbName := os.Getenv("DB_NAME") + suffix + port := os.Getenv("DB_PORT") + cs := fmt.Sprintf( + "host=%s user=%s password=%s dbname=%s port=%s", + host, + user, + password, + dbName, + port, + ) + return cs +} diff --git a/xmlpostgresintegration_test.go b/xmlpostgresintegration_test.go new file mode 100644 index 0000000..10c64ca --- /dev/null +++ b/xmlpostgresintegration_test.go @@ -0,0 +1,117 @@ +package main + +import ( + "context" + _ "embed" + "log/slog" + "strings" + "testing" + + "github.com/dbsteward/dbsteward/lib" + "github.com/dbsteward/dbsteward/lib/encoding/xml" + "github.com/dbsteward/dbsteward/lib/format" + "github.com/dbsteward/dbsteward/lib/format/pgsql8" + "github.com/dbsteward/dbsteward/lib/ir" +) + +//go:embed example/someapp_v1.xml +var v1 string + +//go:embed example/someapp_v2.xml +var v2 string + +// To run: +// DB_HOST=localhost DB_USER=postgres DB_SUPERUSER=postgres DB_NAME=test DB_PORT=5432 go test ./... + +// This test uses the definitions in the examples table to +// create a database and then create a set of upgrade commands +// and ensure those commands work. It is limited: see the comments +// at the end of the test for explanation. +func TestXMLPostgresIngegration(t *testing.T) { + c := pgsql8.Initdb(t, "tl") + if c == nil { + t.SkipNow() + } + defer pgsql8.Teardowndb(t, c, "tl") + def1, err := xml.ReadDef(strings.NewReader(v1)) + if err != nil { + t.Fatal(err) + } + lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ + ir.SqlFormatPgsql8: pgsql8.GlobalLookup, + }) + lib.GlobalDBSteward.SqlFormat = ir.SqlFormatPgsql8 + err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.Application) + if err != nil { + t.Fatal(err) + } + err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.Owner) + if err != nil { + t.Fatal(err) + } + err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.ReadOnly) + if err != nil { + t.Fatal(err) + } + err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.Replication) + if err != nil { + t.Fatal(err) + } + ops := pgsql8.NewOperations().(*pgsql8.Operations) + statements, err := ops.CreateStatements(*def1) + if err != nil { + t.Fatal(err) + } + tx, err := c.Begin(context.TODO()) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(context.TODO()) + for _, s := range statements { + t.Log(s.Statement) + _, err = tx.Exec(context.TODO(), s.Statement) + if err != nil { + t.Fatal(err.Error()) + } + } + err = tx.Commit(context.TODO()) + if err != nil { + t.Fatal(err) + } + def2, err := xml.ReadDef(strings.NewReader(v2)) + if err != nil { + t.Fatal(err) + } + ops = pgsql8.NewOperations().(*pgsql8.Operations) + statements, err = ops.Upgrade(slog.Default(), def1, def2) + if err != nil { + t.Fatal(err) + } + tx, err = c.Begin(context.TODO()) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(context.TODO()) + for _, s := range statements { + t.Log(s.Statement) + _, err = tx.Exec(context.TODO(), s.Statement) + if err != nil { + t.Fatal(err.Error()) + } + } + err = tx.Commit(context.TODO()) + if err != nil { + t.Fatal(err) + } + ops = pgsql8.NewOperations().(*pgsql8.Operations) + _, err = ops.ExtractSchemaConn(context.TODO(), c) + if err != nil { + t.Fatal(err) + } + // It's impractical to verify that the extraction is + // correct without massively rewriting the XML. Due to + // differences in object ordering and other things that + // produce unequal but functionally equivalent code. + // It's probably best to do that level of precision + // testing at a more unit level. +}