Skip to content

Commit

Permalink
Fixes #19: exposes GetTableName() method
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Aug 3, 2016
1 parent f6ece0f commit 339aacf
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
3 changes: 1 addition & 2 deletions model_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func (q *ModelQuery) Insert(attrs ...string) error {
if q.lastError != nil {
return q.lastError
}
tableName := q.model.tableName
cols := q.model.columns(attrs, q.exclude)
pkName := ""
for name, value := range q.model.pk() {
Expand All @@ -65,7 +64,7 @@ func (q *ModelQuery) Insert(attrs ...string) error {
}
}

result, err := q.builder.Insert(tableName, Params(cols)).Execute()
result, err := q.builder.Insert(q.model.tableName, Params(cols)).Execute()
if err == nil && pkName != "" {
pkValue, err := result.LastInsertId()
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (s *SelectQuery) Build() *Query {
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (s *SelectQuery) One(a interface{}) error {
if len(s.from) == 0 {
if tableName := getTableName(a); tableName != "" {
if tableName := GetTableName(a); tableName != "" {
s.from = []string{tableName}
}
}
Expand Down Expand Up @@ -312,7 +312,7 @@ func (s *SelectQuery) Model(pk, model interface{}) error {
// or the TableName() method if the slice element implements the TableModel interface.
func (s *SelectQuery) All(slice interface{}) error {
if len(s.from) == 0 {
if tableName := getTableName(slice); tableName != "" {
if tableName := GetTableName(slice); tableName != "" {
s.from = []string{tableName}
}
}
Expand Down
9 changes: 5 additions & 4 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func newStructValue(model interface{}, mapper FieldMapFunc) *structValue {
return &structValue{
structInfo: getStructInfo(reflect.TypeOf(model).Elem(), mapper),
value: value.Elem(),
tableName: getTableName(model),
tableName: GetTableName(model),
}
}

Expand Down Expand Up @@ -243,8 +243,9 @@ func indirect(v reflect.Value) reflect.Value {
return v
}

// getTableName returns the table name corresponding to the given model struct or slice of structs.
func getTableName(a interface{}) string {
// GetTableName returns the table name corresponding to the given model struct or slice of structs.
// Do not call this method in the model's TableName() method, or it will cause infinite loop.
func GetTableName(a interface{}) string {
if tm, ok := a.(TableModel); ok {
v := reflect.ValueOf(a)
if v.Kind() == reflect.Ptr && v.IsNil() {
Expand All @@ -258,7 +259,7 @@ func getTableName(a interface{}) string {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
return getTableName(reflect.Zero(t.Elem()).Interface())
return GetTableName(reflect.Zero(t.Elem()).Interface())
}
return DefaultFieldMapFunc(t.Name())
}
16 changes: 8 additions & 8 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,28 @@ type MyCustomer struct{}

func Test_getTableName(t *testing.T) {
var c1 Customer
assert.Equal(t, "customer", getTableName(c1))
assert.Equal(t, "customer", GetTableName(c1))

var c2 *Customer
assert.Equal(t, "customer", getTableName(c2))
assert.Equal(t, "customer", GetTableName(c2))

var c3 MyCustomer
assert.Equal(t, "my_customer", getTableName(c3))
assert.Equal(t, "my_customer", GetTableName(c3))

var c4 []Customer
assert.Equal(t, "customer", getTableName(c4))
assert.Equal(t, "customer", GetTableName(c4))

var c5 *[]Customer
assert.Equal(t, "customer", getTableName(c5))
assert.Equal(t, "customer", GetTableName(c5))

var c6 []MyCustomer
assert.Equal(t, "my_customer", getTableName(c6))
assert.Equal(t, "my_customer", GetTableName(c6))

var c7 []CustomerPtr
assert.Equal(t, "customer", getTableName(c7))
assert.Equal(t, "customer", GetTableName(c7))

var c8 **int
assert.Equal(t, "", getTableName(c8))
assert.Equal(t, "", GetTableName(c8))
}

type FA struct {
Expand Down

0 comments on commit 339aacf

Please sign in to comment.