diff --git a/copier.go b/copier.go index bb2ad1b..9f3a6d0 100644 --- a/copier.go +++ b/copier.go @@ -24,6 +24,13 @@ const ( // Denotes that the value as been copied hasCopied + + // Some default converter types for a nicer syntax + String string = "" + Bool bool = false + Int int = 0 + Float32 float32 = 0 + Float64 float64 = 0 ) // Option sets copy options @@ -32,6 +39,18 @@ type Option struct { // struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go) IgnoreEmpty bool DeepCopy bool + Converters []TypeConverter +} + +type TypeConverter struct { + SrcType interface{} + DstType interface{} + Fn func(src interface{}) (interface{}, error) +} + +type converterPair struct { + SrcType reflect.Type + DstType reflect.Type } // Tag Flags @@ -59,12 +78,27 @@ func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) { var ( - isSlice bool - amount = 1 - from = indirect(reflect.ValueOf(fromValue)) - to = indirect(reflect.ValueOf(toValue)) + isSlice bool + amount = 1 + from = indirect(reflect.ValueOf(fromValue)) + to = indirect(reflect.ValueOf(toValue)) + converters map[converterPair]TypeConverter ) + // save convertes into map for faster lookup + for i := range opt.Converters { + if converters == nil { + converters = make(map[converterPair]TypeConverter) + } + + pair := converterPair{ + SrcType: reflect.TypeOf(opt.Converters[i].SrcType), + DstType: reflect.TypeOf(opt.Converters[i].DstType), + } + + converters[pair] = opt.Converters[i] + } + if !to.CanAddr() { return ErrInvalidCopyDestination } @@ -113,13 +147,13 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) for _, k := range from.MapKeys() { toKey := indirect(reflect.New(toType.Key())) - if !set(toKey, k, opt.DeepCopy) { + if !set(toKey, k, opt.DeepCopy, converters) { return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key()) } elemType, _ := indirectType(toType.Elem()) toValue := indirect(reflect.New(elemType)) - if !set(toValue, from.MapIndex(k), opt.DeepCopy) { + if !set(toValue, from.MapIndex(k), opt.DeepCopy, converters) { if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil { return err } @@ -148,7 +182,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem())) } - if !set(to.Index(i), from.Index(i), opt.DeepCopy) { + if !set(to.Index(i), from.Index(i), opt.DeepCopy, converters) { // ignore error while copy slice element err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt) if err != nil { @@ -251,7 +285,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) toField := dest.FieldByName(destFieldName) if toField.IsValid() { if toField.CanSet() { - if !set(toField, fromField, opt.DeepCopy) { + if !set(toField, fromField, opt.DeepCopy, converters) { if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil { return err } @@ -293,7 +327,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() { values := fromMethod.Call([]reflect.Value{}) if len(values) >= 1 { - set(toField, values[0], opt.DeepCopy) + set(toField, values[0], opt.DeepCopy, converters) } } } @@ -305,7 +339,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) if to.Len() < i+1 { to.Set(reflect.Append(to, dest.Addr())) } else { - if !set(to.Index(i), dest.Addr(), opt.DeepCopy) { + if !set(to.Index(i), dest.Addr(), opt.DeepCopy, converters) { // ignore error while copy slice element err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt) if err != nil { @@ -317,7 +351,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) if to.Len() < i+1 { to.Set(reflect.Append(to, dest)) } else { - if !set(to.Index(i), dest, opt.DeepCopy) { + if !set(to.Index(i), dest, opt.DeepCopy, converters) { // ignore error while copy slice element err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt) if err != nil { @@ -401,8 +435,14 @@ func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) { return reflectType, isPtr } -func set(to, from reflect.Value, deepCopy bool) bool { +func set(to, from reflect.Value, deepCopy bool, converters map[converterPair]TypeConverter) bool { if from.IsValid() { + if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil { + return false + } else if ok { + return true + } + if to.Kind() == reflect.Ptr { // set `to` to nil if from is nil if from.Kind() == reflect.Ptr && from.IsNil() { @@ -480,7 +520,7 @@ func set(to, from reflect.Value, deepCopy bool) bool { to.Set(rv) } } else if from.Kind() == reflect.Ptr { - return set(to, from.Elem(), deepCopy) + return set(to, from.Elem(), deepCopy, converters) } else { return false } @@ -489,6 +529,33 @@ func set(to, from reflect.Value, deepCopy bool) bool { return true } +// lookupAndCopyWithConverter looks up the type pair, on success the TypeConverter Fn func is called to copy src to dst field. +func lookupAndCopyWithConverter(to, from reflect.Value, converters map[converterPair]TypeConverter) (copied bool, err error) { + pair := converterPair{ + SrcType: from.Type(), + DstType: to.Type(), + } + + if cnv, ok := converters[pair]; ok { + result, err := cnv.Fn(from.Interface()) + + if err != nil { + return false, err + } + + if result != nil { + to.Set(reflect.ValueOf(result)) + } else { + // in case we've got a nil value to copy + to.Set(reflect.Zero(to.Type())) + } + + return true, nil + } + + return false, nil +} + // parseTags Parses struct tags and returns uint8 bit flags. func parseTags(tag string) (flg uint8, name string, err error) { for _, t := range strings.Split(tag, ",") { diff --git a/copier_converter_test.go b/copier_converter_test.go new file mode 100644 index 0000000..08bd038 --- /dev/null +++ b/copier_converter_test.go @@ -0,0 +1,187 @@ +package copier_test + +import ( + "errors" + "strconv" + "testing" + "time" + + "github.com/jinzhu/copier" +) + +func TestCopyWithTypeConverters(t *testing.T) { + type SrcStruct struct { + Field1 time.Time + Field2 *time.Time + Field3 *time.Time + Field4 string + } + + type DestStruct struct { + Field1 string + Field2 string + Field3 string + Field4 int + } + + testTime := time.Date(2021, 3, 5, 1, 30, 0, 123000000, time.UTC) + + src := SrcStruct{ + Field1: testTime, + Field2: &testTime, + Field3: nil, + Field4: "9000", + } + + var dst DestStruct + + err := copier.CopyWithOption(&dst, &src, copier.Option{ + IgnoreEmpty: true, + DeepCopy: true, + Converters: []copier.TypeConverter{ + { + SrcType: time.Time{}, + DstType: copier.String, + Fn: func(src interface{}) (interface{}, error) { + s, ok := src.(time.Time) + + if !ok { + return nil, errors.New("src type not matching") + } + + return s.Format(time.RFC3339), nil + }, + }, + { + SrcType: copier.String, + DstType: copier.Int, + Fn: func(src interface{}) (interface{}, error) { + s, ok := src.(string) + + if !ok { + return nil, errors.New("src type not matching") + } + + return strconv.Atoi(s) + }, + }, + }, + }) + + if err != nil { + t.Fatalf(`Should be able to copy from src to dst object. %v`, err) + return + } + + dateStr := "2021-03-05T01:30:00Z" + + if dst.Field1 != dateStr { + t.Fatalf("got %q, wanted %q", dst.Field1, dateStr) + } + + if dst.Field2 != dateStr { + t.Fatalf("got %q, wanted %q", dst.Field2, dateStr) + } + + if dst.Field3 != "" { + t.Fatalf("got %q, wanted %q", dst.Field3, "") + } + + if dst.Field4 != 9000 { + t.Fatalf("got %q, wanted %q", dst.Field4, 9000) + } +} + +func TestCopyWithConverterAndAnnotation(t *testing.T) { + type SrcStruct struct { + Field1 string + } + + type DestStruct struct { + Field1 string + Field2 string `copier:"Field1"` + } + + src := SrcStruct{ + Field1: "test", + } + + var dst DestStruct + + err := copier.CopyWithOption(&dst, &src, copier.Option{ + IgnoreEmpty: true, + DeepCopy: true, + Converters: []copier.TypeConverter{ + { + SrcType: copier.String, + DstType: copier.String, + Fn: func(src interface{}) (interface{}, error) { + s, ok := src.(string) + + if !ok { + return nil, errors.New("src type not matching") + } + + return s + "2", nil + }, + }, + }, + }) + + if err != nil { + t.Fatalf(`Should be able to copy from src to dst object. %v`, err) + return + } + + if dst.Field2 != "test2" { + t.Fatalf("got %q, wanted %q", dst.Field2, "test2") + } +} + +func TestCopyWithConverterStrToStrPointer(t *testing.T) { + type SrcStruct struct { + Field1 string + } + + type DestStruct struct { + Field1 *string + } + + src := SrcStruct{ + Field1: "", + } + + var dst DestStruct + + ptrStrType := "" + + err := copier.CopyWithOption(&dst, &src, copier.Option{ + IgnoreEmpty: true, + DeepCopy: true, + Converters: []copier.TypeConverter{ + { + SrcType: copier.String, + DstType: &ptrStrType, + Fn: func(src interface{}) (interface{}, error) { + s, _ := src.(string) + + // return nil on empty string + if s == "" { + return nil, nil + } + + return &s, nil + }, + }, + }, + }) + + if err != nil { + t.Fatalf(`Should be able to copy from src to dst object. %v`, err) + return + } + + if dst.Field1 != nil { + t.Fatalf("got %q, wanted nil", *dst.Field1) + } +}