-
Notifications
You must be signed in to change notification settings - Fork 1
/
manager.go
91 lines (81 loc) · 2.13 KB
/
manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package sqlmanager
import (
"log"
"reflect"
"runtime"
"strings"
"text/template"
)
type SqlManager struct {
sqlTemples []SqlTemple
drivers map[string]Driver
funcs template.FuncMap
tpl *template.Template
}
func New() *SqlManager {
sm := &SqlManager{
drivers: make(map[string]Driver),
funcs: template.FuncMap{},
}
return sm
}
func (sm *SqlManager) Use(plugin Driver) {
if _, ok := sm.drivers[plugin.DriverName()]; ok {
log.Printf("sqlmanager - WARN: %s already used\n", plugin.DriverName())
}
sm.drivers[plugin.DriverName()] = plugin
}
func (sm *SqlManager) Load() {
sm.tpl = template.New("_root_")
sm.tpl.Funcs(sm.funcs)
var sqltpls []SqlTemple
for _, driver := range sm.drivers {
sqls, err := driver.Load()
if err != nil {
log.Printf("sqlmanager - ERROR: %s load failed: ", sqls)
log.Panicln(err)
}
for _, sql := range sqls {
d, has := sm.findTpl(sql.Name)
if has {
log.Printf("sqlmanager - WARN: %s Has duplicate sql: It will be cover [%s-%s] with [ %s ]\n", driver.DriverName(), sql.Name, strings.ReplaceAll(d.Sql, "\n", ""), strings.ReplaceAll(sql.Sql, "\n", ""))
}
sqltpls = append(sqltpls, sql)
sm.tpl = sm.tpl.New(sql.Name)
sm.tpl, err = sm.tpl.Parse(sql.Sql)
if err != nil {
panic(err)
}
}
log.Printf("sqlmanager - INFO: %s loaded %d sqls.\n", driver.DriverName(), len(sqls))
}
sm.sqlTemples = sqltpls
}
func (sm *SqlManager) findTpl(name string) (*SqlTemple, bool) {
for _, tpl := range sm.sqlTemples {
if tpl.Name == name {
return &tpl, true
}
}
return nil, false
}
func (sm *SqlManager) RegisterFunc(funcs template.FuncMap) {
for k, v := range funcs {
if temp, ok := sm.funcs[k]; ok {
log.Printf("sqlmanager - WARN: %s Has duplicate func: It will be cover [%s] with [%s]\n", k, getFunctionName(temp), getFunctionName(v))
}
sm.funcs[k] = v
}
}
func getFunctionName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}
type Driver interface {
Load() ([]SqlTemple, error)
DriverName() string
}
type SqlTemple struct {
Name string // 名称
Description string // 描述
Sql string // sql
}