Skip to content

Commit

Permalink
Recursively construct toml key (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmezonur authored Dec 6, 2019
1 parent 43633a3 commit ce10418
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
33 changes: 26 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Load(filepath string, dst interface{}) error {
return err
}

return bindFlags(dst, metadata)
return bindFlags(dst, metadata, "")
}

// bindEnvVariables will bind CLI flags to their respective elements in dst, defined by the struct-tag "env".
Expand Down Expand Up @@ -65,7 +65,7 @@ func bindEnvVariables(dst interface{}) error {
}

// bindFlags will bind CLI flags to their respective elements in dst, defined by the struct-tag "flag".
func bindFlags(dst interface{}, metadata toml.MetaData) error {
func bindFlags(dst interface{}, metadata toml.MetaData, fieldPath string) error {
fields := structs.Fields(dst)
for _, field := range fields {
tag := field.Tag(flagTag)
Expand All @@ -75,7 +75,18 @@ func bindFlags(dst interface{}, metadata toml.MetaData) error {
continue
}

if err := bindFlags(dstElem.Addr().Interface(), metadata); err != nil {
var path string
if fieldPath != "" {
path = fmt.Sprintf("%s.", fieldPath)
}

if field.Tag(tomlTag) != "" {
path += field.Tag(tomlTag)
} else {
path += field.Name()
}

if err := bindFlags(dstElem.Addr().Interface(), metadata, path); err != nil {
return err
}

Expand All @@ -91,7 +102,15 @@ func bindFlags(dst interface{}, metadata toml.MetaData) error {
useFlagDefaultValue := false
if !isFlagSet(tag) {
_, envHasKey := os.LookupEnv(field.Tag(envTag))
if envHasKey || tomlHasKey(metadata, field.Tag(tomlTag)) {

var tomlKey string
if fieldPath == "" {
tomlKey = field.Tag(tomlTag)
} else {
tomlKey = fmt.Sprintf("%s.%s", fieldPath, field.Tag(tomlTag))
}

if envHasKey || tomlHasKey(metadata, tomlKey) {
continue
} else {
useFlagDefaultValue = true
Expand Down Expand Up @@ -190,10 +209,10 @@ func isFlagSet(tag string) bool {
return flagSet
}

// tomlHasKey will check if the tag presents in toml metadata
func tomlHasKey(metadata toml.MetaData, tag string) bool {
// tomlHasKey will check if the toml key presents in toml metadata
func tomlHasKey(metadata toml.MetaData, tomlKey string) bool {
for _, key := range metadata.Keys() {
if strings.ToLower(key.String()) == strings.ToLower(tag) {
if strings.EqualFold(key.String(), tomlKey) {
return true
}
}
Expand Down
64 changes: 64 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,70 @@ LogLevel = "debug"
}
}

func TestLoad_TomlNested_FlagSetAndNotGiven(t *testing.T) {
var cfg struct {
DB struct {
Account string `toml:"account" flag:"db-account"`
Username string `toml:"username" flag:"db-user"`
Credentials struct {
Secret string `toml:"secret" flag:"db-secret"`
Password string `toml:"password" flag:"db-password"`
} `toml:"credentials"`
Options *struct {
Port int `toml:"port" flag:"db-port"`
}
} `toml:"database"`
}
tmp, _ := ioutil.TempFile("", "")
defer os.Remove(tmp.Name())

_, err := tmp.WriteString(`
[database]
account = "test_account"
username = "test_user"
[database.credentials]
secret = "wowowow"
password = "12345"
[database.options]
port = 3306
`)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

fs := flag.NewFlagSet("tmp", flag.ExitOnError)
_ = fs.String("db-account", "default", "")
_ = fs.String("db-user", "default", "")
_ = fs.String("db-secret", "default", "")
_ = fs.String("db-password", "default", "")
_ = fs.Int("db-port", 0, "")
flag.CommandLine = fs

if err := Load(tmp.Name(), &cfg); err != nil {
t.Fatalf("unexpected error %v", err)
}

if cfg.DB.Account != "test_account" {
t.Errorf("got: %v, expected: %v", cfg.DB.Account, "test_account")
}

if cfg.DB.Username != "test_user" {
t.Errorf("got: %v, expected: %v", cfg.DB.Username, "test_user")
}

if cfg.DB.Credentials.Secret != "wowowow" {
t.Errorf("got: %v, expected: %v", cfg.DB.Credentials.Secret, "wowowow")
}

if cfg.DB.Credentials.Password != "12345" {
t.Errorf("got: %v, expected: %v", cfg.DB.Credentials.Password, "12345")
}

if cfg.DB.Options.Port != 3306 {
t.Errorf("got: %v, expected: %v", cfg.DB.Options.Port, 3306)
}
}

func TestLoad_EnvGivenWithNested(t *testing.T) {
os.Clearenv()
var cfg struct {
Expand Down

0 comments on commit ce10418

Please sign in to comment.