diff --git a/pkg/db/pgx.go b/pkg/db/pgx.go index dc78b711..1d36261b 100644 --- a/pkg/db/pgx.go +++ b/pkg/db/pgx.go @@ -3,19 +3,34 @@ package db import ( "context" "database/sql" + "errors" "fmt" + "regexp" "time" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" "github.com/xmtp/xmtpd/pkg/migrations" ) -func newPGXDB( - ctx context.Context, - dsn string, - waitForDB, statementTimeout time.Duration, -) (*sql.DB, error) { +const MAX_NAMESPACE_LENGTH = 32 + +var allowedNamespaceRe = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +func waitUntilDBReady(ctx context.Context, db *pgxpool.Pool, waitTime time.Duration) error { + waitUntil := time.Now().Add(waitTime) + + err := db.Ping(ctx) + + for err != nil && time.Now().Before(waitUntil) { + time.Sleep(3 * time.Second) + err = db.Ping(ctx) + } + return err +} + +func parseConfig(dsn string, statementTimeout time.Duration) (*pgxpool.Config, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err @@ -24,23 +39,118 @@ func newPGXDB( config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint( statementTimeout.Milliseconds(), ) + return config, nil +} +func newPGXDB( + ctx context.Context, + config *pgxpool.Config, + waitForDB time.Duration, +) (*sql.DB, error) { dbPool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { return nil, err } + if err = waitUntilDBReady(ctx, dbPool, waitForDB); err != nil { + return nil, err + } + db := stdlib.OpenDBFromPool(dbPool) - waitUntil := time.Now().Add(waitForDB) + return db, nil +} - err = db.Ping() - for err != nil && time.Now().Before(waitUntil) { - time.Sleep(3 * time.Second) - err = db.Ping() +func isValidNamespace(namespace string) error { + if len(namespace) == 0 || len(namespace) > MAX_NAMESPACE_LENGTH { + return fmt.Errorf( + "namespace length must be between 1 and %d characters", + MAX_NAMESPACE_LENGTH, + ) + } + // PostgreSQL identifiers must start with a letter or underscore + if !allowedNamespaceRe.MatchString(namespace) { + return fmt.Errorf( + "namespace must start with a letter or underscore and contain only letters, numbers, and underscores", + ) + } + return nil +} + +// Creates a new database with the given namespace if it doesn't exist +func createNamespace( + ctx context.Context, + config *pgxpool.Config, + namespace string, + waitForDB time.Duration, +) error { + if err := isValidNamespace(namespace); err != nil { + return err + } + + // Make a copy of the config so we don't dirty it + config = config.Copy() + // Change the database to postgres so we are able to create new DBs + config.ConnConfig.Database = "postgres" + + // Create a temporary connection to the postgres DB + adminConn, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return fmt.Errorf("failed to connect to postgres: %w", err) } + defer adminConn.Close() - return db, err + if err = waitUntilDBReady(ctx, adminConn, waitForDB); err != nil { + return err + } + + // Create database if it doesn't exist + _, err = adminConn.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, namespace)) + if err != nil { + // Ignore error if database already exists + var pgErr *pgconn.PgError + // Error code 42P04 is for "duplicate database" + // https://www.postgresql.org/docs/current/errcodes-appendix.html + if errors.As(err, &pgErr) && pgErr.Code == "42P04" { + return nil + } + + return fmt.Errorf("failed to create database: %w", err) + } + + return nil +} + +// Creates a new database with the given namespace if it doesn't exist and returns the full DSN for the new database. +func NewNamespacedDB( + ctx context.Context, + dsn string, + namespace string, + waitForDB, statementTimeout time.Duration, +) (*sql.DB, error) { + // Parse the DSN to get the config + config, err := parseConfig(dsn, statementTimeout) + if err != nil { + return nil, fmt.Errorf("failed to parse DSN: %w", err) + } + + if err = createNamespace(ctx, config, namespace, waitForDB); err != nil { + return nil, err + } + + config.ConnConfig.Database = namespace + + db, err := newPGXDB(ctx, config, waitForDB) + if err != nil { + return nil, err + } + + err = migrations.Migrate(ctx, db) + if err != nil { + return nil, err + } + + return db, nil } func NewDB( @@ -48,7 +158,12 @@ func NewDB( dsn string, waitForDB, statementTimeout time.Duration, ) (*sql.DB, error) { - db, err := newPGXDB(ctx, dsn, waitForDB, statementTimeout) + config, err := parseConfig(dsn, statementTimeout) + if err != nil { + return nil, err + } + + db, err := newPGXDB(ctx, config, waitForDB) if err != nil { return nil, err } diff --git a/pkg/db/pgx_test.go b/pkg/db/pgx_test.go new file mode 100644 index 00000000..cb7c33b9 --- /dev/null +++ b/pkg/db/pgx_test.go @@ -0,0 +1,85 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/testutils" +) + +func TestNamespacedDB(t *testing.T) { + startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix + newDBName := "xmtp_" + testutils.RandomString(24) + // Create namespaced DB + namespacedDB, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + t.Cleanup(func() { namespacedDB.Close() }) + require.NoError(t, err) + + result, err := namespacedDB.Query("SELECT current_database();") + require.NoError(t, err) + defer result.Close() + + require.True(t, result.Next()) + var dbName string + err = result.Scan(&dbName) + require.NoError(t, err) + require.Equal(t, newDBName, dbName) +} + +func TestNamespaceRepeat(t *testing.T) { + startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix + newDBName := "xmtp_" + testutils.RandomString(24) + // Create namespaced DB + db1, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + require.NoError(t, err) + require.NotNil(t, db1) + t.Cleanup(func() { db1.Close() }) + + // Create again with the same name + db2, err := NewNamespacedDB( + context.Background(), + startingDsn, + newDBName, + time.Second, + time.Second, + ) + require.NoError(t, err) + require.NotNil(t, db2) + t.Cleanup(func() { db2.Close() }) +} + +func TestNamespacedDBInvalidName(t *testing.T) { + _, err := NewNamespacedDB( + context.Background(), + testutils.LocalTestDBDSNPrefix+"/foo"+testutils.LocalTestDBDSNSuffix, + "invalid/name", + time.Second, + time.Second, + ) + require.Error(t, err) +} + +func TestNamespacedDBInvalidDSN(t *testing.T) { + _, err := NewNamespacedDB( + context.Background(), + "invalid-dsn", + "dbname", + time.Second, + time.Second, + ) + require.Error(t, err) +} diff --git a/pkg/testutils/store.go b/pkg/testutils/store.go index e9d1f6b5..44d108a0 100644 --- a/pkg/testutils/store.go +++ b/pkg/testutils/store.go @@ -13,8 +13,8 @@ import ( ) const ( - localTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765" - localTestDBDSNSuffix = "?sslmode=disable" + LocalTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765" + LocalTestDBDSNSuffix = "?sslmode=disable" ) func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) { @@ -28,7 +28,7 @@ func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) { } func newCtlDB(t testing.TB) (*sql.DB, string, func()) { - return openDB(t, localTestDBDSNPrefix+localTestDBDSNSuffix) + return openDB(t, LocalTestDBDSNPrefix+LocalTestDBDSNSuffix) } func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, string, func()) { @@ -36,7 +36,7 @@ func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, s _, err := ctlDB.Exec("CREATE DATABASE " + dbName) require.NoError(t, err) - db, dsn, cleanup := openDB(t, localTestDBDSNPrefix+"/"+dbName+localTestDBDSNSuffix) + db, dsn, cleanup := openDB(t, LocalTestDBDSNPrefix+"/"+dbName+LocalTestDBDSNSuffix) require.NoError(t, migrations.Migrate(ctx, db)) return db, dsn, func() {