diff --git a/bind.go b/bind.go index 877c0f5c4..ed7ca3249 100644 --- a/bind.go +++ b/bind.go @@ -8,6 +8,7 @@ import ( "encoding/xml" "errors" "fmt" + "mime/multipart" "net/http" "reflect" "strconv" @@ -45,7 +46,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { for i, name := range names { params[name] = []string{values[i]} } - if err := b.bindData(i, params, "param"); err != nil { + if err := b.bindData(i, params, "param", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -53,7 +54,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { // BindQueryParams binds query params to bindable object func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { + if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -70,9 +71,12 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return } - ctype := req.Header.Get(HeaderContentType) - switch { - case strings.HasPrefix(ctype, MIMEApplicationJSON): + // mediatype is found like `mime.ParseMediaType()` does it + base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";") + mediatype := strings.TrimSpace(base) + + switch mediatype { + case MIMEApplicationJSON: if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { switch err.(type) { case *HTTPError: @@ -81,7 +85,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } } - case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): + case MIMEApplicationXML, MIMETextXML: if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) @@ -90,12 +94,20 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): + case MIMEApplicationForm: params, err := c.FormParams() if err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = b.bindData(i, params, "form"); err != nil { + if err = b.bindData(i, params, "form", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + case MIMEMultipartForm: + params, err := c.MultipartForm() + if err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + if err = b.bindData(i, params.Value, "form", params.File); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } default: @@ -106,7 +118,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { // BindHeaders binds HTTP headers to a bindable object func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { + if err := b.bindData(i, c.Request().Header, "header", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -132,10 +144,11 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { - if destination == nil || len(data) == 0 { +func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { + if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } + hasFiles := len(dataFiles) > 0 typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() @@ -179,7 +192,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri return errors.New("binding element must be a struct") } - for i := 0; i < typ.NumField(); i++ { + for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { @@ -198,10 +211,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } @@ -209,10 +222,20 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri continue } + if hasFiles { + if ok, err := isFieldMultipartFile(structField.Type()); err != nil { + return err + } else if ok { + if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok { + continue + } + } + } + inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To + // Go json.Unmarshal supports case-insensitive binding. However the + // url params are bound case-sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. for k, v := range data { @@ -394,3 +417,50 @@ func setFloatField(value string, bitSize int, field reflect.Value) error { } return err } + +var ( + // NOT supported by bind as you can NOT check easily empty struct being actual file or not + multipartFileHeaderType = reflect.TypeOf(multipart.FileHeader{}) + // supported by bind as you can check by nil value if file existed or not + multipartFileHeaderPointerType = reflect.TypeOf(&multipart.FileHeader{}) + multipartFileHeaderSliceType = reflect.TypeOf([]multipart.FileHeader(nil)) + multipartFileHeaderPointerSliceType = reflect.TypeOf([]*multipart.FileHeader(nil)) +) + +func isFieldMultipartFile(field reflect.Type) (bool, error) { + switch field { + case multipartFileHeaderPointerType, + multipartFileHeaderSliceType, + multipartFileHeaderPointerSliceType: + return true, nil + case multipartFileHeaderType: + return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct") + default: + return false, nil + } +} + +func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool { + fileHeaders := files[inputFieldName] + if len(fileHeaders) == 0 { + return false + } + + result := true + switch structField.Type() { + case multipartFileHeaderPointerSliceType: + structField.Set(reflect.ValueOf(fileHeaders)) + case multipartFileHeaderSliceType: + headers := make([]multipart.FileHeader, len(fileHeaders)) + for i, fileHeader := range fileHeaders { + headers[i] = *fileHeader + } + structField.Set(reflect.ValueOf(headers)) + case multipartFileHeaderPointerType: + structField.Set(reflect.ValueOf(fileHeaders[0])) + default: + result = false + } + + return result +} diff --git a/bind_test.go b/bind_test.go index fc0c00598..c79669c8c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -446,7 +446,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string", func(t *testing.T) { dest := map[string]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -458,7 +458,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { var dest map[string]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -470,7 +470,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -482,7 +482,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { var dest map[string][]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -494,7 +494,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface", func(t *testing.T) { dest := map[string]interface{}{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]interface{}{ "multiple": "1", @@ -506,7 +506,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { var dest map[string]interface{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]interface{}{ "multiple": "1", @@ -518,25 +518,25 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int{}, dest) }) t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { var dest map[string]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int(nil), dest) }) t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int{}, dest) }) t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { var dest map[string][]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int(nil), dest) }) } @@ -544,7 +544,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := b.bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) @@ -666,7 +666,7 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = binder.bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) @@ -1420,3 +1420,119 @@ func TestBindInt8(t *testing.T) { assert.Equal(t, target{V: &[]int8{1, 2}}, p) }) } + +func TestBindMultipartFormFiles(t *testing.T) { + file1 := createTestFormFile("file", "file1.txt") + file11 := createTestFormFile("file", "file11.txt") + file2 := createTestFormFile("file2", "file2.txt") + filesA := createTestFormFile("files", "filesA.txt") + filesB := createTestFormFile("files", "filesB.txt") + + t.Run("nok, can not bind to multipart file struct", func(t *testing.T) { + var target struct { + File multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct") + }) + + t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) + }) + + t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file11) + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) // should choose first one + }) + + t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) { + var target struct { + Files []multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, &target.Files[0], filesA) + assertMultipartFileHeader(t, &target.Files[1], filesB) + }) + + t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) { + var target struct { + Files []*multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, target.Files[0], filesA) + assertMultipartFileHeader(t, target.Files[1], filesB) + }) +} + +type testFormFile struct { + Fieldname string + Filename string + Content []byte +} + +func createTestFormFile(formFieldName string, filename string) testFormFile { + return testFormFile{ + Fieldname: formFieldName, + Filename: filename, + Content: []byte(strings.Repeat(filename, 10)), + } +} + +func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + + for _, file := range files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + assert.NoError(t, err) + + n, err := fw.Write(file.Content) + assert.NoError(t, err) + assert.Equal(t, len(file.Content), n) + } + + err := mw.Close() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/", &body) + assert.NoError(t, err) + req.Header.Set("Content-Type", mw.FormDataContentType()) + + rec := httptest.NewRecorder() + + e := New() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) { + assert.Equal(t, file.Filename, fh.Filename) + assert.Equal(t, int64(len(file.Content)), fh.Size) + fl, err := fh.Open() + assert.NoError(t, err) + body, err := io.ReadAll(fl) + assert.NoError(t, err) + assert.Equal(t, string(file.Content), string(body)) + err = fl.Close() + assert.NoError(t, err) +}