From 257d725722662225533c4edbd9b427207038b12e Mon Sep 17 00:00:00 2001 From: michal-laskowski <1753681+michal-laskowski@users.noreply.github.com> Date: Wed, 13 Nov 2024 00:42:01 +0100 Subject: [PATCH 1/4] fix: Embedded struct binding --- bind_test.go | 33 +++++++++++++++++++++++++++++++++ binder/mapping.go | 8 +++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/bind_test.go b/bind_test.go index aa00e191ca..72ce66571b 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1091,6 +1091,39 @@ func Benchmark_Bind_Body_XML(b *testing.B) { require.Equal(b, "john", d.Name) } +// go test -run Test_Bind_Body_Form_Embedded +func Test_Bind_Body_Form_Embedded(b *testing.T) { + var err error + + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + type EmbededDemo struct { + EmbededStrings []string + EmbededString string + } + + type Demo struct { + SomeString string + SomeOtherString string + Strings []string + EmbededDemo + } + body := []byte("SomeString=john%2Clong&SomeOtherString=long%2Cjohn&Strings=long%2Cjohn&EmbededStrings=john%2Clong&EmbededString=johny%2Cwalker") + c.Request().SetBody(body) + c.Request().Header.SetContentType(MIMEApplicationForm) + c.Request().Header.SetContentLength(len(body)) + d := new(Demo) + + err = c.Bind().Body(d) + + require.NoError(b, err) + require.Equal(b, []string{"long", "john"}, d.Strings) + require.Equal(b, []string{"john", "long"}, d.EmbededStrings) + require.Equal(b, "johny,walker", d.EmbededString) + require.Equal(b, "john,long", d.SomeString) + require.Equal(b, "long,john", d.SomeOtherString) +} + // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form -benchmem -count=4 func Benchmark_Bind_Body_Form(b *testing.B) { var err error diff --git a/binder/mapping.go b/binder/mapping.go index ea67ace200..b00feabf13 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -203,9 +203,15 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool { // Does the field type equals input? if structFieldKind != kind { // Is the field an embedded struct? - if structFieldKind == reflect.Struct { + if structFieldKind == reflect.Struct && typeField.Anonymous { // Loop over embedded struct fields for j := 0; j < structField.NumField(); j++ { + fNm := utils.ToLower(structField.Type().Field(j).Name) + if fNm != key { + //this is not the field that we are looking for + continue + } + structFieldField := structField.Field(j) // Can this embedded field be changed? From 35d7ad11efdc77c50d3a269d077ca5bcb71c1e4d Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:17:42 -0500 Subject: [PATCH 2/4] Update bind_test.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- bind_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bind_test.go b/bind_test.go index 72ce66571b..94b2fd70cc 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1097,16 +1097,16 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) - type EmbededDemo struct { - EmbededStrings []string - EmbededString string + type EmbeddedDemo struct { + EmbeddedStrings []string `form:"embedded_strings"` + EmbeddedString string `form:"embedded_string"` } type Demo struct { - SomeString string - SomeOtherString string - Strings []string - EmbededDemo + SomeString string `form:"some_string"` + SomeOtherString string `form:"some_other_string"` + Strings []string `form:"strings"` + EmbeddedDemo } body := []byte("SomeString=john%2Clong&SomeOtherString=long%2Cjohn&Strings=long%2Cjohn&EmbededStrings=john%2Clong&EmbededString=johny%2Cwalker") c.Request().SetBody(body) From ee247ed4aec7c1eabef1cae46392e0eee558b098 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:25:07 -0500 Subject: [PATCH 3/4] Fix typo --- bind_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bind_test.go b/bind_test.go index 94b2fd70cc..c6d863cb19 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1118,8 +1118,8 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) { require.NoError(b, err) require.Equal(b, []string{"long", "john"}, d.Strings) - require.Equal(b, []string{"john", "long"}, d.EmbededStrings) - require.Equal(b, "johny,walker", d.EmbededString) + require.Equal(b, []string{"john", "long"}, d.EmbeddedStrings) + require.Equal(b, "johny,walker", d.EmbeddedString) require.Equal(b, "john,long", d.SomeString) require.Equal(b, "long,john", d.SomeOtherString) } From cc0971cd060c829660e08ef899d6f4f42e24c029 Mon Sep 17 00:00:00 2001 From: michal-laskowski <1753681+michal-laskowski@users.noreply.github.com> Date: Thu, 14 Nov 2024 00:07:32 +0100 Subject: [PATCH 4/4] add: Support for other bindings fo embedded structs --- bind_test.go | 19 ++++++---- binder/cookie.go | 12 +----- binder/form.go | 10 +---- binder/header.go | 12 +----- binder/mapping.go | 86 ++++++++++++++++++++++++------------------ binder/mapping_test.go | 18 ++++----- binder/query.go | 10 +---- binder/resp_header.go | 12 +----- 8 files changed, 74 insertions(+), 105 deletions(-) diff --git a/bind_test.go b/bind_test.go index c6d863cb19..5cc59f6280 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1098,17 +1098,18 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type EmbeddedDemo struct { - EmbeddedStrings []string `form:"embedded_strings"` EmbeddedString string `form:"embedded_string"` + EmbeddedStrings []string `form:"embedded_strings"` } type Demo struct { - SomeString string `form:"some_string"` - SomeOtherString string `form:"some_other_string"` - Strings []string `form:"strings"` + String string `form:"some_string"` + OtherString string `form:"some_other_string"` + Strings []string `form:"strings"` + OtherStrings []string `form:"other_strings"` EmbeddedDemo } - body := []byte("SomeString=john%2Clong&SomeOtherString=long%2Cjohn&Strings=long%2Cjohn&EmbededStrings=john%2Clong&EmbededString=johny%2Cwalker") + body := []byte("some_string=john%2Clong&some_other_string=long&some_other_string=long&strings=long%2Cjohn&embedded_strings=john%2Clongest&embedded_string=johny%2Cwalker&other_strings=long&other_strings=johny") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().Header.SetContentLength(len(body)) @@ -1117,11 +1118,13 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) { err = c.Bind().Body(d) require.NoError(b, err) + require.Equal(b, "john,long", d.String) require.Equal(b, []string{"long", "john"}, d.Strings) - require.Equal(b, []string{"john", "long"}, d.EmbeddedStrings) + //! only one value is taken + require.Equal(b, "long", d.OtherString) + require.Equal(b, []string{"long", "johny"}, d.OtherStrings) require.Equal(b, "johny,walker", d.EmbeddedString) - require.Equal(b, "john,long", d.SomeString) - require.Equal(b, "long,john", d.SomeOtherString) + require.Equal(b, []string{"john", "longest"}, d.EmbeddedStrings) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form -benchmem -count=4 diff --git a/binder/cookie.go b/binder/cookie.go index 0f5c650c33..2dd1dc864a 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -26,14 +23,7 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/form.go b/binder/form.go index f45407fe93..7df300e256 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,7 +1,6 @@ package binder import ( - "reflect" "strings" "github.com/gofiber/utils/v2" @@ -30,14 +29,7 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/header.go b/binder/header.go index 196163694d..3610408137 100644 --- a/binder/header.go +++ b/binder/header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -20,14 +17,7 @@ func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) return parse(b.Name(), out, data) diff --git a/binder/mapping.go b/binder/mapping.go index b00feabf13..e9420a1278 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -172,7 +172,18 @@ func parseParamSquareBrackets(k string) (string, error) { return bb.String(), nil } -func equalFieldType(out any, kind reflect.Kind, key string) bool { +func appendValue(to map[string][]string, rawValue string, out any, k string, bindingName string) { + if strings.Contains(rawValue, ",") && equalFieldType(out, reflect.Slice, k, bindingName) { + values := strings.Split(rawValue, ",") + for i := 0; i < len(values); i++ { + to[k] = append(to[k], values[i]) + } + } else { + to[k] = append(to[k], rawValue) + } +} + +func equalFieldType(out any, kind reflect.Kind, key string, bindingName string) bool { // Get type of interface outTyp := reflect.TypeOf(out).Elem() key = utils.ToLower(key) @@ -196,53 +207,54 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool { if !structField.CanSet() { continue } + // Get field key data typeField := outTyp.Field(i) // Get type of field key structFieldKind := structField.Kind() - // Does the field type equals input? - if structFieldKind != kind { - // Is the field an embedded struct? - if structFieldKind == reflect.Struct && typeField.Anonymous { - // Loop over embedded struct fields - for j := 0; j < structField.NumField(); j++ { - fNm := utils.ToLower(structField.Type().Field(j).Name) - if fNm != key { - //this is not the field that we are looking for - continue - } - - structFieldField := structField.Field(j) - - // Can this embedded field be changed? - if !structFieldField.CanSet() { - continue - } - - // Is the embedded struct field type equal to the input? - if structFieldField.Kind() == kind { - return true - } - } - } - continue - } - // Get tag from field if exist - inputFieldName := typeField.Tag.Get(QueryBinder.Name()) - if inputFieldName == "" { - inputFieldName = typeField.Name - } else { - inputFieldName = strings.Split(inputFieldName, ",")[0] - } // Compare field/tag with provided key - if utils.ToLower(inputFieldName) == key { - return true + if getFieldKey(typeField, bindingName) == key { + return structFieldKind == kind + } + + // Is the field an embedded struct? + if typeField.Anonymous { + // Loop over embedded struct fields + for j := 0; j < structField.NumField(); j++ { + if getFieldKey(structField.Type().Field(j), bindingName) != key { + // this is not the field that we are looking for + continue + } + + structFieldField := structField.Field(j) + + // Can this embedded field be changed? + if !structFieldField.CanSet() { + continue + } + + // Is the embedded struct field type equal to the input? + return structFieldField.Kind() == kind + } } } return false } +// Get binding key for a field +func getFieldKey(typeField reflect.StructField, bindingName string) string { + // Get tag from field if exist + inputFieldName := typeField.Tag.Get(bindingName) + if inputFieldName == "" { + inputFieldName = typeField.Name + } else { + inputFieldName = strings.Split(inputFieldName, ",")[0] + } + // Compare field key + return utils.ToLower(inputFieldName) +} + // Get content type from content type header func FilterFlags(content string) string { for i, char := range content { diff --git a/binder/mapping_test.go b/binder/mapping_test.go index e6fc8146f7..1f74664cfe 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -10,25 +10,25 @@ import ( func Test_EqualFieldType(t *testing.T) { var out int - require.False(t, equalFieldType(&out, reflect.Int, "key")) + require.False(t, equalFieldType(&out, reflect.Int, "key", "query")) var dummy struct{ f string } - require.False(t, equalFieldType(&dummy, reflect.String, "key")) + require.False(t, equalFieldType(&dummy, reflect.String, "key", "query")) var dummy2 struct{ f string } - require.False(t, equalFieldType(&dummy2, reflect.String, "f")) + require.False(t, equalFieldType(&dummy2, reflect.String, "f", "query")) var user struct { Name string Address string `query:"address"` Age int `query:"AGE"` } - require.True(t, equalFieldType(&user, reflect.String, "name")) - require.True(t, equalFieldType(&user, reflect.String, "Name")) - require.True(t, equalFieldType(&user, reflect.String, "address")) - require.True(t, equalFieldType(&user, reflect.String, "Address")) - require.True(t, equalFieldType(&user, reflect.Int, "AGE")) - require.True(t, equalFieldType(&user, reflect.Int, "age")) + require.True(t, equalFieldType(&user, reflect.String, "name", "query")) + require.True(t, equalFieldType(&user, reflect.String, "Name", "query")) + require.True(t, equalFieldType(&user, reflect.String, "address", "query")) + require.True(t, equalFieldType(&user, reflect.String, "Address", "query")) + require.True(t, equalFieldType(&user, reflect.Int, "AGE", "query")) + require.True(t, equalFieldType(&user, reflect.Int, "age", "query")) } func Test_ParseParamSquareBrackets(t *testing.T) { diff --git a/binder/query.go b/binder/query.go index 25b69f5bc3..d35d92d22c 100644 --- a/binder/query.go +++ b/binder/query.go @@ -1,7 +1,6 @@ package binder import ( - "reflect" "strings" "github.com/gofiber/utils/v2" @@ -30,14 +29,7 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/resp_header.go b/binder/resp_header.go index 0455185ba1..749e98b324 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -20,14 +17,7 @@ func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) return parse(b.Name(), out, data)