From 0c4696977dbcf548060b4de07c87bab3d00bf90a Mon Sep 17 00:00:00 2001 From: Prasad Ghangal Date: Tue, 22 Aug 2023 17:03:08 +0530 Subject: [PATCH] Use postgres DB while connecting to RDS during restore (#2282) * Use postgres DB while connecting to RDS during restore Signed-off-by: Prasad Ghangal * Update pkg/function/export_rds_snapshot_location.go * Refactor RDS unit tests Signed-off-by: Prasad Ghangal --------- Signed-off-by: Prasad Ghangal Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- pkg/function/export_rds_snapshot_location.go | 62 ++++++++++++-------- pkg/function/rds_functions_test.go | 5 +- pkg/function/restore_rds_snapshot.go | 8 +-- pkg/postgres/client.go | 2 + 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/pkg/function/export_rds_snapshot_location.go b/pkg/function/export_rds_snapshot_location.go index 826c643d19..480221539e 100644 --- a/pkg/function/export_rds_snapshot_location.go +++ b/pkg/function/export_rds_snapshot_location.go @@ -250,48 +250,64 @@ func execDumpCommand(ctx context.Context, dbEngine RDSDBEngine, action RDSAction return kubeTask(ctx, cli, namespace, image, command, injectPostgresSecrets(secretName)) } -func prepareCommand(ctx context.Context, dbEngine RDSDBEngine, action RDSAction, dbEndpoint, username, password string, dbList []string, backupPrefix, backupID string, profile *param.Profile, dbEngineVersion string) ([]string, string, error) { +func prepareCommand( + ctx context.Context, + dbEngine RDSDBEngine, + action RDSAction, + dbEndpoint, + username, + password string, + dbList []string, + backupPrefix, + backupID string, + profile *param.Profile, + dbEngineVersion string, +) ([]string, string, error) { // Convert profile object into json profileJson, err := json.Marshal(profile) if err != nil { return nil, "", err } - // Find list of dbs - // For backup operation, if database arg is not set, we take backup of all databases - if dbList == nil { - // If no database is passed, we find list of all the existing databases - pg, err := postgres.NewClient(dbEndpoint, username, password, "postgres", "disable") - if err != nil { - return nil, "", errors.Wrap(err, "Error in creating postgres client") - } - - // Test DB connection - if err := pg.PingDB(ctx); err != nil { - return nil, "", errors.Wrap(err, "Failed to ping postgres database") - } - - dbList, err = pg.ListDatabases(ctx) - if err != nil { - return nil, "", errors.Wrap(err, "Error while listing databases") - } - dbList = filterRestrictedDB(dbList) - } - switch dbEngine { case PostgrSQLEngine: switch action { case BackupAction: + // For backup operation, if database arg is not set, we take backup of all databases + if dbList == nil { + dbList, err = findDBList(ctx, dbEndpoint, username, password) + if err != nil { + return nil, "", err + } + } command, err := postgresBackupCommand(dbEndpoint, username, password, dbList, backupPrefix, backupID, profileJson) return command, postgresToolsImage, err case RestoreAction: - command, err := postgresRestoreCommand(dbEndpoint, username, password, dbList, backupPrefix, backupID, profileJson, dbEngineVersion) + command, err := postgresRestoreCommand(dbEndpoint, username, password, backupPrefix, backupID, profileJson, dbEngineVersion) return command, postgresToolsImage, err } } return nil, "", errors.New("Invalid RDSDBEngine or RDSAction") } +func findDBList(ctx context.Context, dbEndpoint, username, password string) ([]string, error) { + pg, err := postgres.NewClient(dbEndpoint, username, password, postgres.DefaultConnectDatabase, "disable") + if err != nil { + return nil, errors.Wrap(err, "Error in creating postgres client") + } + + // Test DB connection + if err := pg.PingDB(ctx); err != nil { + return nil, errors.Wrap(err, "Failed to ping postgres database") + } + + dbList, err := pg.ListDatabases(ctx) + if err != nil { + return nil, errors.Wrap(err, "Error while listing databases") + } + return filterRestrictedDB(dbList), nil +} + //nolint:unparam func postgresBackupCommand(dbEndpoint, username, password string, dbList []string, backupPrefix, backupID string, profile []byte) ([]string, error) { if len(dbList) == 0 { diff --git a/pkg/function/rds_functions_test.go b/pkg/function/rds_functions_test.go index ac54c2e672..2884e5ca5e 100644 --- a/pkg/function/rds_functions_test.go +++ b/pkg/function/rds_functions_test.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/kanisterio/kanister/pkg/param" + "github.com/kanisterio/kanister/pkg/postgres" . "gopkg.in/check.v1" ) @@ -59,7 +60,7 @@ func (s *RDSFunctionsTest) TestPrepareCommand(c *C) { fmt.Sprintf(` export PGHOST=%s kando location pull --profile '%s' --path "%s" - | gunzip -c -f | sed 's/"LOCALE"/"LC_COLLATE"/' | psql -q -U "${PGUSER}" %s - `, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), []string{"template1"}[0]), + `, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), postgres.DefaultConnectDatabase), }, }, { @@ -78,7 +79,7 @@ func (s *RDSFunctionsTest) TestPrepareCommand(c *C) { fmt.Sprintf(` export PGHOST=%s kando location pull --profile '%s' --path "%s" - | gunzip -c -f | psql -q -U "${PGUSER}" %s - `, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), []string{"template1"}[0]), + `, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), postgres.DefaultConnectDatabase), }, }, { diff --git a/pkg/function/restore_rds_snapshot.go b/pkg/function/restore_rds_snapshot.go index d57926ae97..6d0afb6fb8 100644 --- a/pkg/function/restore_rds_snapshot.go +++ b/pkg/function/restore_rds_snapshot.go @@ -29,6 +29,7 @@ import ( "github.com/kanisterio/kanister/pkg/field" "github.com/kanisterio/kanister/pkg/log" "github.com/kanisterio/kanister/pkg/param" + "github.com/kanisterio/kanister/pkg/postgres" ) func init() { @@ -206,11 +207,8 @@ func restoreRDSSnapshot(ctx context.Context, namespace, instanceID, subnetGroup, } //nolint:unparam -func postgresRestoreCommand(pgHost, username, password string, dbList []string, backupArtifactPrefix, backupID string, profile []byte, dbEngineVersion string) ([]string, error) { +func postgresRestoreCommand(pgHost, username, password string, backupArtifactPrefix, backupID string, profile []byte, dbEngineVersion string) ([]string, error) { replaceCommand := "" - if len(dbList) == 0 { - return nil, errors.New("No database found. Atleast one db needed to connect") - } // check if PostgresDB version < 13 v1, err := version.NewVersion(dbEngineVersion) @@ -237,7 +235,7 @@ func postgresRestoreCommand(pgHost, username, password string, dbList []string, fmt.Sprintf(` export PGHOST=%s kando location pull --profile '%s' --path "%s" - | gunzip -c -f |%s psql -q -U "${PGUSER}" %s - `, pgHost, profile, fmt.Sprintf("%s/%s", backupArtifactPrefix, backupID), replaceCommand, dbList[0]), + `, pgHost, profile, fmt.Sprintf("%s/%s", backupArtifactPrefix, backupID), replaceCommand, postgres.DefaultConnectDatabase), }, nil } diff --git a/pkg/postgres/client.go b/pkg/postgres/client.go index 7ef5f22cbb..3c13f87238 100644 --- a/pkg/postgres/client.go +++ b/pkg/postgres/client.go @@ -23,6 +23,8 @@ import ( _ "github.com/lib/pq" ) +const DefaultConnectDatabase = "postgres" + // Client is postgres client to access postgres instance type Client struct { *sql.DB