Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add db package #38

Merged
merged 8 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions db/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package db

import (
"time"
)

type Config struct {
Driver string `yaml:"driver" mapstructure:"driver"`
URL string `yaml:"url" mapstructure:"url"`
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" default:"10"`
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" default:"10"`
ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"10ms"`
MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"100ms"`
}
73 changes: 73 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package db

import (
"context"
"database/sql"
"fmt"
"time"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

type Client struct {
*sqlx.DB
queryTimeOut time.Duration
}

// NewClient creates a new sqlx database client
func New(cfg Config) (*Client, error) {
db, err := sqlx.Connect(cfg.Driver, cfg.URL)
if err != nil {
return nil, err
}

db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetConnMaxLifetime(cfg.ConnMaxLifeTime)

return &Client{DB: db, queryTimeOut: cfg.MaxQueryTimeout}, err
}

func (c Client) WithTimeout(ctx context.Context, op func(ctx context.Context) error) (err error) {
ctxWithTimeout, cancel := context.WithTimeout(ctx, c.queryTimeOut)
defer cancel()

return op(ctxWithTimeout)
}

func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc func(*sqlx.Tx) error) (err error) {
txn, err := c.BeginTxx(ctx, &txnOptions)
if err != nil {
return err
}

defer func() {
if p := recover(); p != nil {
switch p := p.(type) {
case error:
err = p
default:
err = errors.Errorf("%s", p)
}
err = txn.Rollback()
panic(p)
} else if err != nil {
if rlbErr := txn.Rollback(); err != nil {
err = fmt.Errorf("rollback error: %s while executing: %w", rlbErr, err)
} else {
err = fmt.Errorf("rollback: %w", err)
}
} else {
err = txn.Commit()
}
}()

err = txFunc(txn)
return err
}

// Close closes the database connection
func (c *Client) Close() error {
return c.DB.Close()
}
176 changes: 176 additions & 0 deletions db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package db_test

import (
"context"
"database/sql"
"fmt"
"log"
"os"
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/odpf/salt/db"
"github.com/ory/dockertest"
"github.com/ory/dockertest/docker"
"github.com/stretchr/testify/assert"
)

const (
dialect = "postgres"
user = "root"
password = "pass"
database = "postgres"
host = "localhost"
port = "5432"
dsn = "postgres://%s:%s@localhost:%s/%s?sslmode=disable"
)

var (
createTableQuery = "CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))"
dropTableQuery = "DROP TABLE IF EXISTS users"
checkTableQuery = "SELECT EXISTS(SELECT * FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users');"
)

var client *db.Client

func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}

opts := dockertest.RunOptions{
Repository: "postgres",
Tag: "14",
Env: []string{
"POSTGRES_USER=" + user,
"POSTGRES_PASSWORD=" + password,
"POSTGRES_DB=" + database,
},
ExposedPorts: []string{"5432"},
PortBindings: map[docker.Port][]docker.PortBinding{
"5432": {
{HostIP: "0.0.0.0", HostPort: port},
},
},
}

resource, err := pool.RunWithOptions(&opts, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start resource: %s", err.Error())
}

fmt.Println(resource.GetPort("5432/tcp"))

if err := resource.Expire(120); err != nil {
log.Fatalf("Could not expire resource: %s", err.Error())
}

pool.MaxWait = 60 * time.Second

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

if err = pool.Retry(func() error {
client, err = db.New(pgConfig)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err.Error())
}

defer func() {
client.Close()
}()

code := m.Run()

if err := pool.Purge(resource); err != nil {
log.Fatalf("Could not purge resource: %s", err)
}

os.Exit(code)
}

func TestWithTxn(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(dropTableQuery); err != nil {
return err
}

return nil
})
assert.NoError(t, err)

// Table should be dropped
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}

func TestWithTxnCommit(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
query2 := "SELECT 1"

err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(query2); err != nil {
return err
}

return nil
})
// WithTx should not return an error
assert.NoError(t, err)

// User table should exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, true, tableExist)
}

func TestWithTxnRollback(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}
query2 := "WRONG QUERY"

err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(createTableQuery); err != nil {
return err
}
if _, err := tx.Exec(query2); err != nil {
return err
}

return nil
})
// WithTx should return an error
assert.Error(t, err)

// Table should not be created
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}
49 changes: 49 additions & 0 deletions db/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package db

import (
"fmt"
"io/fs"

"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database"
_ "github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/golang-migrate/migrate/v4/source/iofs"
)

func RunMigrations(config Config, embeddedMigrations fs.FS, resourcePath string) error {
m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
if err != nil {
return err
}

err = m.Up()
if err == migrate.ErrNoChange || err == nil {
return nil
}

return err
}

func RunRollback(config Config, embeddedMigrations fs.FS, resourcePath string) error {
m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
if err != nil {
return err
}

err = m.Steps(-1)
if err == migrate.ErrNoChange || err == nil {
return nil
}

return err
}

func getMigrationInstance(config Config, embeddedMigrations fs.FS, resourcePath string) (*migrate.Migrate, error) {
src, err := iofs.New(embeddedMigrations, resourcePath)
if err != nil {
return nil, fmt.Errorf("db migrator: %v", err)
}
return migrate.NewWithSourceInstance("iofs", src, config.URL)
}
60 changes: 60 additions & 0 deletions db/migrate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package db_test

import (
"embed"
"fmt"
"log"
"testing"

"github.com/odpf/salt/db"
"github.com/stretchr/testify/assert"
)

//go:embed migrations/*.sql
var migrationFs embed.FS

func TestRunMigrations(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

err := db.RunMigrations(pgConfig, migrationFs, "migrations")
assert.NoError(t, err)

// User table should exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, true, tableExist)
}

func TestRunRollback(t *testing.T) {
if _, err := client.Exec(dropTableQuery); err != nil {
log.Fatalf("Could not cleanup: %s", err)
}

dsn := fmt.Sprintf(dsn, user, password, port, database)
var (
pgConfig = db.Config{
Driver: "postgres",
URL: dsn,
}
)

err := db.RunRollback(pgConfig, migrationFs, "migrations")
assert.NoError(t, err)

// User table should not exist
var tableExist bool
result := client.QueryRow(checkTableQuery)
result.Scan(&tableExist)
assert.Equal(t, false, tableExist)
}
1 change: 1 addition & 0 deletions db/migrations/1481574547_create_users_table.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS users
1 change: 1 addition & 0 deletions db/migrations/1481574547_create_users_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))
Loading