diff --git a/pkg/db/pgx.go b/pkg/db/pgx.go index dc78b711..46549083 100644 --- a/pkg/db/pgx.go +++ b/pkg/db/pgx.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -11,11 +12,19 @@ import ( "github.com/xmtp/xmtpd/pkg/migrations" ) -func newPGXDB( - ctx context.Context, - dsn string, - waitForDB, statementTimeout time.Duration, -) (*sql.DB, error) { +func waitUntilDBReady(ctx context.Context, db *pgxpool.Pool, waitTime time.Duration) error { + waitUntil := time.Now().Add(waitTime) + + err := db.Ping(ctx) + + for err != nil && time.Now().Before(waitUntil) { + time.Sleep(3 * time.Second) + err = db.Ping(ctx) + } + return err +} + +func parseConfig(dsn string, statementTimeout time.Duration) (*pgxpool.Config, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err @@ -24,23 +33,93 @@ func newPGXDB( config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint( statementTimeout.Milliseconds(), ) + return config, nil +} +func newPGXDB( + ctx context.Context, + config *pgxpool.Config, + waitForDB time.Duration, +) (*sql.DB, error) { dbPool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { return nil, err } + if err = waitUntilDBReady(ctx, dbPool, waitForDB); err != nil { + return nil, err + } + db := stdlib.OpenDBFromPool(dbPool) - waitUntil := time.Now().Add(waitForDB) + return db, nil +} - err = db.Ping() - for err != nil && time.Now().Before(waitUntil) { - time.Sleep(3 * time.Second) - err = db.Ping() +// Creates a new database with the given namespace if it doesn't exist +func createNamespace( + ctx context.Context, + config *pgxpool.Config, + namespace string, + waitForDB time.Duration, +) error { + // Make a copy of the config so we don't dirty it + config = config.Copy() + // Change the database to postgres so we are able to create new DBs + config.ConnConfig.Database = "postgres" + + // Create a temporary connection to the postgres DB + adminConn, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return fmt.Errorf("failed to connect to postgres: %w", err) + } + defer adminConn.Close() + + if err = waitUntilDBReady(ctx, adminConn, waitForDB); err != nil { + return err + } + + // Create database if it doesn't exist + _, err = adminConn.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, namespace)) + if err != nil { + // Ignore error if database already exists + if !strings.Contains(err.Error(), "already exists") { + return fmt.Errorf("failed to create database: %w", err) + } } - return db, err + return nil +} + +// Creates a new database with the given namespace if it doesn't exist and returns the full DSN for the new database. +func NewNamespacedDB( + ctx context.Context, + dsn string, + namespace string, + waitForDB, statementTimeout time.Duration, +) (*sql.DB, error) { + // Parse the DSN to get the config + config, err := parseConfig(dsn, statementTimeout) + if err != nil { + return nil, fmt.Errorf("failed to parse DSN: %w", err) + } + + if err = createNamespace(ctx, config, namespace, waitForDB); err != nil { + return nil, err + } + + config.ConnConfig.Database = namespace + + db, err := newPGXDB(ctx, config, waitForDB) + if err != nil { + return nil, err + } + + err = migrations.Migrate(ctx, db) + if err != nil { + return nil, err + } + + return db, nil } func NewDB( @@ -48,7 +127,12 @@ func NewDB( dsn string, waitForDB, statementTimeout time.Duration, ) (*sql.DB, error) { - db, err := newPGXDB(ctx, dsn, waitForDB, statementTimeout) + config, err := parseConfig(dsn, statementTimeout) + if err != nil { + return nil, err + } + + db, err := newPGXDB(ctx, config, waitForDB) if err != nil { return nil, err } diff --git a/pkg/db/pgx_test.go b/pkg/db/pgx_test.go new file mode 100644 index 00000000..53b1bfb7 --- /dev/null +++ b/pkg/db/pgx_test.go @@ -0,0 +1,63 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/testutils" +) + +func TestNamespacedDB(t *testing.T) { + startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix + newDBName := testutils.RandomString(32) + // Create namespaced DB + namespacedDB, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + t.Cleanup(func() { namespacedDB.Close() }) + require.NoError(t, err) + + result, err := namespacedDB.Query("SELECT current_database();") + require.NoError(t, err) + defer result.Close() + + require.True(t, result.Next()) + var dbName string + err = result.Scan(&dbName) + require.NoError(t, err) + require.Equal(t, newDBName, dbName) +} + +func TestNamespaceRepeat(t *testing.T) { + startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix + newDBName := testutils.RandomString(32) + // Create namespaced DB + db1, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + require.NoError(t, err) + require.NotNil(t, db1) + t.Cleanup(func() { db1.Close() }) + + // Create again with the same name + db2, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + require.NoError(t, err) + require.NotNil(t, db2) + t.Cleanup(func() { db2.Close() }) +} diff --git a/pkg/testutils/store.go b/pkg/testutils/store.go index e9d1f6b5..44d108a0 100644 --- a/pkg/testutils/store.go +++ b/pkg/testutils/store.go @@ -13,8 +13,8 @@ import ( ) const ( - localTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765" - localTestDBDSNSuffix = "?sslmode=disable" + LocalTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765" + LocalTestDBDSNSuffix = "?sslmode=disable" ) func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) { @@ -28,7 +28,7 @@ func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) { } func newCtlDB(t testing.TB) (*sql.DB, string, func()) { - return openDB(t, localTestDBDSNPrefix+localTestDBDSNSuffix) + return openDB(t, LocalTestDBDSNPrefix+LocalTestDBDSNSuffix) } func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, string, func()) { @@ -36,7 +36,7 @@ func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, s _, err := ctlDB.Exec("CREATE DATABASE " + dbName) require.NoError(t, err) - db, dsn, cleanup := openDB(t, localTestDBDSNPrefix+"/"+dbName+localTestDBDSNSuffix) + db, dsn, cleanup := openDB(t, LocalTestDBDSNPrefix+"/"+dbName+LocalTestDBDSNSuffix) require.NoError(t, migrations.Migrate(ctx, db)) return db, dsn, func() {