Skip to content

Commit

Permalink
Merge branch 'master' of github.com:runner-mei/GoBatis
Browse files Browse the repository at this point in the history
  • Loading branch information
hengwei-test committed Jun 11, 2024
2 parents d0c5607 + 12a2791 commit a12bbad
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 18 deletions.
2 changes: 1 addition & 1 deletion core/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ func TestMapperE(t *testing.T) {
})
}

func TestMapperF(t *testing.T) {
func TestMapperSimple(t *testing.T) {
tests.Run(t, func(_ testing.TB, factory *core.Session) {
ref := factory.SessionReference()
itest := tests.NewITest(ref)
Expand Down
161 changes: 144 additions & 17 deletions core/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ func isScannable(mapper *Mapper, t reflect.Type) bool {
return false
}

type ColumnInfoGetter interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
}

// colScanner is an interface used by MapScan and SliceScan
type colScanner interface {
Columns() ([]string, error)
ColumnInfoGetter
Scan(dest ...interface{}) error
Err() error
}

type rowsi interface {
Columns() ([]string, error)
ColumnInfoGetter
Err() error
Next() bool
Scan(...interface{}) error
Expand Down Expand Up @@ -262,26 +267,148 @@ func scanAll(dialect Dialect, mapper *Mapper, rows rowsi, dest interface{}, stru
return rows.Err()
}

func defaultScanValue() (ptrValue interface{}, valueGet func() interface{}) {
var value interface{}
return &value, func() interface{} {
return value
}
}

func makeColumnValue(name string, columnType *sql.ColumnType) func() (ptrValue interface{}, valueGet func() interface{}) {
if columnType == nil {
return defaultScanValue
}

switch strings.ToLower(columnType.DatabaseTypeName()) {
case "tinyint", "smallint", "mediumint", "int", "bigint", "integer", "biginteger", "smallserial", "serial", "bigserial", "int1", "int2", "int3", "int4", "int8":
return func() (ptrValue interface{}, valueGet func() interface{}) {
var value sql.NullInt64
return &value, func() interface{} {
if value.Valid {
return value.Int64
}
return nil
}
}
case "float", "float4", "float8", "double", "decimal", "numeric", "real", "double precision":
return func() (ptrValue interface{}, valueGet func() interface{}) {
var value sql.NullFloat64
return &value, func() interface{} {
if value.Valid {
return value.Float64
}
return nil
}
}
case "varchar", "char", "text", "tinytext", "longtext", "mediumtext", "character varying", "character":
//if nullable, ok := columnType.Nullable()

return func() (ptrValue interface{}, valueGet func() interface{}) {
var value sql.NullString
return &value, func() interface{} {
if value.Valid {
return value.String
}
return nil
}
}
case "boolean", "bool", "bit":
//if nullable, ok := columnType.Nullable()

return func() (ptrValue interface{}, valueGet func() interface{}) {
var value sql.NullBool
return &value, func() interface{} {
if value.Valid {
return value.Bool
}
return nil
}
}
}

t := columnType.ScanType()
if t != nil {
return func() (ptrValue interface{}, valueGet func() interface{}) {
var value = reflect.New(t)
var nullable Nullable
nullable.Value = value.Interface()
return &nullable, func() interface{} {
return value.Elem().Interface()
}
}
}

return defaultScanValue
}

func createNewRecord(rows ColumnInfoGetter, columns []string) (func() (ptrArray []interface{}, valueArray []func() interface{}), error) {
if len(columns) == 0 {
names, err := rows.Columns()
if err != nil {
return nil, err
}
columns = names
}

newRecord := func() ([]interface{}, []func() interface{}) {
values := make([]interface{}, len(columns))
ptrValues := make([]interface{}, len(columns))
valueArray := make([]func() interface{}, len(columns))
for i := range values {
ptrValues[i] = &values[i]
valueArray[i] = func(idx int) func() interface{} {
return func() interface{} {
return values[idx]
}
}(i)
}
return ptrValues, valueArray
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
if len(columnTypes) == len(columns) {
valueArray := make([]func() (ptrValue interface{}, valueGet func() interface{}), len(columns))

for i := range columns {
valueArray[i] = makeColumnValue(columns[i], columnTypes[i])
}

newRecord = func() ([]interface{}, []func() interface{}) {
ptrValues := make([]interface{}, len(valueArray))
valueGetArray := make([]func() interface{}, len(valueArray))
for i := range valueArray {
ptrValues[i], valueGetArray[i] = valueArray[i]()
}
return ptrValues, valueGetArray
}
}
return newRecord, nil
}

func scanMapSlice(dialect Dialect, rows rowsi, dest *[]map[string]interface{}) error {
columns, err := rows.Columns()
if err != nil {
return err
}

newRecord, err := createNewRecord(rows, columns)
if err != nil {
return err
}

for rows.Next() {
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
}
ptrArray, valueArray := newRecord()

err = rows.Scan(values...)
err = rows.Scan(ptrArray...)
if err != nil {
return errors.New("Scan into Map(" + strings.Join(columns, ",") + ") error : " + err.Error())
}

one := map[string]interface{}{}
for i, column := range columns {
one[column] = *(values[i].(*interface{})) // nolint: forcetypeassert
one[column] = valueArray[i]()
}
*dest = append(*dest, one)
}
Expand Down Expand Up @@ -309,28 +436,28 @@ func StructScan(dialect Dialect, mapper *Mapper, rows rowsi, dest interface{}, i
// This will modify the map sent to it in place, so reuse the same map with
// care. Columns which occur more than once in the result will overwrite
// each other!
func MapScan(dialect Dialect, r colScanner, dest map[string]interface{}) error {
// ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns()
func MapScan(dialect Dialect, row colScanner, dest map[string]interface{}) error {
columns, err := row.Columns()
if err != nil {
return err
}

values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
newRecord, err := createNewRecord(row, columns)
if err != nil {
return err
}

err = r.Scan(values...)
ptrArray, valueArray := newRecord()
err = row.Scan(ptrArray...)
if err != nil {
return errors.New("Scan into Map(" + strings.Join(columns, ",") + ") error : " + err.Error())
}

for i, column := range columns {
dest[column] = *(values[i].(*interface{}))
dest[column] = valueArray[i]()
}

return r.Err()
return row.Err()
}

// reflect helpers
Expand Down
1 change: 1 addition & 0 deletions dialects/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ var (
newBlob: newBlob,
makeArrayValuer: makeArrayValuer,
makeArrayScanner: makeArrayScanner,
limitFunc: limitByLimitMN,
}
MSSql Dialect = &dialect{
name: "mssql",
Expand Down

0 comments on commit a12bbad

Please sign in to comment.