Skip to content

Commit

Permalink
Use postgres DB while connecting to RDS during restore (#2282)
Browse files Browse the repository at this point in the history
* Use postgres DB while connecting to RDS during restore

Signed-off-by: Prasad Ghangal <[email protected]>

* Update pkg/function/export_rds_snapshot_location.go

* Refactor RDS unit tests

Signed-off-by: Prasad Ghangal <[email protected]>

---------

Signed-off-by: Prasad Ghangal <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
PrasadG193 and mergify[bot] authored Aug 22, 2023
1 parent f4aabaf commit 0c46969
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 30 deletions.
62 changes: 39 additions & 23 deletions pkg/function/export_rds_snapshot_location.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,48 +250,64 @@ func execDumpCommand(ctx context.Context, dbEngine RDSDBEngine, action RDSAction
return kubeTask(ctx, cli, namespace, image, command, injectPostgresSecrets(secretName))
}

func prepareCommand(ctx context.Context, dbEngine RDSDBEngine, action RDSAction, dbEndpoint, username, password string, dbList []string, backupPrefix, backupID string, profile *param.Profile, dbEngineVersion string) ([]string, string, error) {
func prepareCommand(
ctx context.Context,
dbEngine RDSDBEngine,
action RDSAction,
dbEndpoint,
username,
password string,
dbList []string,
backupPrefix,
backupID string,
profile *param.Profile,
dbEngineVersion string,
) ([]string, string, error) {
// Convert profile object into json
profileJson, err := json.Marshal(profile)
if err != nil {
return nil, "", err
}

// Find list of dbs
// For backup operation, if database arg is not set, we take backup of all databases
if dbList == nil {
// If no database is passed, we find list of all the existing databases
pg, err := postgres.NewClient(dbEndpoint, username, password, "postgres", "disable")
if err != nil {
return nil, "", errors.Wrap(err, "Error in creating postgres client")
}

// Test DB connection
if err := pg.PingDB(ctx); err != nil {
return nil, "", errors.Wrap(err, "Failed to ping postgres database")
}

dbList, err = pg.ListDatabases(ctx)
if err != nil {
return nil, "", errors.Wrap(err, "Error while listing databases")
}
dbList = filterRestrictedDB(dbList)
}

switch dbEngine {
case PostgrSQLEngine:
switch action {
case BackupAction:
// For backup operation, if database arg is not set, we take backup of all databases
if dbList == nil {
dbList, err = findDBList(ctx, dbEndpoint, username, password)
if err != nil {
return nil, "", err
}
}
command, err := postgresBackupCommand(dbEndpoint, username, password, dbList, backupPrefix, backupID, profileJson)
return command, postgresToolsImage, err
case RestoreAction:
command, err := postgresRestoreCommand(dbEndpoint, username, password, dbList, backupPrefix, backupID, profileJson, dbEngineVersion)
command, err := postgresRestoreCommand(dbEndpoint, username, password, backupPrefix, backupID, profileJson, dbEngineVersion)
return command, postgresToolsImage, err
}
}
return nil, "", errors.New("Invalid RDSDBEngine or RDSAction")
}

func findDBList(ctx context.Context, dbEndpoint, username, password string) ([]string, error) {
pg, err := postgres.NewClient(dbEndpoint, username, password, postgres.DefaultConnectDatabase, "disable")
if err != nil {
return nil, errors.Wrap(err, "Error in creating postgres client")
}

// Test DB connection
if err := pg.PingDB(ctx); err != nil {
return nil, errors.Wrap(err, "Failed to ping postgres database")
}

dbList, err := pg.ListDatabases(ctx)
if err != nil {
return nil, errors.Wrap(err, "Error while listing databases")
}
return filterRestrictedDB(dbList), nil
}

//nolint:unparam
func postgresBackupCommand(dbEndpoint, username, password string, dbList []string, backupPrefix, backupID string, profile []byte) ([]string, error) {
if len(dbList) == 0 {
Expand Down
5 changes: 3 additions & 2 deletions pkg/function/rds_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

"github.com/kanisterio/kanister/pkg/param"
"github.com/kanisterio/kanister/pkg/postgres"
. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -59,7 +60,7 @@ func (s *RDSFunctionsTest) TestPrepareCommand(c *C) {
fmt.Sprintf(`
export PGHOST=%s
kando location pull --profile '%s' --path "%s" - | gunzip -c -f | sed 's/"LOCALE"/"LC_COLLATE"/' | psql -q -U "${PGUSER}" %s
`, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), []string{"template1"}[0]),
`, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), postgres.DefaultConnectDatabase),
},
},
{
Expand All @@ -78,7 +79,7 @@ func (s *RDSFunctionsTest) TestPrepareCommand(c *C) {
fmt.Sprintf(`
export PGHOST=%s
kando location pull --profile '%s' --path "%s" - | gunzip -c -f | psql -q -U "${PGUSER}" %s
`, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), []string{"template1"}[0]),
`, "db-endpoint", "null", fmt.Sprintf("%s/%s", "/backup/postgres-backup", "backup-id"), postgres.DefaultConnectDatabase),
},
},
{
Expand Down
8 changes: 3 additions & 5 deletions pkg/function/restore_rds_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/kanisterio/kanister/pkg/field"
"github.com/kanisterio/kanister/pkg/log"
"github.com/kanisterio/kanister/pkg/param"
"github.com/kanisterio/kanister/pkg/postgres"
)

func init() {
Expand Down Expand Up @@ -206,11 +207,8 @@ func restoreRDSSnapshot(ctx context.Context, namespace, instanceID, subnetGroup,
}

//nolint:unparam
func postgresRestoreCommand(pgHost, username, password string, dbList []string, backupArtifactPrefix, backupID string, profile []byte, dbEngineVersion string) ([]string, error) {
func postgresRestoreCommand(pgHost, username, password string, backupArtifactPrefix, backupID string, profile []byte, dbEngineVersion string) ([]string, error) {
replaceCommand := ""
if len(dbList) == 0 {
return nil, errors.New("No database found. Atleast one db needed to connect")
}

// check if PostgresDB version < 13
v1, err := version.NewVersion(dbEngineVersion)
Expand All @@ -237,7 +235,7 @@ func postgresRestoreCommand(pgHost, username, password string, dbList []string,
fmt.Sprintf(`
export PGHOST=%s
kando location pull --profile '%s' --path "%s" - | gunzip -c -f |%s psql -q -U "${PGUSER}" %s
`, pgHost, profile, fmt.Sprintf("%s/%s", backupArtifactPrefix, backupID), replaceCommand, dbList[0]),
`, pgHost, profile, fmt.Sprintf("%s/%s", backupArtifactPrefix, backupID), replaceCommand, postgres.DefaultConnectDatabase),
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
_ "github.com/lib/pq"
)

const DefaultConnectDatabase = "postgres"

// Client is postgres client to access postgres instance
type Client struct {
*sql.DB
Expand Down

0 comments on commit 0c46969

Please sign in to comment.