Skip to content

Commit

Permalink
feat: database driver
Browse files Browse the repository at this point in the history
  • Loading branch information
ystyle committed May 31, 2023
1 parent a9315a6 commit b4b58b9
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 90 deletions.
42 changes: 41 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ a library for manager sql with markdown or constant. and you can custom sql stor
// },
// })
sm.Load()
sql, err := sm.RenderTPL("GetStudentByID", 1)
sql, err := sm.RenderTPL("test/GetStudentByID", 1)
if err != nil {
panic(err)
}
fmt.Println(sql)
// select * from student where id = 1

sql, err = sm.RenderTPL("test2/GetStudentByID2", 1)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -94,6 +101,39 @@ a library for manager sql with markdown or constant. and you can custom sql stor
sm.load()
sql, _ := sm.RenderTPL("GetStudentByID", 1)
```
4. Database Driver: store in database
>you can custom the table name, and delete row should set deleted = 1, deleted_at = now(). you can build a interface to manage your sqls in your product.
>field: name, deleted,deleted_at,description,sql is required.
>load sql only on call `sm.load()`, it mean you should run `sm.load()` after changes
```sql
create table sql_manager
(
id int unsigned auto_increment primary key,
name varchar(255) null,
deleted int default 0 not null,
deleted_at datetime null,
description varchar(255) null,
`sql` text null,
constraint sql_manager_deleted_name_uindex
unique (deleted, name)
);
INSERT INTO sql_manager (name, deleted, deleted_at, description, `sql`) VALUES ('GetStudentByID', 0, null, 'get student by id, required id', 'select * from student where id = {{.}}');
```
```go
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
func main() {
db, _ := sql.Open("mysql", "root:root@tcp(127.0.0.1:3306)/test1")
sm = sqlmanager.New()
sm.Use(sqlmanager.NewDatabaseDriver(db, "sql_manager"))
sm.load()
sql, _ := sm.RenderTPL("test/GetStudentByID", 1)
}
```


### custom puglin
> implement sqlmanager.Driver
Expand Down
38 changes: 38 additions & 0 deletions database_driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package sqlmanager

import (
"database/sql"
"fmt"
)

const sqlRaw = "select name, description, `sql` from %s where deleted_at is null"

type DatabaseDriver struct {
db *sql.DB
tablename string
sqls []SqlTemple
}

func NewDatabaseDriver(db *sql.DB, tablename string) *DatabaseDriver {
return &DatabaseDriver{db: db, tablename: tablename}
}

func (dbd *DatabaseDriver) DriverName() string {
return "database"
}

func (dbd *DatabaseDriver) Load() ([]SqlTemple, error) {
rows, err := dbd.db.Query(fmt.Sprintf(sqlRaw, dbd.tablename))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var sql SqlTemple
if err := rows.Scan(&sql.Name, &sql.Description, &sql.Sql); err != nil {
return nil, err
}
dbd.sqls = append(dbd.sqls, sql)
}
return dbd.sqls, nil
}
4 changes: 2 additions & 2 deletions dynamic_driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ func TestDynamicLoad(t *testing.T) {
sm := New()
store := NewDynamicDriver()
store.Register("rest1", `select * from table where id = {{. }} or {{ block "rest" . }} {{ end }}`)
store.Register("rest", `select * from table where id = "{{ test . }}"`)
store.Register("rest", `select * from table where id = "{{ upper . }}"`)
sm.Use(store)
sm.RegisterFunc(template.FuncMap{
"test": func(v string) string {
"upper": func(v string) string {
return strings.ToUpper(v)
},
})
Expand Down
41 changes: 3 additions & 38 deletions embed_markdown_driver.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
package sqlmanager

import (
"bytes"
"embed"
"errors"
"fmt"
"github.com/gomarkdown/markdown"
"github.com/gomarkdown/markdown/parser"
"io/fs"
"log"
"path"
Expand Down Expand Up @@ -57,48 +52,18 @@ func (mdd *EmbedMarkdownDriver) Load() ([]SqlTemple, error) {
}

func (mdd *EmbedMarkdownDriver) parseMarkdown(filename string) ([]SqlTemple, error) {
var sqls []SqlTemple
buf, err := mdd.fs.ReadFile(filename)
if err != nil {
log.Printf("sqlmanager - ERROR: %s loading failed...\n", filename)
return nil, err
}
if bytes.ContainsRune(buf, '\r') {
buf = markdown.NormalizeNewlines(buf)
}
psr := parser.New()
node := markdown.Parse(buf, psr)
list := getAll(node)
i := 0
for {
// 1. text, code
// 2. text, text, code
if i >= len(list) {
break
}
var tpl SqlTemple
if list[i].Type == "text" && list[i+1].Type == "code" {
tpl.Name = mdd.getName(filename, list[i].Content)
tpl.Sql = list[i+1].Content
sqls = append(sqls, tpl)
i += 2
} else if list[i].Type == "text" && list[i+1].Type == "text" && list[i+2].Type == "code" {
tpl.Name = mdd.getName(filename, list[i].Content)
tpl.Description = list[i+1].Content
tpl.Sql = list[i+2].Content
sqls = append(sqls, tpl)
i += 3
} else {
return nil, errors.New(fmt.Sprintf("ERROR: parse markdown failed: %s", filename))
}
}
return sqls, nil
return parseMarkdown(buf, mdd.getName(filename))
}

func (mdd *EmbedMarkdownDriver) getName(filename, code string) string {
func (mdd *EmbedMarkdownDriver) getName(filename string) string {
ext := path.Ext(filename)
base := strings.TrimSuffix(filename, ext)
base = strings.TrimPrefix(base, mdd.dir)
base = strings.TrimPrefix(base, "/")
return path.Join(strings.TrimSuffix(base, ext), code)
return base
}
50 changes: 12 additions & 38 deletions file_driver.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package sqlmanager

import (
"bytes"
"errors"
"fmt"
"github.com/gomarkdown/markdown"
"github.com/gomarkdown/markdown/parser"
"log"
"os"
"path"
"path/filepath"
"strings"
)

type MarkdownDriver struct {
Expand Down Expand Up @@ -44,7 +40,7 @@ func (mdd *MarkdownDriver) Load() ([]SqlTemple, error) {
}
ext := path.Ext(subpath)
if ext == ".md" || ext == ".markdown" {
s, err := parseMarkdown(subpath)
s, err := mdd.parseMarkdown(subpath)
if err != nil {
return err
}
Expand All @@ -60,41 +56,19 @@ func (mdd *MarkdownDriver) Load() ([]SqlTemple, error) {
return sqls, nil
}

func parseMarkdown(filename string) ([]SqlTemple, error) {
var sqls []SqlTemple
func (mdd *MarkdownDriver) parseMarkdown(filename string) ([]SqlTemple, error) {
buf, err := os.ReadFile(filename)
if err != nil {
log.Printf("sqlmanager - ERROR: %s loading failed...\n", filename)
return nil, err
}
if bytes.ContainsRune(buf, '\r') {
buf = markdown.NormalizeNewlines(buf)
}
psr := parser.New()
node := markdown.Parse(bytes.ReplaceAll(buf, []byte("\r"), nil), psr)
list := getAll(node)
i := 0
for {
// 1. text, code
// 2. text, text, code
if i >= len(list) {
break
}
var tpl SqlTemple
if list[i].Type == "text" && list[i+1].Type == "code" {
tpl.Name = list[i].Content
tpl.Sql = list[i+1].Content
sqls = append(sqls, tpl)
i += 2
} else if list[i].Type == "text" && list[i+1].Type == "text" && list[i+2].Type == "code" {
tpl.Name = list[i].Content
tpl.Description = list[i+1].Content
tpl.Sql = list[i+2].Content
sqls = append(sqls, tpl)
i += 3
} else {
return nil, errors.New(fmt.Sprintf("ERROR: parse markdown failed: %s", filename))
}
}
return sqls, nil
return parseMarkdown(buf, mdd.getName(filename))
}

func (mdd *MarkdownDriver) getName(filename string) string {
ext := path.Ext(filename)
base := strings.TrimSuffix(filename, ext)
base = strings.TrimPrefix(base, mdd.dir)
base = strings.TrimPrefix(base, "/")
return base
}
4 changes: 2 additions & 2 deletions file_driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (

func TestLoad(t *testing.T) {
sm := New()
sm.Use(NewMarkdownDriverWithDir("./test-sql"))
sm.Use(NewMarkdownDriverWithDir("test-sql"))
sm.Load()
sql, err := sm.RenderTPL("GetStudentByID2", 1)
sql, err := sm.RenderTPL("test/GetStudentByID2", 1)
if err != nil {
panic(err)
}
Expand Down
16 changes: 7 additions & 9 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func (sm *SqlManager) Use(plugin Driver) {
}

func (sm *SqlManager) Load() {
sm.tpl = nil
sm.tpl = template.New("_root_")
sm.tpl.Funcs(sm.funcs)
var sqltpls []SqlTemple
for _, driver := range sm.drivers {
sqls, err := driver.Load()
if err != nil {
Expand All @@ -41,22 +43,18 @@ func (sm *SqlManager) Load() {
for _, sql := range sqls {
d, has := sm.findTpl(sql.Name)
if has {
log.Printf("sqlmanager - WARN: %s Has duplicate sql: It will be cover [%s] with [ %s ]\n", sql.Name, strings.ReplaceAll(d.Sql, "\n", ""), strings.ReplaceAll(sql.Sql, "\n", ""))
}
sm.sqlTemples = append(sm.sqlTemples, sql)
if sm.tpl == nil {
sm.tpl = template.New(sql.Name)
sm.tpl.Funcs(sm.funcs)
} else {
sm.tpl = sm.tpl.New(sql.Name)
log.Printf("sqlmanager - WARN: %s Has duplicate sql: It will be cover [%s-%s] with [ %s ]\n", driver.DriverName(), sql.Name, strings.ReplaceAll(d.Sql, "\n", ""), strings.ReplaceAll(sql.Sql, "\n", ""))
}
sqltpls = append(sqltpls, sql)
sm.tpl = sm.tpl.New(sql.Name)
sm.tpl, err = sm.tpl.Parse(sql.Sql)
if err != nil {
panic(err)
}
}
log.Printf("sqlmanager - INFO: %s loaded %d sqls.\n", driver.DriverName(), len(sqls))
}
sm.sqlTemples = sqltpls
}

func (sm *SqlManager) findTpl(name string) (*SqlTemple, bool) {
Expand Down
39 changes: 39 additions & 0 deletions markdown.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,48 @@
package sqlmanager

import (
"bytes"
"errors"
"github.com/gomarkdown/markdown"
"github.com/gomarkdown/markdown/ast"
"github.com/gomarkdown/markdown/parser"
"path/filepath"
)

func parseMarkdown(buf []byte, prefix string) ([]SqlTemple, error) {
var sqls []SqlTemple
if bytes.ContainsRune(buf, '\r') {
buf = markdown.NormalizeNewlines(buf)
}
psr := parser.New()
node := markdown.Parse(buf, psr)
list := getAll(node)
i := 0
for {
// 1. text, code
// 2. text, text, code
if i >= len(list) {
break
}
var tpl SqlTemple
if list[i].Type == "text" && list[i+1].Type == "code" {
tpl.Name = filepath.Join(prefix, list[i].Content)
tpl.Sql = list[i+1].Content
sqls = append(sqls, tpl)
i += 2
} else if list[i].Type == "text" && list[i+1].Type == "text" && list[i+2].Type == "code" {
tpl.Name = filepath.Join(prefix, list[i].Content)
tpl.Description = list[i+1].Content
tpl.Sql = list[i+2].Content
sqls = append(sqls, tpl)
i += 3
} else {
return nil, errors.New("parse markdown failed")
}
}
return sqls, nil
}

func getAll(node ast.Node) []item {
var list []item
ast.WalkFunc(node, func(node ast.Node, entering bool) ast.WalkStatus {
Expand Down

0 comments on commit b4b58b9

Please sign in to comment.