From 339aacfe9eca16bb725824f470a8c4f2c2ee1c4e Mon Sep 17 00:00:00 2001 From: qiangxue Date: Wed, 3 Aug 2016 15:33:10 -0400 Subject: [PATCH] Fixes #19: exposes GetTableName() method --- model_query.go | 3 +-- select.go | 4 ++-- struct.go | 9 +++++---- struct_test.go | 16 ++++++++-------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/model_query.go b/model_query.go index 80a89ca..b08f239 100644 --- a/model_query.go +++ b/model_query.go @@ -54,7 +54,6 @@ func (q *ModelQuery) Insert(attrs ...string) error { if q.lastError != nil { return q.lastError } - tableName := q.model.tableName cols := q.model.columns(attrs, q.exclude) pkName := "" for name, value := range q.model.pk() { @@ -65,7 +64,7 @@ func (q *ModelQuery) Insert(attrs ...string) error { } } - result, err := q.builder.Insert(tableName, Params(cols)).Execute() + result, err := q.builder.Insert(q.model.tableName, Params(cols)).Execute() if err == nil && pkName != "" { pkValue, err := result.LastInsertId() if err != nil { diff --git a/select.go b/select.go index 3fd3cac..94cc3d3 100644 --- a/select.go +++ b/select.go @@ -271,7 +271,7 @@ func (s *SelectQuery) Build() *Query { // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. func (s *SelectQuery) One(a interface{}) error { if len(s.from) == 0 { - if tableName := getTableName(a); tableName != "" { + if tableName := GetTableName(a); tableName != "" { s.from = []string{tableName} } } @@ -312,7 +312,7 @@ func (s *SelectQuery) Model(pk, model interface{}) error { // or the TableName() method if the slice element implements the TableModel interface. func (s *SelectQuery) All(slice interface{}) error { if len(s.from) == 0 { - if tableName := getTableName(slice); tableName != "" { + if tableName := GetTableName(slice); tableName != "" { s.from = []string{tableName} } } diff --git a/struct.go b/struct.go index 1aa9790..79969d3 100644 --- a/struct.go +++ b/struct.go @@ -86,7 +86,7 @@ func newStructValue(model interface{}, mapper FieldMapFunc) *structValue { return &structValue{ structInfo: getStructInfo(reflect.TypeOf(model).Elem(), mapper), value: value.Elem(), - tableName: getTableName(model), + tableName: GetTableName(model), } } @@ -243,8 +243,9 @@ func indirect(v reflect.Value) reflect.Value { return v } -// getTableName returns the table name corresponding to the given model struct or slice of structs. -func getTableName(a interface{}) string { +// GetTableName returns the table name corresponding to the given model struct or slice of structs. +// Do not call this method in the model's TableName() method, or it will cause infinite loop. +func GetTableName(a interface{}) string { if tm, ok := a.(TableModel); ok { v := reflect.ValueOf(a) if v.Kind() == reflect.Ptr && v.IsNil() { @@ -258,7 +259,7 @@ func getTableName(a interface{}) string { t = t.Elem() } if t.Kind() == reflect.Slice { - return getTableName(reflect.Zero(t.Elem()).Interface()) + return GetTableName(reflect.Zero(t.Elem()).Interface()) } return DefaultFieldMapFunc(t.Name()) } diff --git a/struct_test.go b/struct_test.go index 4a6a39b..4a839d1 100644 --- a/struct_test.go +++ b/struct_test.go @@ -96,28 +96,28 @@ type MyCustomer struct{} func Test_getTableName(t *testing.T) { var c1 Customer - assert.Equal(t, "customer", getTableName(c1)) + assert.Equal(t, "customer", GetTableName(c1)) var c2 *Customer - assert.Equal(t, "customer", getTableName(c2)) + assert.Equal(t, "customer", GetTableName(c2)) var c3 MyCustomer - assert.Equal(t, "my_customer", getTableName(c3)) + assert.Equal(t, "my_customer", GetTableName(c3)) var c4 []Customer - assert.Equal(t, "customer", getTableName(c4)) + assert.Equal(t, "customer", GetTableName(c4)) var c5 *[]Customer - assert.Equal(t, "customer", getTableName(c5)) + assert.Equal(t, "customer", GetTableName(c5)) var c6 []MyCustomer - assert.Equal(t, "my_customer", getTableName(c6)) + assert.Equal(t, "my_customer", GetTableName(c6)) var c7 []CustomerPtr - assert.Equal(t, "customer", getTableName(c7)) + assert.Equal(t, "customer", GetTableName(c7)) var c8 **int - assert.Equal(t, "", getTableName(c8)) + assert.Equal(t, "", GetTableName(c8)) } type FA struct {