From 4e5421b4fa453546550a2f167b479822983b5fc3 Mon Sep 17 00:00:00 2001 From: kPshi Date: Fri, 11 Aug 2017 21:26:50 +0200 Subject: [PATCH 1/4] show #42: Check on Ptr of Struct too strict --- select_test.go | 18 +++++++++ struct_test.go | 105 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 103 insertions(+), 20 deletions(-) diff --git a/select_test.go b/select_test.go index b4402e7..fdad127 100644 --- a/select_test.go +++ b/select_test.go @@ -120,6 +120,24 @@ func TestSelectQuery_Model(t *testing.T) { } } + { + // All without specifying FROM using []* + var customers []*CustomerPtr + err := db.Select().OrderBy("id").All(&customers) + if assert.Nil(t, err) { + assert.Equal(t, 3, len(customers)) + } + } + + { + // All without specifying FROM using []** + var customers []**CustomerPtr + err := db.Select().OrderBy("id").All(&customers) + if assert.Nil(t, err) { + assert.Equal(t, 3, len(customers)) + } + } + { // Model without specifying FROM var customer CustomerPtr diff --git a/struct_test.go b/struct_test.go index 3ec60b9..ec77fa5 100644 --- a/struct_test.go +++ b/struct_test.go @@ -65,6 +65,13 @@ func Test_indirect(t *testing.T) { if assert.NotNil(t, b) { assert.Equal(t, 0, *b) } + + var c1 Customer = Customer{} + + vc := indirect(reflect.ValueOf(&c1)) + assert.Equal(t, reflect.Struct, vc.Kind()) + assert.Equal(t, reflect.Struct, vc.Type().Kind()) + assert.True(t, vc.CanSet()) } func Test_structValue_columns(t *testing.T) { @@ -99,18 +106,18 @@ func TestIssue37(t *testing.T) { Status: 2, Email: "abc@example.com", } - ev := struct{ + ev := struct { Customer Status string - } {customer, "20"} + }{customer, "20"} sv := newStructValue(&ev, nil) cols := sv.columns([]string{"ID", "Status"}, nil) assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) - ev2 := struct{ + ev2 := struct { Status string Customer - } {"20", customer} + }{"20", customer} sv = newStructValue(&ev2, nil) cols = sv.columns([]string{"ID", "Status"}, nil) assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) @@ -118,30 +125,88 @@ func TestIssue37(t *testing.T) { type MyCustomer struct{} +type SomeTable struct{} + +func (*SomeTable) TableName() string { + return "strange_name" +} + func Test_getTableName(t *testing.T) { - var c1 Customer - assert.Equal(t, "customer", GetTableName(c1)) + { + var c Customer + assert.Equal(t, "customer", GetTableName(c)) + } - var c2 *Customer - assert.Equal(t, "customer", GetTableName(c2)) + { + var c *Customer + assert.Equal(t, "customer", GetTableName(c)) + } - var c3 MyCustomer - assert.Equal(t, "my_customer", GetTableName(c3)) + { + var c MyCustomer + assert.Equal(t, "my_customer", GetTableName(c)) + } - var c4 []Customer - assert.Equal(t, "customer", GetTableName(c4)) + { + var c []Customer + assert.Equal(t, "customer", GetTableName(c)) + } - var c5 *[]Customer - assert.Equal(t, "customer", GetTableName(c5)) + { + var c *[]Customer + assert.Equal(t, "customer", GetTableName(c)) + } - var c6 []MyCustomer - assert.Equal(t, "my_customer", GetTableName(c6)) + { + var c []*Customer + assert.Equal(t, "customer", GetTableName(c)) + } - var c7 []CustomerPtr - assert.Equal(t, "customer", GetTableName(c7)) + { + var c []MyCustomer + assert.Equal(t, "my_customer", GetTableName(c)) + } - var c8 **int - assert.Equal(t, "", GetTableName(c8)) + { + var c []CustomerPtr + assert.Equal(t, "customer", GetTableName(c)) + } + + { + var c **int + assert.Equal(t, "", GetTableName(c)) + } + + { + var c ***[]Customer + assert.Equal(t, "customer", GetTableName(c)) + } + + { + func(i interface{}) { + func(c interface{}) { + assert.Equal(t, "customer", GetTableName(c)) + }(&i) + }(&Customer{}) + } + + { + func(i interface{}) { + func(c interface{}) { + assert.Equal(t, "customer", GetTableName(&c)) + }(&i) + }(&Customer{}) + } + + { + var c *SomeTable + assert.Equal(t, "strange_name", GetTableName(c)) + } + + { + var c **SomeTable + assert.Equal(t, "strange_name", GetTableName(c)) + } } type FA struct { From c205edfdb1f3ff431c1041984a474c156c8ebb78 Mon Sep 17 00:00:00 2001 From: kPshi Date: Fri, 11 Aug 2017 21:28:24 +0200 Subject: [PATCH 2/4] fix #42: Check on Ptr of Struct too strict? --- rows.go | 130 ++++++++++++++++++++++++++---------------------------- select.go | 6 +-- struct.go | 95 +++++++++++++++++++++++++++++++++------ 3 files changed, 145 insertions(+), 86 deletions(-) diff --git a/rows.go b/rows.go index 84bc06a..edc5a98 100644 --- a/rows.go +++ b/rows.go @@ -6,6 +6,7 @@ package dbx import ( "database/sql" + "fmt" "reflect" ) @@ -63,14 +64,16 @@ func (r *Rows) ScanMap(a NullStringMap) error { // To change the default behavior, set DB.FieldMapper with your custom mapping function. // You may also set Query.FieldMapper to change the behavior for particular queries. func (r *Rows) ScanStruct(a interface{}) error { - rv := reflect.ValueOf(a) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return VarTypeError("must be a pointer") - } - rv = indirect(rv) + return r.scanStructV(reflect.ValueOf(a)) +} +func (r *Rows) scanStructV(rv reflect.Value) error { + rv = indirect0(rv, nil, true, true) if rv.Kind() != reflect.Struct { return VarTypeError("must be a pointer to a struct") } + if !rv.CanSet() { + return VarTypeError("not settable value") + } si := getStructInfo(rv.Type(), r.fieldMapFunc) @@ -93,56 +96,62 @@ func (r *Rows) ScanStruct(a interface{}) error { func (r *Rows) all(slice interface{}) error { defer r.Close() - v := reflect.ValueOf(slice) - if v.Kind() != reflect.Ptr || v.IsNil() { - return VarTypeError("must be a pointer") - } - v = indirect(v) + v := indirect(reflect.ValueOf(slice)) if v.Kind() != reflect.Slice { - return VarTypeError("must be a slice of struct or NullStringMap") + return VarTypeError("not a slice, must be a slice of struct or NullStringMap") } - - et := v.Type().Elem() - - if et.Kind() == reflect.Map { - for r.Next() { - ev, ok := reflect.MakeMap(et).Interface().(NullStringMap) - if !ok { - return VarTypeError("must be a slice of struct or NullStringMap") - } - if err := r.ScanMap(ev); err != nil { - return err - } - v.Set(reflect.Append(v, reflect.ValueOf(ev))) - } - return r.Close() + if !v.CanSet() { + return VarTypeError("slice not settable") } - if et.Kind() != reflect.Struct { - return VarTypeError("must be a slice of struct or NullStringMap") + // check for a valid element type + et := v.Type().Elem() + var si *structInfo + finV := indirect0(reflect.New(et), nil, false, false) + switch finV.Kind() { + default: + return VarTypeError(fmt.Sprintf("a slice of %s, must be a slice of struct or NullStringMap", finV.Kind())) + case reflect.Map: + case reflect.Struct: + si = getStructInfo(finV.Type(), r.fieldMapFunc) + } + cols, err := r.Columns() + if err != nil { + return err } - si := getStructInfo(et, r.fieldMapFunc) - - cols, _ := r.Columns() + // everything prepared, now scan the result for r.Next() { - ev := reflect.New(et).Elem() + ev, err := r.scanRow(et, si, cols) + if err != nil { + return err + } + newSliceV := reflect.Append(v, ev) + v.Set(newSliceV) + } + return nil +} +func (r *Rows) scanRow(et reflect.Type, si *structInfo, cols []string) (ev reflect.Value, err error) { + ev = reflect.New(et).Elem() + evi := indirect(ev) + if evi.Kind() == reflect.Map { + if evi.IsNil() { + evi.Set(reflect.MakeMap(evi.Type())) + } + err = r.ScanMap(evi.Interface().(NullStringMap)) + } else { refs := make([]interface{}, len(cols)) for i, col := range cols { if fi, ok := si.dbNameMap[col]; ok { - refs[i] = fi.getField(ev).Addr().Interface() + refs[i] = fi.getField(evi).Addr().Interface() } else { refs[i] = &sql.NullString{} } } - if err := r.Scan(refs...); err != nil { - return err - } - v.Set(reflect.Append(v, ev)) + err = r.Scan(refs...) } - - return r.Close() + return } // column populates the given slice with the first column of the query result. @@ -150,14 +159,13 @@ func (r *Rows) all(slice interface{}) error { func (r *Rows) column(slice interface{}) error { defer r.Close() - v := reflect.ValueOf(slice) - if v.Kind() != reflect.Ptr || v.IsNil() { - return VarTypeError("must be a pointer to a slice") + v := indirect(reflect.ValueOf(slice)) + if !v.CanSet() { + return VarTypeError("not settable value") } - v = indirect(v) if v.Kind() != reflect.Slice { - return VarTypeError("must be a pointer to a slice") + return VarTypeError("must be a (pointer to a) slice") } et := v.Type().Elem() @@ -194,37 +202,23 @@ func (r *Rows) one(a interface{}) error { return sql.ErrNoRows } - var err error + rv := indirect(reflect.ValueOf(a)) - rt := reflect.TypeOf(a) - if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map { - // pointer to map - v := indirect(reflect.ValueOf(a)) - if v.IsNil() { - v.Set(reflect.MakeMap(v.Type())) + if rv.Kind() == reflect.Map { + if rv.IsNil() { + if !rv.CanSet() { + return VarTypeError("not settable value") + } + rv.Set(reflect.MakeMap(rv.Type())) } - a = v.Interface() - rt = reflect.TypeOf(a) - } - - if rt.Kind() == reflect.Map { - v, ok := a.(NullStringMap) + v, ok := rv.Interface().(NullStringMap) if !ok { return VarTypeError("must be a NullStringMap") } - if v == nil { - return VarTypeError("NullStringMap is nil") - } - err = r.ScanMap(v) + return r.ScanMap(v) } else { - err = r.ScanStruct(a) + return r.scanStructV(rv) } - - if err != nil { - return err - } - - return r.Close() } // row populates a single row of query result into a list of variables. diff --git a/select.go b/select.go index 3910ad6..034507a 100644 --- a/select.go +++ b/select.go @@ -286,10 +286,8 @@ func (s *SelectQuery) One(a interface{}) error { // to infer the name of the primary key column. Only simple primary key is supported. For composite primary keys, // please use Where() to specify the filtering condition. func (s *SelectQuery) Model(pk, model interface{}) error { - t := reflect.TypeOf(model) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + v := indirect(reflect.ValueOf(model)) + t := v.Type() if t.Kind() != reflect.Struct { return VarTypeError("must be a pointer to a struct") } diff --git a/struct.go b/struct.go index f29775e..a28e3cf 100644 --- a/struct.go +++ b/struct.go @@ -78,14 +78,14 @@ func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo { } func newStructValue(model interface{}, mapper FieldMapFunc) *structValue { - value := reflect.ValueOf(model) - if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() { + value := indirect0(reflect.ValueOf(model), nil, false, true) + if value.Kind() != reflect.Struct { return nil } return &structValue{ - structInfo: getStructInfo(reflect.TypeOf(model).Elem(), mapper), - value: value.Elem(), + structInfo: getStructInfo(value.Type(), mapper), + value: value, tableName: GetTableName(model), } } @@ -237,13 +237,7 @@ func concat(s1, s2 string) string { // indirect dereferences pointers and returns the actual value it points to. // If a pointer is nil, it will be initialized with a new value. func indirect(v reflect.Value) reflect.Value { - for v.Kind() == reflect.Ptr { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - return v + return indirect0(v, nil, false, true) } // GetTableName returns the table name corresponding to the given model struct or slice of structs. @@ -257,12 +251,85 @@ func GetTableName(a interface{}) string { } return tm.TableName() } - t := reflect.TypeOf(a) - if t.Kind() == reflect.Ptr { + + tmt := reflect.TypeOf((*TableModel)(nil)).Elem() + v := indirect0(reflect.ValueOf(a), func(v reflect.Value) bool { + // stop once we found something that implements TableModel + return v.Type().Implements(tmt) + + }, false, false) + // may well be our early exit got us here + t := v.Type() + for t.Kind() == reflect.Ptr && !t.Implements(tmt) { t = t.Elem() } + if t.Implements(tmt) { + return GetTableName(reflect.New(t).Elem().Interface()) + } + // a slice may have elements of the type we're searching for if t.Kind() == reflect.Slice { - return GetTableName(reflect.Zero(t.Elem()).Interface()) + return GetTableName(reflect.New(t.Elem()).Elem().Interface()) + } + // or a struct where we derive the name from + if t.Kind() != reflect.Struct { + // otherwise we can't do anything with that value + return "" } return DefaultFieldMapFunc(t.Name()) } + +// resolves the given value through pointers and interfaces, returning the actual value found. +// When init == true it initializes empty pointers on the way down. Once the match function returns +// true when called with the current value and type handled, the process is stopped early. +func indirect0(v reflect.Value, match func(v reflect.Value) bool, lastMatching, init bool) reflect.Value { + // need to use Value.Elem() to get through interfaces + found := v + aborted := false +FOR: + for v.IsValid() { + // continue? + if match != nil && match(v) { + found = v + aborted = true + if !lastMatching { + break FOR + } + } + // inspect type + k := v.Kind() + switch { + case k == reflect.Ptr: + if v.IsNil() { + if init && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + } else { + v = reflect.New(v.Type().Elem()) + } + } + v = v.Elem() + case k == reflect.Interface: + if v.IsNil() { + break FOR + } + v = v.Elem() + case k == reflect.Slice: + if v.IsNil() { + if init && v.CanSet() { + v.Set(reflect.MakeSlice(reflect.SliceOf(v.Type().Elem()), 0, 0)) + } else { + v = reflect.MakeSlice(reflect.SliceOf(v.Type().Elem()), 0, 0) + } + } + break FOR + default: + break FOR + } + } + + // done + if aborted { + return found + } else { + return v + } +} From 8daae1a99d75ea046461b230507a3b17893c5498 Mon Sep 17 00:00:00 2001 From: Tobias Sprute Date: Thu, 21 Jun 2018 18:30:38 +0200 Subject: [PATCH 3/4] no Goland files, no vendor files --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index fb6c1ae..3f4e7fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ +# Goland files .idea +/*.iml # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o @@ -24,3 +26,7 @@ _testmain.go *.exe *.test *.prof + +# No vendoring +/vendor + From 90cec8f60e679b63ac4fc467ea46d33ed62b8331 Mon Sep 17 00:00:00 2001 From: Tobias Sprute Date: Sun, 22 Jul 2018 21:20:21 +0200 Subject: [PATCH 4/4] use proper package in example_test.go --- example_test.go | 80 ++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/example_test.go b/example_test.go index 5b0f864..cf87357 100644 --- a/example_test.go +++ b/example_test.go @@ -1,14 +1,12 @@ -package dbx_test +package dbx import ( "fmt" - - "github.com/go-ozzo/ozzo-dbx" ) // This example shows how to populate DB data in different ways. func Example_dbQueries() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") // create a new query q := db.NewQuery("SELECT id, name FROM users LIMIT 10") @@ -26,7 +24,7 @@ func Example_dbQueries() { q.One(&user) // fetch a single row into a string map - data := dbx.NullStringMap{} + data := NullStringMap{} q.One(data) // fetch row by row @@ -40,13 +38,13 @@ func Example_dbQueries() { // This example shows how to use query builder to build DB queries. func Example_queryBuilder() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") // build a SELECT query // SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id` q := db.Select("id", "name"). From("users"). - Where(dbx.Like("name", "Charles")). + Where(Like("name", "Charles")). OrderBy("id") // fetch all rows into a struct array @@ -57,39 +55,39 @@ func Example_queryBuilder() { // build an INSERT query // INSERT INTO `users` (name) VALUES ('James') - db.Insert("users", dbx.Params{ + db.Insert("users", Params{ "name": "James", }).Execute() } // This example shows how to use query builder in transactions. func Example_transactions() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") - db.Transactional(func(tx *dbx.Tx) error { - _, err := tx.Insert("user", dbx.Params{ + db.Transactional(func(tx *Tx) error { + _, err := tx.Insert("user", Params{ "name": "user1", }).Execute() if err != nil { return err } - _, err = tx.Insert("user", dbx.Params{ + _, err = tx.Insert("user", Params{ "name": "user2", }).Execute() return err }) } -type Customer struct { +type TestCustomer struct { ID string Name string } // This example shows how to do CRUD operations. func Example_crudOperations() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") - var customer Customer + var customer TestCustomer // read a customer: SELECT * FROM customer WHERE id=100 db.Select().Model(100, &customer) @@ -105,18 +103,18 @@ func Example_crudOperations() { } func ExampleSchemaBuilder() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") - db.Insert("users", dbx.Params{ + db.Insert("users", Params{ "name": "James", "age": 30, }).Execute() } func ExampleRows_ScanMap() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") - user := dbx.NullStringMap{} + user := NullStringMap{} sql := "SELECT id, name FROM users LIMIT 10" rows, _ := db.NewQuery(sql).Rows() @@ -127,7 +125,7 @@ func ExampleRows_ScanMap() { } func ExampleRows_ScanStruct() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") var user struct { ID, Name string @@ -142,7 +140,7 @@ func ExampleRows_ScanStruct() { } func ExampleQuery_All() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users LIMIT 10" // fetches data into a slice of struct @@ -152,7 +150,7 @@ func ExampleQuery_All() { db.NewQuery(sql).All(&users) // fetches data into a slice of NullStringMap - var users2 []dbx.NullStringMap + var users2 []NullStringMap db.NewQuery(sql).All(&users2) for _, user := range users2 { fmt.Println(user["id"].String, user["name"].String) @@ -160,7 +158,7 @@ func ExampleQuery_All() { } func ExampleQuery_One() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users LIMIT 10" // fetches data into a struct @@ -170,13 +168,13 @@ func ExampleQuery_One() { db.NewQuery(sql).One(&user) // fetches data into a NullStringMap - var user2 dbx.NullStringMap + var user2 NullStringMap db.NewQuery(sql).All(user2) fmt.Println(user2["id"].String, user2["name"].String) } func ExampleQuery_Row() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users LIMIT 10" // fetches data into a struct @@ -192,7 +190,7 @@ func ExampleQuery_Rows() { ID, Name string } - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users LIMIT 10" rows, _ := db.NewQuery(sql).Rows() @@ -207,11 +205,11 @@ func ExampleQuery_Bind() { ID, Name string } - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" q := db.NewQuery(sql) - q.Bind(dbx.Params{"age": 30, "status": 1}).One(&user) + q.Bind(Params{"age": 30, "status": 1}).One(&user) } func ExampleQuery_Prepare() { @@ -219,19 +217,19 @@ func ExampleQuery_Prepare() { ID, Name string } - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" q := db.NewQuery(sql).Prepare() defer q.Close() - q.Bind(dbx.Params{"age": 30, "status": 1}).All(&users1) - q.Bind(dbx.Params{"age": 20, "status": 1}).All(&users2) - q.Bind(dbx.Params{"age": 10, "status": 1}).All(&users3) + q.Bind(Params{"age": 30, "status": 1}).All(&users1) + q.Bind(Params{"age": 20, "status": 1}).All(&users2) + q.Bind(Params{"age": 10, "status": 1}).All(&users3) } func ExampleDB() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") // queries data through a plain SQL var users []struct { @@ -240,36 +238,36 @@ func ExampleDB() { db.NewQuery("SELECT id, name FROM users WHERE age=30").All(&users) // queries data using query builder - db.Select("id", "name").From("users").Where(dbx.HashExp{"age": 30}).All(&users) + db.Select("id", "name").From("users").Where(HashExp{"age": 30}).All(&users) // executes a plain SQL - db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(dbx.Params{"name": "James"}).Execute() + db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(Params{"name": "James"}).Execute() // executes a SQL using query builder - db.Insert("users", dbx.Params{"name": "James"}).Execute() + db.Insert("users", Params{"name": "James"}).Execute() } func ExampleDB_Open() { - db, err := dbx.Open("mysql", "user:pass@/example") + db, err := Open("mysql", "user:pass@/example") if err != nil { panic(err) } - var users []dbx.NullStringMap + var users []NullStringMap if err := db.NewQuery("SELECT * FROM users LIMIT 10").All(&users); err != nil { panic(err) } } func ExampleDB_Begin() { - db, _ := dbx.Open("mysql", "user:pass@/example") + db, _ := Open("mysql", "user:pass@/example") tx, _ := db.Begin() - _, err1 := tx.Insert("user", dbx.Params{ + _, err1 := tx.Insert("user", Params{ "name": "user1", }).Execute() - _, err2 := tx.Insert("user", dbx.Params{ + _, err2 := tx.Insert("user", Params{ "name": "user2", }).Execute()