Skip to content

Commit

Permalink
Support Cloud SQL connections (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhutchinson authored Sep 21, 2023
1 parent 48dc351 commit dd6e1bb
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 13 deletions.
67 changes: 56 additions & 11 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"context"
"database/sql"
"flag"
"fmt"
"net"
"net/http"
"os"

"cloud.google.com/go/cloudsqlconn"
"github.com/golang/glog"
"github.com/gorilla/mux"
"github.com/transparency-dev/distributor/cmd/internal/distributor"
Expand All @@ -36,12 +38,14 @@ import (

_ "embed"

"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
)

var (
addr = flag.String("listen", ":8080", "Address to listen on")
mysqlURI = flag.String("mysql_uri", "", "URI for MySQL DB")
addr = flag.String("listen", ":8080", "Address to listen on")
useCloudSql = flag.Bool("use_cloud_sql", false, "Set to true to set up the DB connection using cloudsql connection. This will ignore mysql_uri and generate it from env variables.")
mysqlURI = flag.String("mysql_uri", "", "URI for MySQL DB")

witnessConfigFile = flag.String("witness_config_file", "", "Path to a file containing the public keys of allowed witnesses")

Expand All @@ -65,15 +69,7 @@ func main() {

ws := getWitnessesOrDie()
ls := getLogsOrDie()

if len(*mysqlURI) == 0 {
glog.Exitf("mysql_uri is required")
}
glog.Infof("Connecting to DB at %q", *mysqlURI)
db, err := sql.Open("mysql", *mysqlURI)
if err != nil {
glog.Exitf("Failed to connect to DB: %v", err)
}
db := getDatabaseOrDie()

d, err := distributor.NewDistributor(ws, ls, db)
if err != nil {
Expand Down Expand Up @@ -106,6 +102,55 @@ func main() {
}
}

func getDatabaseOrDie() *sql.DB {
if *useCloudSql {
return getCloudSqlOrDie()
}
if len(*mysqlURI) == 0 {
glog.Exitf("mysql_uri is required")
}
glog.Infof("Connecting to DB at %q", *mysqlURI)
db, err := sql.Open("mysql", *mysqlURI)
if err != nil {
glog.Exitf("Failed to connect to DB: %v", err)
}
return db
}

func getCloudSqlOrDie() *sql.DB {
mustGetenv := func(k string) string {
v := os.Getenv(k)
if v == "" {
glog.Exitf("Failed precondition: %s environment variable not set.", k)
}
return v
}
var (
dbUser = mustGetenv("DB_USER") // e.g. 'my-db-user'
dbPwd = mustGetenv("DB_PASS") // e.g. 'my-db-password'
dbName = mustGetenv("DB_NAME") // e.g. 'my-database'
instanceConnectionName = mustGetenv("INSTANCE_CONNECTION_NAME") // e.g. 'project:region:instance'
)

d, err := cloudsqlconn.NewDialer(context.Background())
if err != nil {
glog.Exitf("cloudsqlconn.NewDialer: %w", err)
}
var opts []cloudsqlconn.DialOption
mysql.RegisterDialContext("cloudsqlconn",
func(ctx context.Context, addr string) (net.Conn, error) {
return d.Dial(ctx, instanceConnectionName, opts...)
})

dbURI := fmt.Sprintf("%s:%s@cloudsqlconn(localhost:3306)/%s", dbUser, dbPwd, dbName)

dbPool, err := sql.Open("mysql", dbURI)
if err != nil {
glog.Exitf("sql.Open: %w", err)
}
return dbPool
}

func getLogsOrDie() map[string]distributor.LogInfo {
logsCfg := logsConfig{}
if err := yaml.Unmarshal(configLogs, &logsCfg); err != nil {
Expand Down
20 changes: 18 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ require (
)

require (
cloud.google.com/go/cloudsqlconn v1.4.4 // indirect
cloud.google.com/go/compute v1.23.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
Expand All @@ -29,8 +32,13 @@ require (
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/imdario/mergo v0.3.15 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/term v0.0.0-20221205130635-1aeaba878587 // indirect
Expand All @@ -42,9 +50,17 @@ require (
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
golang.org/x/sys v0.10.0 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/oauth2 v0.12.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.8.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
google.golang.org/api v0.140.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230911183012-2d3300fd4832 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
Loading

0 comments on commit dd6e1bb

Please sign in to comment.