Skip to content

Commit

Permalink
feat: 为 upsert 增加 unique 支持
Browse files Browse the repository at this point in the history
  • Loading branch information
hengwei-test committed May 11, 2024
1 parent 7cdec6b commit 17289ca
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 26 deletions.
35 changes: 30 additions & 5 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,37 @@ func GenerateUpsertSQL(dbType Dialect, mapper *Mapper, rType reflect.Type, keyNa
keyFields = incrFields
}
} else {
for idx := range keyNames {
fi, _, err := toFieldName(structType, keyNames[idx], nil)
if err != nil {
return "", errors.New("upsert isnot generate, " + err.Error())
uniqueNameOk := false
if len(keyNames) == 1 {
fmt.Println("=======1", keyNames[0])
for _, field := range structType.Index {
if _, ok := field.Options["autoincr"]; ok {
continue
}
// if _, ok := field.Options["pk"]; ok {
// continue
// }

key, ok := field.Options["unique"];
fmt.Println("=======2", field.Name, key, ok, field.Options)
if ok {
fmt.Println("=======", keyNames[0], key, field.Name)
if strings.EqualFold(keyNames[0], key) {
keyFields = append(keyFields, field)
uniqueNameOk = true
}
}
}
}

if !uniqueNameOk {
for idx := range keyNames {
fi, _, err := toFieldName(structType, keyNames[idx], nil)
if err != nil {
return "", errors.New("upsert isnot generate, " + err.Error())
}
keyFields = append(keyFields, fi)
}
keyFields = append(keyFields, fi)
}
}

Expand Down
16 changes: 16 additions & 0 deletions cmd/gobatis/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os/exec"
"path/filepath"
"reflect"
"runtime/debug"
"strings"
"text/template"

Expand Down Expand Up @@ -599,6 +600,12 @@ var funcs = template.FuncMap{
"pluralize": Pluralize,
"camelizeDownFirst": CamelizeDownFirst,
"isType": func(typ goparser2.Type, excepted string, or ...string) bool {
defer func() {
if o := recover(); o != nil {
fmt.Println(o)
debug.PrintStack()
}
}()
return typ.IsExceptedType(excepted, or...)
},
// "isStructType": goparser2.IsStructType,
Expand Down Expand Up @@ -767,6 +774,9 @@ func initInsertFunc() {
{{- /* if eq $.var_param_length 1 */}}
{{- /* $upsertKeys = $.method.ReadFieldNames "On" */}}
{{- /* end */}}
{{- if not $upsertKeys }}
{{- $upsertKeys = $.method.ReadByNameForUpsert }}
{{- end }}
{{- range $idx, $paramName := $upsertKeys}}
"{{$paramName}}",
Expand All @@ -781,6 +791,9 @@ func initInsertFunc() {
{{- /* if eq $.var_param_length 1 */}}
{{- /* $upsertKeys = $.method.ReadFieldNames "On" */}}
{{- /* end */}}
{{- if not $upsertKeys }}
{{- $upsertKeys = $.method.ReadByNameForUpsert }}
{{- end }}
{{- $exists := false}}
{{- range $idx, $a := $upsertKeys }}
Expand All @@ -805,6 +818,9 @@ func initInsertFunc() {
{{- /* if eq $.var_param_length 1 */}}
{{- /* $upsertKeys = $.method.ReadFieldNames "On" */}}
{{- /* end */}}
{{- if not $upsertKeys }}
{{- $upsertKeys = $.method.ReadByNameForUpsert }}
{{- end }}
{{- $exists := false}}
{{- range $idx, $a := $upsertKeys }}
Expand Down
11 changes: 10 additions & 1 deletion cmd/gobatis/goparser2/astutil/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,18 @@ func (v *funcDeclVisitor) Visit(n ast.Node) ast.Visitor {

var name string
if star, ok := rn.Recv.List[0].Type.(*ast.StarExpr); ok {
name = star.X.(*ast.Ident).Name
switch x := star.X.(type) {
case *ast.Ident:
name = x.Name
case *ast.IndexExpr:
name = x.X.(*ast.Ident).Name
default:
log.Fatalln(fmt.Errorf("func.recv is unknown type - %T(%s)", rn.Recv.List[0].Type, rn.Recv.List[0].Type))
}
} else if ident, ok := rn.Recv.List[0].Type.(*ast.Ident); ok {
name = ident.Name
} else if ident, ok := rn.Recv.List[0].Type.(*ast.IndexExpr); ok {
name = ident.X.(*ast.Ident).Name
} else {
log.Fatalln(fmt.Errorf("func.recv is unknown type - %T", rn.Recv.List[0].Type))
}
Expand Down
14 changes: 14 additions & 0 deletions cmd/gobatis/goparser2/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ func (m *Method) ReadFieldNames(sep string) []string {
return params
}

func (m *Method) ReadByNameForUpsert() []string {
sep := "By"
pos := strings.Index(m.Name, sep)
if pos < 0 {
return nil
}
keyStr := m.Name[pos+len(sep):]
if keyStr == "" {
return nil
}

return strings.Split(keyStr, sep)
}

func (m *Method) IsOneParam() bool {
count := 0

Expand Down
3 changes: 0 additions & 3 deletions cmd/gobatis/goparser2/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,13 @@ func Parse(ctx *ParseContext, filename string) (*File, error) {
return nil, errors.New("load document of " + astFile.TypeList[idx].Name + " fail: namespace invalid syntex")
}
customNamespace = strings.TrimSpace(customNamespace)
fmt.Println("customNamespace", customNamespace)
break
}

if strings.HasPrefix(commentText, "@gobatis.namespace ") {
useNamespace = true
customNamespace = strings.TrimPrefix(commentText, "@gobatis.namespace ")
customNamespace = strings.TrimSpace(customNamespace)
fmt.Println("customNamespace", customNamespace)
break
}

Expand All @@ -208,7 +206,6 @@ func Parse(ctx *ParseContext, filename string) (*File, error) {
useNamespace = true
customNamespace = strings.TrimPrefix(commentText, "@gobatis.namespace\t")
customNamespace = strings.TrimSpace(customNamespace)
fmt.Println("customNamespace", customNamespace)
break
}
}
Expand Down
39 changes: 22 additions & 17 deletions core/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ var xormkeyTags = map[string]struct{}{
"version": {},
"default": {},
"json": {},
"jsonb": {},
"bit": {},
"tinyint": {},
"smallint": {},
Expand All @@ -1109,43 +1110,47 @@ var xormkeyTags = map[string]struct{}{
"decimal": {},
"numeric": {},
"tinyblob": {},
"clob": {},
"blob": {},
"mediumblob": {},
"longblob": {},
"bytea": {},
"bool": {},
"serial": {},
"bigserial": {},
}

func TagSplitForXORM(s string, fieldName string) []string {
parts := strings.Fields(s)
if len(parts) == 0 {
return parts
}
name := parts[0]
idx := strings.IndexByte(name, '(')
if idx >= 0 {
name = name[:idx]
}

if _, ok := xormkeyTags[name]; !ok {
return parts
}

for i := 1; i < len(parts); i++ {
// name := parts[0]
// idx := strings.IndexByte(name, '(')
// if idx >= 0 {
// name = name[:idx]
// }

// if _, ok := xormkeyTags[name]; !ok {
// return parts
// }

hasFieldName := false
for i := 0; i < len(parts); i++ {
name := parts[i]

idx := strings.IndexByte(name, '(')
if idx >= 0 {
name = name[:idx]
// unique(xxxx) 改成 unique=xxxx
parts[i] = name[:idx] + "=" + strings.TrimSuffix(name[idx+1:], ")")
}

if _, ok := xormkeyTags[name]; !ok {
tmp := parts[i]
parts[i] = parts[0]
parts[0] = tmp
return parts
if !hasFieldName {
fieldName = parts[i]
}
// parts[i] = parts[0]
// parts[0] = tmp
// return parts
}
}

Expand Down

0 comments on commit 17289ca

Please sign in to comment.