Skip to content

Commit

Permalink
add support for reading config file using viper, fixes kubeflow#12
Browse files Browse the repository at this point in the history
  • Loading branch information
dhirajsb committed Oct 4, 2023
1 parent 4211667 commit cabad69
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 24 deletions.
12 changes: 12 additions & 0 deletions cmd/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package cmd

type Config struct {
DbFile string `mapstructure:"db-file" yaml:"db-file"`
Hostname string `mapstructure:"hostname" yaml:"hostname"`
Port int `mapstructure:"port" yaml:"port"`
LibraryDirs []string `mapstructure:"metadata-library-dir" yaml:"metadata-library-dir"`
}

var cfg = Config{DbFile: "metadata.sqlite.db", Hostname: "localhost", Port: 8080, LibraryDirs: []string(nil)}

const EnvPrefix = "MR"
12 changes: 5 additions & 7 deletions cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This command can create a new ml-metadata Sqlite DB, or migrate an existing DB
to the latest schema required by this server.`,
RunE: func(cmd *cobra.Command, args []string) error {
// connect to DB
dbConn, err := NewDatabaseConnection(dbFile)
dbConn, err := NewDatabaseConnection(cfg.DbFile)
defer func() {
// close DB connection on exit
db, err2 := dbConn.DB()
Expand Down Expand Up @@ -76,9 +76,9 @@ func migrateDatabase(dbConn *gorm.DB) error {
}

func loadLibraries(dbConn *gorm.DB) error {
libs, err := library.LoadLibraries(libraryDirs)
libs, err := library.LoadLibraries(cfg.LibraryDirs)
if err != nil {
return fmt.Errorf("failed to read library directories %s: %w", libraryDirs, err)
return fmt.Errorf("failed to read library directories %s: %w", cfg.LibraryDirs, err)
}
for path, lib := range libs {
grpcServer := grpc.NewGrpcServer(dbConn)
Expand Down Expand Up @@ -128,8 +128,6 @@ func ToProtoProperties(props map[string]library.PropertyType) map[string]proto.P
return result
}

var libraryDirs []string

func init() {
rootCmd.AddCommand(migrateCmd)

Expand All @@ -141,6 +139,6 @@ func init() {

// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
migrateCmd.Flags().StringVarP(&dbFile, "db-file", "d", "metadata.sqlite.db", "Sqlite DB file")
migrateCmd.Flags().StringSliceVarP(&libraryDirs, "metadata-library-dir", "m", libraryDirs, "Built-in metadata types library directories containing yaml files")
migrateCmd.Flags().StringVarP(&cfg.DbFile, "db-file", "d", cfg.DbFile, "Sqlite DB file")
migrateCmd.Flags().StringSliceVarP(&cfg.LibraryDirs, "metadata-library-dir", "m", cfg.LibraryDirs, "Built-in metadata types library directories containing yaml files")
}
44 changes: 38 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package cmd

import (
"errors"
"flag"
"fmt"
"github.com/golang/glog"
"github.com/spf13/pflag"
"os"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand All @@ -28,6 +31,9 @@ custom metadata libraries, exposing a higher level GraphQL API, RBAC, etc.`,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
return initConfig(cmd)
},
}

// Execute adds all child commands to the root command and sets flags appropriately.
Expand All @@ -40,13 +46,13 @@ func Execute() {
}

func init() {
cobra.OnInitialize(initConfig)
//cobra.OnInitialize(initConfig)

// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.

rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.model-registry.yaml)")
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file (default is $HOME/.model-registry.yaml)")

// default to logging to stderr
_ = flag.Set("logtostderr", "true")
Expand All @@ -55,30 +61,56 @@ func init() {

// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
//rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")

}

// initConfig reads in config file and ENV variables if set.
func initConfig() {
func initConfig(cmd *cobra.Command) error {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := os.UserHomeDir()
cobra.CheckErr(err)
if err != nil {
return err
}

// Search config in home directory with name ".model-registry" (without extension).
viper.AddConfigPath(home)
viper.SetConfigType("yaml")
viper.SetConfigName(".model-registry")
}

viper.SetEnvPrefix(EnvPrefix)
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv() // read in environment variables that match

// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed())
glog.Info("using config file: ", viper.ConfigFileUsed())
} else {
var configFileNotFoundError viper.ConfigFileNotFoundError
ok := errors.As(err, &configFileNotFoundError)
// ignore if it's a file not found error for default config file
if !(cfgFile == "" && ok) {
return fmt.Errorf("reading config %s: %v", viper.ConfigFileUsed(), err)
}
}

// bind flags to config
if err := viper.BindPFlags(cmd.Flags()); err != nil {
return err
}
var err error
cmd.Flags().VisitAll(func(f *pflag.Flag) {
name := f.Name
if err == nil && !f.Changed && viper.IsSet(name) {
value := viper.Get(name)
err = cmd.Flags().Set(name, fmt.Sprintf("%v", value))
}
})

return err
}
16 changes: 6 additions & 10 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ func InterceptorLogger(l *log.Logger) logging.Logger {
}

var (
dbFile string
host = "localhost"
port int = 8080

// serveCmd represents the serve command
serveCmd = &cobra.Command{
Use: "serve",
Expand All @@ -64,7 +60,7 @@ location of the database file and the hostname and port where it listens.'`,
)

func runServer(cmd *cobra.Command, args []string) error {
glog.Info("server started...")
glog.Infof("server started at %s:%v", cfg.Hostname, cfg.Port)

// Create a channel to receive signals
signalChannel := make(chan os.Signal, 1)
Expand All @@ -73,13 +69,13 @@ func runServer(cmd *cobra.Command, args []string) error {
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)

// connect to the DB using Gorm
db, err := NewDatabaseConnection(dbFile)
db, err := NewDatabaseConnection(cfg.DbFile)
if err != nil {
log.Fatalf("db connection failed: %v", err)
}

// listen on host:port
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port))
if err != nil {
log.Fatalf("server listen failed: %v", err)
}
Expand Down Expand Up @@ -189,7 +185,7 @@ func init() {

// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
serveCmd.Flags().StringVarP(&dbFile, "db-file", "d", "metadata.sqlite.db", "Sqlite DB file")
serveCmd.Flags().StringVarP(&host, "hostname", "n", host, "Server listen hostname")
serveCmd.Flags().IntVarP(&port, "port", "p", port, "Server listen port")
serveCmd.Flags().StringVarP(&cfg.DbFile, "db-file", "d", cfg.DbFile, "Sqlite DB file")
serveCmd.Flags().StringVarP(&cfg.Hostname, "hostname", "n", cfg.Hostname, "Server listen hostname")
serveCmd.Flags().IntVarP(&cfg.Port, "port", "p", cfg.Port, "Server listen port")
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/searKing/golang/tools/go-enum v1.2.97
github.com/soheilhy/cmux v0.1.5
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.16.0
github.com/vektah/gqlparser/v2 v2.5.8
golang.org/x/sync v0.2.0
Expand Down Expand Up @@ -41,7 +42,6 @@ require (
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/urfave/cli/v2 v2.25.5 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
Expand Down

0 comments on commit cabad69

Please sign in to comment.