From 9849c9a478a4246fea1fc1000a8d48059ed7d790 Mon Sep 17 00:00:00 2001 From: David Oduneye Date: Sat, 2 Mar 2024 00:30:38 -0500 Subject: [PATCH] fix test cleaner --- cli/commands/clean_tests.go | 112 ++++++++++++++++++------------------ cli/commands/test.go | 3 +- 2 files changed, 57 insertions(+), 58 deletions(-) diff --git a/cli/commands/clean_tests.go b/cli/commands/clean_tests.go index e3d4c6973..5dbd14f28 100644 --- a/cli/commands/clean_tests.go +++ b/cli/commands/clean_tests.go @@ -1,87 +1,85 @@ package commands import ( + "context" "database/sql" "fmt" "os/user" "sync" + "time" _ "github.com/lib/pq" "github.com/urfave/cli/v2" ) func ClearDBCommand() *cli.Command { - command := cli.Command{ - Name: "clean", - Category: "Database Operations", - Aliases: []string{"c"}, - Usage: "Remove databases used for testing", - Action: func(c *cli.Context) error { - if c.Args().Len() > 0 { - return cli.Exit("Invalid arguments", 1) - } - - err := CleanTestDBs() - if err != nil { - return cli.Exit(err.Error(), 1) - } - - return nil - }, - } - - return &command + command := cli.Command{ + Name: "clean", + Category: "Database Operations", + Aliases: []string{"c"}, + Usage: "Remove databases used for testing", + Action: func(c *cli.Context) error { + if c.Args().Len() > 0 { + return cli.Exit("Invalid arguments", 1) + } + + return CleanTestDBs(context.Background()) + }, + } + + return &command } -func CleanTestDBs() error { - fmt.Println("Cleaning test databases") +func CleanTestDBs(ctx context.Context) error { + fmt.Println("Cleaning test databases") - db, err := sql.Open("postgres", CONFIG.Database.WithDb()) - if err != nil { - return err - } + db, err := sql.Open("postgres", CONFIG.Database.WithDb()) + if err != nil { + return err + } - defer db.Close() + defer db.Close() - currentUser, err := user.Current() - if err != nil { - return fmt.Errorf("failed to get current user: %w", err) - } + currentUser, err := user.Current() + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } - rows, err := db.Query("SELECT datname FROM pg_database WHERE datistemplate = false AND datname != 'postgres' AND datname != $1 AND datname != $2 AND datname LIKE 'sac_test_%';", currentUser.Username, CONFIG.Database.DatabaseName) - if err != nil { - return err - } + query := `SELECT datname FROM pg_database WHERE datistemplate = false AND datname != 'postgres' AND datname != $1 AND datname != $2 AND datname LIKE 'sac_test_%';` + rows, err := db.QueryContext(ctx, query, currentUser.Username, CONFIG.Database.DatabaseName) - defer rows.Close() + if err != nil { + return err + } - var wg sync.WaitGroup + defer rows.Close() - for rows.Next() { - var dbName string + var wg sync.WaitGroup - if err := rows.Scan(&dbName); err != nil { - return err - } + for rows.Next() { + var dbName string - wg.Add(1) + if err := rows.Scan(&dbName); err != nil { + return err + } - go func(dbName string) { - defer wg.Done() + wg.Add(1) + go func(dbName string) { + defer wg.Done() - fmt.Printf("Dropping database %s\n", dbName) + fmt.Printf("Dropping database %s\n", dbName) - if _, err := db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)); err != nil { - fmt.Printf("Failed to drop database %s: %v\n", dbName, err) - } - }(dbName) - } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // Set a timeout for each drop operation + defer cancel() - if err := rows.Err(); err != nil { - return err - } + _, err := db.ExecContext(ctx, fmt.Sprintf("DROP DATABASE %s", dbName)) + if err != nil { + fmt.Printf("Failed to drop database %s: %v\n", dbName, err) + } + }(dbName) + } - wg.Wait() + wg.Wait() - return nil -} + return nil +} \ No newline at end of file diff --git a/cli/commands/test.go b/cli/commands/test.go index be12543c4..7d25bead9 100644 --- a/cli/commands/test.go +++ b/cli/commands/test.go @@ -1,6 +1,7 @@ package commands import ( + "context" "fmt" "os/exec" "sync" @@ -105,7 +106,7 @@ func BackendTest() error { fmt.Println(string(out)) - err = CleanTestDBs() + err = CleanTestDBs(context.Background()) if err != nil { return cli.Exit(err.Error(), 1) }