Skip to content

Commit

Permalink
feat: add db package (#38)
Browse files Browse the repository at this point in the history
* feat: add sqlx package

* chore: fix package naming

* test: add postgres dockertest

* chore: fix review comments

* test: add test for WithTx

* test: add test for migrations

* chore: use iofs driver for migration
  • Loading branch information
ravisuhag authored Aug 30, 2022
1 parent 2550467 commit afb9357
Show file tree
Hide file tree
Showing 9 changed files with 1,467 additions and 27 deletions.
14 changes: 14 additions & 0 deletions db/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package db

import (
"time"
)

type Config struct {
Driver string `yaml:"driver" mapstructure:"driver"`
URL string `yaml:"url" mapstructure:"url"`
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" default:"10"`
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" default:"10"`
ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"10ms"`
MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"100ms"`
}
73 changes: 73 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package db

import (
"context"
"database/sql"
"fmt"
"time"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

type Client struct {
*sqlx.DB
queryTimeOut time.Duration
}

// NewClient creates a new sqlx database client
func New(cfg Config) (*Client, error) {
db, err := sqlx.Connect(cfg.Driver, cfg.URL)
if err != nil {
return nil, err
}

db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetConnMaxLifetime(cfg.ConnMaxLifeTime)

return &Client{DB: db, queryTimeOut: cfg.MaxQueryTimeout}, err
}

func (c Client) WithTimeout(ctx context.Context, op func(ctx context.Context) error) (err error) {
ctxWithTimeout, cancel := context.WithTimeout(ctx, c.queryTimeOut)
defer cancel()

return op(ctxWithTimeout)
}

func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc func(*sqlx.Tx) error) (err error) {
txn, err := c.BeginTxx(ctx, &txnOptions)
if err != nil {
return err
}

defer func() {
if p := recover(); p != nil {
switch p := p.(type) {
case error:
err = p
default:
err = errors.Errorf("%s", p)
}
err = txn.Rollback()
panic(p)
} else if err != nil {
if rlbErr := txn.Rollback(); err != nil {
err = fmt.Errorf("rollback error: %s while executing: %w", rlbErr, err)
} else {
err = fmt.Errorf("rollback: %w", err)
}
} else {
err = txn.Commit()
}
}()

err = txFunc(txn)
return err
}

// Close closes the database connection
func (c *Client) Close() error {
return c.DB.Close()
}
176 changes: 176 additions & 0 deletions db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package db_test

import (
"context"
"database/sql"
"fmt"
"log"
"os"
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/odpf/salt/db"
"github.com/ory/dockertest"
"github.com/ory/dockertest/docker"
"github.com/stretchr/testify/assert"
)

const (
dialect = "postgres"
user = "root"
password = "pass"
database = "postgres"
host = "localhost"
port = "5432"
dsn = "postgres://%s:%s@localhost:%s/%s?sslmode=disable"
)

var (
createTableQuery = "CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))"
dropTableQuery = "DROP TABLE IF EXISTS users"
checkTableQuery = "SELECT EXISTS(SELECT * FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users');"
)

var client *db.Client

func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}

opts := dockertest.RunOptions{
Repository: "postgres",
Tag: "14",
Env: []string{
"POSTGRES_USER=" + user,
"POSTGRES_PASSWORD=" + password,
"POSTGRES_DB=" + database,
},
ExposedPorts: []string{"5432"},
PortBindings: map[docker.Port][]docker.PortBinding{
"5432": {
{HostIP: "0.0.0.0", HostPort: port},
},
},
}

resource, err := pool.RunWithOptions(&opts, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start resource: %s", err.Error())
}

fmt.Println(resource.GetPort("5432/tcp"))

if err := resource.Expire(120); err != nil {
log.Fatalf("Could not expire resource: %s", err.Error())
}

pool.MaxWait = 60 * time.Second

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

if err = pool.Retry(func() error {
client, err = db.New(pgConfig)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err.Error())
}

defer func() {
client.Close()
}()

code := m.Run()

if err := pool.Purge(resource); err != nil {
log.Fatalf("Could not purge resource: %s", err)
}

os.Exit(code)
}

func TestWithTxn(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(dropTableQuery); err != nil {
return err
}

return nil
})
assert.NoError(t, err)

// Table should be dropped
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}

func TestWithTxnCommit(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
query2 := "SELECT 1"

err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(query2); err != nil {
return err
}

return nil
})
// WithTx should not return an error
assert.NoError(t, err)

// User table should exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, true, tableExist)
}

func TestWithTxnRollback(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
query2 := "WRONG QUERY"

err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(query2); err != nil {
return err
}

return nil
})
// WithTx should return an error
assert.Error(t, err)

// Table should not be created
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}
49 changes: 49 additions & 0 deletions db/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package db

import (
"fmt"
"io/fs"

"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database"
_ "github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/golang-migrate/migrate/v4/source/iofs"
)

func RunMigrations(config Config, embeddedMigrations fs.FS, resourcePath string) error {
m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
if err != nil {
return err
}

err = m.Up()
if err == migrate.ErrNoChange || err == nil {
return nil
}

return err
}

func RunRollback(config Config, embeddedMigrations fs.FS, resourcePath string) error {
m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
if err != nil {
return err
}

err = m.Steps(-1)
if err == migrate.ErrNoChange || err == nil {
return nil
}

return err
}

func getMigrationInstance(config Config, embeddedMigrations fs.FS, resourcePath string) (*migrate.Migrate, error) {
src, err := iofs.New(embeddedMigrations, resourcePath)
if err != nil {
return nil, fmt.Errorf("db migrator: %v", err)
}
return migrate.NewWithSourceInstance("iofs", src, config.URL)
}
60 changes: 60 additions & 0 deletions db/migrate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package db_test

import (
"embed"
"fmt"
"log"
"testing"

"github.com/odpf/salt/db"
"github.com/stretchr/testify/assert"
)

//go:embed migrations/*.sql
var migrationFs embed.FS

func TestRunMigrations(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

err := db.RunMigrations(pgConfig, migrationFs, "migrations")
assert.NoError(t, err)

// User table should exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, true, tableExist)
}

func TestRunRollback(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

err := db.RunRollback(pgConfig, migrationFs, "migrations")
assert.NoError(t, err)

// User table should not exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}
1 change: 1 addition & 0 deletions db/migrations/1481574547_create_users_table.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS users
1 change: 1 addition & 0 deletions db/migrations/1481574547_create_users_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))
Loading

0 comments on commit afb9357

Please sign in to comment.