diff --git a/README.md b/README.md index b790a76..57fd0dc 100644 --- a/README.md +++ b/README.md @@ -39,3 +39,10 @@ Using kv, errs := struct2env.StructToEnvVars(foo) txt := struct2env.ToShellWithPrefix("TST_", kv) ``` + +Type conversions: + +- Most primitive type to their string representation, single quote (') escaped. +- []byte are encoded as base64 +- time.Time are formatted as RFC3339 +- time.Duration are in (floating point) seconds. diff --git a/env.go b/env.go index 60aab0a..3975e40 100644 --- a/env.go +++ b/env.go @@ -1,30 +1,30 @@ // Package env provides conversion from structure to and from environment variables. // -// It supports converting struct fields to environment variables using field tags, -// handling different data types, and transforming strings between different case -// conventions, which is useful for generating or parsing environment variables, -// JSON tags, or command line flags. +// Supports converting struct fields to environment variables using field tags, +// handling most data types. Provides functions to serialize structs into slices +// of key-value pairs where the keys are derived from struct field names transformed +// to upper snake case by default, or specified explicitly via struct field tags. // -// The package also defines several case conversion functions that aid in manipulating -// strings to fit conventional casing for various programming and configuration contexts. -// Additionally, it provides functions to serialize structs into slices of key-value pairs -// where the keys are derived from struct field names transformed to upper snake case by default, -// or specified explicitly via struct field tags. -// -// It also includes functionality to deserialize environment variables back into +// Includes functionality to deserialize environment variables back into // struct fields, handling pointers and nested structs appropriately, as well as providing // shell-compatible output for environment variable definitions. // +// Incidentally the package also defines several case conversion functions that aid in manipulating +// which is useful for generating or parsing environment variables, +// JSON tags, or command line flags style of naming (camelCase, UPPER_SNAKE_CASE, lower-kebab-case ...) +// // The package leverages reflection to dynamically handle arbitrary struct types, // and has 0 dependencies. package struct2env import ( + "encoding/base64" "fmt" "os" "reflect" "strconv" "strings" + "time" "unicode" ) @@ -97,12 +97,16 @@ type KeyValue struct { } // Escape characters such as the result string can be embedded as a single argument in a shell fragment -// e.g for ENV_VAR= such as is safe (no $(cmd...) no ` etc`). -func ShellQuote(input string) string { +// e.g for ENV_VAR= such as is safe (no $(cmd...) no ` etc`). Will error out if NUL is found +// in the input (use []byte for that and it'll get base64 encoded/decoded). +func ShellQuote(input string) (string, error) { + if strings.ContainsRune(input, 0) { + return "", fmt.Errorf("String value %q should not contain NUL", input) + } // To emit a single quote in a single quote enclosed string you have to close the current ' then emit a quote (\'), // then reopen the single quote sequence to finish. Note that when the string ends with a quote there is an unnecessary // trailing ''. - return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'" + return "'" + strings.ReplaceAll(input, "'", `'\''`) + "'", nil } func (kv KeyValue) String() string { @@ -129,16 +133,20 @@ func ToShellWithPrefix(prefix string, kvl []KeyValue) string { return sb.String() } -func SerializeValue(value interface{}) string { +func SerializeValue(value interface{}) (string, error) { switch v := value.(type) { case bool: res := "false" if v { res = "true" } - return res + return res, nil + case []byte: + return ShellQuote(base64.StdEncoding.EncodeToString(v)) case string: return ShellQuote(v) + case time.Duration: + return fmt.Sprintf("%g", v.Seconds()), nil default: return ShellQuote(fmt.Sprint(value)) } @@ -151,6 +159,7 @@ func SerializeValue(value interface{}) string { // If the field is exportable and the tag is missing we'll use the field name // converted to UPPER_SNAKE_CASE (using CamelCaseToUpperSnakeCase()) as the // environment variable name. +// []byte are encoded as base64, time.Time are formatted as RFC3339, time.Duration are in (floating point) seconds. func StructToEnvVars(s interface{}) ([]KeyValue, []error) { var allErrors []error var allKeyValVals []KeyValue @@ -186,24 +195,49 @@ func structToEnvVars(envVars []KeyValue, allErrors []error, prefix string, s int } fieldValue := v.Field(i) stringValue := "" + var err error + + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { // other wise we hit the "struct" case below + timeField := fieldValue.Interface().(time.Time) + stringValue, err = SerializeValue(timeField.Format(time.RFC3339)) + if err != nil { + allErrors = append(allErrors, err) + } else { + envVars = append(envVars, KeyValue{Key: prefix + tag, QuotedValue: stringValue}) + } + continue // Continue to the next field + } + switch fieldValue.Kind() { //nolint: exhaustive // we have default: for the other cases case reflect.Ptr: if !fieldValue.IsNil() { fieldValue = fieldValue.Elem() - stringValue = SerializeValue(fieldValue.Interface()) + stringValue, err = SerializeValue(fieldValue.Interface()) } case reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: - // log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type) - continue + // From that list of other types, only support []byte + if fieldValue.Type().Elem().Kind() == reflect.Uint8 { + stringValue, err = SerializeValue(fieldValue.Interface()) + } else { + // log.LogVf("Skipping field %s of type %v, not supported", fieldType.Name, fieldType.Type) + continue + } case reflect.Struct: // Recurse with prefix envVars, allErrors = structToEnvVars(envVars, allErrors, tag+"_", fieldValue.Interface()) continue default: - value := fieldValue.Interface() - stringValue = SerializeValue(value) + if !fieldValue.CanInterface() { + err = fmt.Errorf("can't interface %s", fieldType.Name) + } else { + value := fieldValue.Interface() + stringValue, err = SerializeValue(value) + } } envVars = append(envVars, KeyValue{Key: prefix + tag, QuotedValue: stringValue}) + if err != nil { + allErrors = append(allErrors, err) + } } return envVars, allErrors } @@ -217,8 +251,8 @@ func setPointer(fieldValue reflect.Value) reflect.Value { return fieldValue.Elem() } -func checkEnv(envName, fieldName string, fieldValue reflect.Value) (*string, error) { - val, found := os.LookupEnv(envName) +func checkEnv(envLookup EnvLookup, envName, fieldName string, fieldValue reflect.Value) (*string, error) { + val, found := envLookup(envName) if !found { // log.LogVf("%q not set for %s", envName, fieldName) return nil, nil //nolint:nilnil @@ -231,11 +265,19 @@ func checkEnv(envName, fieldName string, fieldValue reflect.Value) (*string, err return &val, nil } +type EnvLookup func(key string) (string, bool) + +// Reverse of StructToEnvVars, assumes the same encoding. Using the current os environment variables as source. func SetFromEnv(prefix string, s interface{}) []error { - return setFromEnv(nil, prefix, s) + return SetFrom(os.LookupEnv, prefix, s) } -func setFromEnv(allErrors []error, prefix string, s interface{}) []error { +// Reverse of StructToEnvVars, assumes the same encoding. Using passed it lookup object that can lookup values by keys. +func SetFrom(envLookup EnvLookup, prefix string, s interface{}) []error { + return setFromEnv(nil, envLookup, prefix, s) +} + +func setFromEnv(allErrors []error, envLookup EnvLookup, prefix string, s interface{}) []error { // TODO: this is quite similar in structure to structToEnvVars() - can it be refactored with // passing setter vs getter function and share the same iteration (yet a little bit of copy is the go way too) v := reflect.ValueOf(s) @@ -263,17 +305,18 @@ func setFromEnv(allErrors []error, prefix string, s interface{}) []error { kind := fieldValue.Kind() - if kind == reflect.Struct { + // Handle time.Time separately a bit below after we get the value + if kind == reflect.Struct && fieldType.Type != reflect.TypeOf(time.Time{}) { // Recurse with prefix if fieldValue.CanAddr() { // Check if we can get the address - SetFromEnv(envName+"_", fieldValue.Addr().Interface()) + allErrors = setFromEnv(allErrors, envLookup, envName+"_", fieldValue.Addr().Interface()) } else { err := fmt.Errorf("cannot take the address of %s to recurse", fieldType.Name) allErrors = append(allErrors, err) } continue } - val, err := checkEnv(envName, fieldType.Name, fieldValue) + val, err := checkEnv(envLookup, envName, fieldType.Name, fieldValue) if err != nil { allErrors = append(allErrors, err) continue @@ -288,33 +331,72 @@ func setFromEnv(allErrors []error, prefix string, s interface{}) []error { kind = fieldValue.Type().Elem().Kind() fieldValue = setPointer(fieldValue) } - switch kind { //nolint: exhaustive // we have default: for the other cases - case reflect.String: - fieldValue.SetString(envVal) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - var ev int64 - ev, err = strconv.ParseInt(envVal, 10, fieldValue.Type().Bits()) + if fieldType.Type == reflect.TypeOf(time.Time{}) { + var timeField time.Time + timeField, err = time.Parse(time.RFC3339, envVal) if err == nil { - fieldValue.SetInt(ev) + fieldValue.Set(reflect.ValueOf(timeField)) + } else { + allErrors = append(allErrors, err) } - case reflect.Float32, reflect.Float64: + continue + } + allErrors = setValue(allErrors, fieldType, fieldValue, kind, envName, envVal) + } + return allErrors +} + +func setValue( + allErrors []error, + fieldType reflect.StructField, + fieldValue reflect.Value, + kind reflect.Kind, + envName, envVal string, +) []error { + var err error + switch kind { //nolint: exhaustive // we have default: for the other cases + case reflect.String: + fieldValue.SetString(envVal) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // if it's a duration, parse it as a float seconds + if fieldType.Type == reflect.TypeOf(time.Duration(0)) { var ev float64 - ev, err = strconv.ParseFloat(envVal, fieldValue.Type().Bits()) + ev, err = strconv.ParseFloat(envVal, 64) if err == nil { - fieldValue.SetFloat(ev) + fieldValue.SetInt(int64(ev * float64(1*time.Second))) } - case reflect.Bool: - var ev bool - ev, err = strconv.ParseBool(envVal) + } else { + var ev int64 + ev, err = strconv.ParseInt(envVal, 10, fieldValue.Type().Bits()) if err == nil { - fieldValue.SetBool(ev) + fieldValue.SetInt(ev) } - default: - err = fmt.Errorf("unsupported type %v to set from %s=%q", kind, envName, envVal) } - if err != nil { - allErrors = append(allErrors, err) + case reflect.Float32, reflect.Float64: + var ev float64 + ev, err = strconv.ParseFloat(envVal, fieldValue.Type().Bits()) + if err == nil { + fieldValue.SetFloat(ev) } + case reflect.Bool: + var ev bool + ev, err = strconv.ParseBool(envVal) + if err == nil { + fieldValue.SetBool(ev) + } + case reflect.Slice: + if fieldValue.Type().Elem().Kind() != reflect.Uint8 { + err = fmt.Errorf("unsupported slice of %v to set from %s=%q", fieldValue.Type().Elem().Kind(), envName, envVal) + } else { + var data []byte + data, err = base64.StdEncoding.DecodeString(envVal) + fieldValue.SetBytes(data) + } + default: + err = fmt.Errorf("unsupported type %v to set from %s=%q", kind, envName, envVal) + } + if err != nil { + allErrors = append(allErrors, err) } return allErrors } diff --git a/env_test.go b/env_test.go index 4e53418..cc2402f 100644 --- a/env_test.go +++ b/env_test.go @@ -1,10 +1,10 @@ package struct2env import ( - "os" "reflect" "strings" "testing" + "time" ) func TestSplitByCase(t *testing.T) { @@ -118,12 +118,15 @@ type FooConfig struct { Embedded HiddenEmbedded `env:"-"` RecurseHere Embedded + SomeBinary []byte + Dur time.Duration + TS time.Time } func TestStructToEnvVars(t *testing.T) { intV := 199 foo := FooConfig{ - Foo: "a\nfoo with $X, `backticks`, \" quotes and \\ and ' in middle and end '", + Foo: "a newline:\nfoo with $X, `backticks`, \" quotes and \\ and ' in middle and end '", Bar: "42str", Blah: 42, ABool: true, @@ -135,6 +138,9 @@ func TestStructToEnvVars(t *testing.T) { InnerA: "rec a", InnerB: "rec b", }, + SomeBinary: []byte{0, 1, 2}, + Dur: 1*time.Hour + 100*time.Millisecond, + TS: time.Date(1998, time.November, 5, 14, 30, 0, 0, time.UTC), } foo.InnerA = "inner a" foo.InnerB = "inner b" @@ -149,12 +155,12 @@ func TestStructToEnvVars(t *testing.T) { if len(errors) != 0 { t.Errorf("expected no error, got %v", errors) } - if len(envVars) != 11 { - t.Errorf("expected 11 env vars, got %d: %+v", len(envVars), envVars) + if len(envVars) != 14 { + t.Errorf("expected 14 env vars, got %d: %+v", len(envVars), envVars) } str := ToShellWithPrefix("TST_", envVars) //nolint:lll - expected := `TST_FOO='a + expected := `TST_FOO='a newline: foo with $X, ` + "`backticks`" + `, " quotes and \ and '\'' in middle and end '\''' TST_BAR='42str' TST_A_SPECIAL_BLAH='42' @@ -166,37 +172,74 @@ TST_INNER_A='inner a' TST_INNER_B='inner b' TST_RECURSE_HERE_INNER_A='rec a' TST_RECURSE_HERE_INNER_B='rec b' -export TST_FOO TST_BAR TST_A_SPECIAL_BLAH TST_A_BOOL TST_HTTP_SERVER TST_INT_POINTER TST_FLOAT_POINTER TST_INNER_A TST_INNER_B TST_RECURSE_HERE_INNER_A TST_RECURSE_HERE_INNER_B +TST_SOME_BINARY='AAEC' +TST_DUR=3600.1 +TST_TS='1998-11-05T14:30:00Z' +export TST_FOO TST_BAR TST_A_SPECIAL_BLAH TST_A_BOOL TST_HTTP_SERVER TST_INT_POINTER TST_FLOAT_POINTER TST_INNER_A TST_INNER_B TST_RECURSE_HERE_INNER_A TST_RECURSE_HERE_INNER_B TST_SOME_BINARY TST_DUR TST_TS ` if str != expected { t.Errorf("\n---expected:---\n%s\n---got:---\n%s", expected, str) } + // NUL in string + type Cfg struct { + Foo string + } + cfg := Cfg{Foo: "ABC\x00DEF"} + envVars, errors = StructToEnvVars(&cfg) + if len(errors) != 1 { + t.Errorf("Should have had error with embedded NUL") + } + if envVars[0].Key != "FOO" { + t.Errorf("Expecting key to be present %v", envVars) + } + if envVars[0].QuotedValue != "" { + t.Errorf("Expecting value to be empty %v", envVars) + } } func TestSetFromEnv(t *testing.T) { foo := FooConfig{} - envs := []struct { - k string - v string - }{ - {"TST2_FOO", "another\nfoo"}, - {"TST2_BAR", "bar"}, - {"TST2_RECURSE_HERE_INNER_B", "in1"}, - {"TST2_A_SPECIAL_BLAH", "31"}, - {"TST2_A_BOOL", "1"}, - {"TST2_FLOAT_POINTER", "5.75"}, - {"TST2_INT_POINTER", "73"}, - } - for _, e := range envs { - os.Setenv(e.k, e.v) - } - errors := SetFromEnv("TST2_", &foo) + envs := map[string]string{ + "TST2_FOO": "another\nfoo", + "TST2_BAR": "bar", + "TST2_RECURSE_HERE_INNER_B": "in1", + "TST2_A_SPECIAL_BLAH": "31", + "TST2_A_BOOL": "1", + "TST2_FLOAT_POINTER": "5.75", + "TST2_INT_POINTER": "73", + "TST2_SOME_BINARY": "QUJDAERFRg==", + "TST2_DUR": "123.456789", + "TST2_TS": "1998-11-05T14:30:00Z", + } + lookup := func(key string) (string, bool) { + value, found := envs[key] + return value, found + } + errors := SetFrom(lookup, "TST2_", &foo) if len(errors) != 0 { t.Errorf("Unexpectedly got errors :%v", errors) } - if foo.Foo != "another\nfoo" || foo.Bar != "bar" || foo.RecurseHere.InnerB != "in1" || foo.Blah != 31 || foo.ABool != true || - foo.FloatPointer == nil || *foo.FloatPointer != 5.75 || - foo.IntPointer == nil || *foo.IntPointer != 73 { + if foo.Foo != "another\nfoo" || foo.Bar != "bar" || foo.RecurseHere.InnerB != "in1" || foo.Blah != 31 || foo.ABool != true { t.Errorf("Mismatch in object values, got: %+v", foo) } + if foo.IntPointer == nil || *foo.IntPointer != 73 { + t.Errorf("IntPointer not set correctly: %v %v", foo.IntPointer, *foo.IntPointer) + } + if foo.FloatPointer == nil || *foo.FloatPointer != 5.75 { + t.Errorf("FloatPointer not set correctly: %v %v", foo.FloatPointer, *foo.FloatPointer) + } + if string(foo.SomeBinary) != "ABC\x00DEF" { + t.Errorf("Base64 decoding not working for []byte field: %q", string(foo.SomeBinary)) + } + if foo.Dur != 123456789*time.Microsecond { + t.Errorf("Duration not set correctly: %v", foo.Dur) + } + if foo.TS != time.Date(1998, time.November, 5, 14, 30, 0, 0, time.UTC) { + t.Errorf("Time not set correctly: %v", foo.TS) + } + envs["TST2_TS"] = "not a rfc3339 time" + errors = SetFrom(lookup, "TST2_", &foo) + if len(errors) != 1 { + t.Errorf("Expected 1 error, got %v", errors) + } }