diff --git a/copier.go b/copier.go index 6dc9600..5fba6b8 100644 --- a/copier.go +++ b/copier.go @@ -3,7 +3,6 @@ package copier import ( "database/sql" "database/sql/driver" - "errors" "fmt" "reflect" "strings" @@ -45,7 +44,7 @@ type Option struct { type TypeConverter struct { SrcType interface{} DstType interface{} - Fn func(src interface{}) (interface{}, error) + Fn func(src interface{}) (dst interface{}, err error) } type converterPair struct { @@ -392,11 +391,7 @@ func copyUnexportedStructFields(to, from reflect.Value) { } func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool { - if !ignoreEmpty { - return false - } - - return v.IsZero() + return ignoreEmpty && v.IsZero() } func deepFields(reflectType reflect.Type) []reflect.StructField { @@ -439,94 +434,95 @@ func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) { } func set(to, from reflect.Value, deepCopy bool, converters map[converterPair]TypeConverter) bool { - if from.IsValid() { - if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil { - return false - } else if ok { - return true - } + if !from.IsValid() { + return true + } + if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil { + return false + } else if ok { + return true + } - if to.Kind() == reflect.Ptr { - // set `to` to nil if from is nil - if from.Kind() == reflect.Ptr && from.IsNil() { - to.Set(reflect.Zero(to.Type())) - return true - } else if to.IsNil() { - // `from` -> `to` - // sql.NullString -> *string - if fromValuer, ok := driverValuer(from); ok { - v, err := fromValuer.Value() - if err != nil { - return false - } - // if `from` is not valid do nothing with `to` - if v == nil { - return true - } + if to.Kind() == reflect.Ptr { + // set `to` to nil if from is nil + if from.Kind() == reflect.Ptr && from.IsNil() { + to.Set(reflect.Zero(to.Type())) + return true + } else if to.IsNil() { + // `from` -> `to` + // sql.NullString -> *string + if fromValuer, ok := driverValuer(from); ok { + v, err := fromValuer.Value() + if err != nil { + return false + } + // if `from` is not valid do nothing with `to` + if v == nil { + return true } - // allocate new `to` variable with default value (eg. *string -> new(string)) - to.Set(reflect.New(to.Type().Elem())) } - // depointer `to` - to = to.Elem() + // allocate new `to` variable with default value (eg. *string -> new(string)) + to.Set(reflect.New(to.Type().Elem())) } + // depointer `to` + to = to.Elem() + } - if deepCopy { - toKind := to.Kind() - if toKind == reflect.Interface && to.IsNil() { - if reflect.TypeOf(from.Interface()) != nil { - to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem()) - toKind = reflect.TypeOf(to.Interface()).Kind() - } - } - if from.Kind() == reflect.Ptr && from.IsNil() { - return true - } - if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice { - return false + if deepCopy { + toKind := to.Kind() + if toKind == reflect.Interface && to.IsNil() { + if reflect.TypeOf(from.Interface()) != nil { + to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem()) + toKind = reflect.TypeOf(to.Interface()).Kind() } } + if from.Kind() == reflect.Ptr && from.IsNil() { + return true + } + if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice { + return false + } + } - if from.Type().ConvertibleTo(to.Type()) { - to.Set(from.Convert(to.Type())) - } else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok { - // `from` -> `to` - // *string -> sql.NullString - if from.Kind() == reflect.Ptr { - // if `from` is nil do nothing with `to` - if from.IsNil() { - return true - } - // depointer `from` - from = indirect(from) - } - // `from` -> `to` - // string -> sql.NullString - // set `to` by invoking method Scan(`from`) - err := toScanner.Scan(from.Interface()) - if err != nil { - return false - } - } else if fromValuer, ok := driverValuer(from); ok { - // `from` -> `to` - // sql.NullString -> string - v, err := fromValuer.Value() - if err != nil { - return false - } - // if `from` is not valid do nothing with `to` - if v == nil { + if from.Type().ConvertibleTo(to.Type()) { + to.Set(from.Convert(to.Type())) + } else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok { + // `from` -> `to` + // *string -> sql.NullString + if from.Kind() == reflect.Ptr { + // if `from` is nil do nothing with `to` + if from.IsNil() { return true } - rv := reflect.ValueOf(v) - if rv.Type().AssignableTo(to.Type()) { - to.Set(rv) - } - } else if from.Kind() == reflect.Ptr { - return set(to, from.Elem(), deepCopy, converters) - } else { + // depointer `from` + from = indirect(from) + } + // `from` -> `to` + // string -> sql.NullString + // set `to` by invoking method Scan(`from`) + err := toScanner.Scan(from.Interface()) + if err != nil { + return false + } + } else if fromValuer, ok := driverValuer(from); ok { + // `from` -> `to` + // sql.NullString -> string + v, err := fromValuer.Value() + if err != nil { return false } + // if `from` is not valid do nothing with `to` + if v == nil { + return true + } + rv := reflect.ValueOf(v) + if rv.Type().AssignableTo(to.Type()) { + to.Set(rv) + } + } else if from.Kind() == reflect.Ptr { + return set(to, from.Elem(), deepCopy, converters) + } else { + return false } return true @@ -574,7 +570,7 @@ func parseTags(tag string) (flg uint8, name string, err error) { if unicode.IsUpper([]rune(t)[0]) { name = strings.TrimSpace(t) } else { - err = errors.New("copier field name tag must be start upper case") + err = ErrFieldNameTagStartNotUpperCase } } } diff --git a/errors.go b/errors.go index cf7c5e7..f26db07 100644 --- a/errors.go +++ b/errors.go @@ -3,8 +3,9 @@ package copier import "errors" var ( - ErrInvalidCopyDestination = errors.New("copy destination is invalid") - ErrInvalidCopyFrom = errors.New("copy from is invalid") - ErrMapKeyNotMatch = errors.New("map's key type doesn't match") - ErrNotSupported = errors.New("not supported") + ErrInvalidCopyDestination = errors.New("copy destination is invalid") + ErrInvalidCopyFrom = errors.New("copy from is invalid") + ErrMapKeyNotMatch = errors.New("map's key type doesn't match") + ErrNotSupported = errors.New("not supported") + ErrFieldNameTagStartNotUpperCase = errors.New("copier field name tag must be start upper case") )