diff --git a/.gitignore b/.gitignore index ae93b1b..d9335e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ .DS_Store -tmp \ No newline at end of file +.idea +.vscode +tmp +*.db +*.sum \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5763570 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Dai Jie + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 5dee09d..296bff7 100644 --- a/README.md +++ b/README.md @@ -1,219 +1,145 @@ -# 7天用Go从零实现Web框架Gee - -![Gee](doc/gee/gee.jpg) - -Gee 的设计与实现参考了Gin,这个教程可以快速入门:[Go Gin简明教程](https://geektutu.com/post/quick-go-gin.html)。 - -## [教程目录](https://geektutu.com/post/gee.html) - -- [第一天:前置知识(http.Handler接口)](https://geektutu.com/post/gee-day1.html),[Code - Github](day1-http-base) -- [第二天:上下文设计(Context)](https://geektutu.com/post/gee-day2.html),[Code - Github](day2-context) -- [第三天:Tire树路由(Router)](https://geektutu.com/post/gee-day3.html),[Code - Github](day3-router) -- [第四天:分组控制(Group)](https://geektutu.com/post/gee-day4.html),[Code - Github](day4-group) -- [第五天:中间件(Middleware)](https://geektutu.com/post/gee-day5.html),[Code - Github](day5-middleware) -- [第六天:HTML模板(Template)](https://geektutu.com/post/gee-day6.html),[Code - Github](day6-template) -- 第七天:错误恢复(Panic Recover),[Code - Github](day7-panic-recover) - - -## Day 1 - Static Route - -```go -func main() { - r := gee.New() - r.GET("/", func(w http.ResponseWriter, req *http.Request) { - fmt.Fprintf(w, "URL.Path = %q\n", req.URL.Path) - }) - - r.GET("/hello", func(w http.ResponseWriter, req *http.Request) { - for k, v := range req.Header { - fmt.Fprintf(w, "Header[%q] = %q\n", k, v) - } - }) - - r.Run(":9999") -} -``` - -## Day 2 - Context Design - -```go -func main() { - r := gee.New() - r.GET("/", func(c *gee.Context) { - c.HTML(http.StatusOK, "

Hello Gee

") - }) - r.GET("/hello", func(c *gee.Context) { - // expect /hello?name=geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) - }) - - r.POST("/login", func(c *gee.Context) { - c.JSON(http.StatusOK, &map[string]string{ - "username": c.PostForm("username"), - "password": c.PostForm("password"), - }) - }) - - r.Run(":9999") -} -``` - -## Day 3 - Dynamic Route - -```go -func main() { - r := gee.New() - r.GET("/", func(c *gee.Context) { - c.HTML(http.StatusOK, "

Hello Gee

") - }) - - r.GET("/hello", func(c *gee.Context) { - // expect /hello?name=geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) - }) - - r.GET("/hello/:name", func(c *gee.Context) { - // expect /hello/geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) - }) - - r.GET("/assets/*filepath", func(c *gee.Context) { - c.JSON(http.StatusOK, gee.H{"filepath": c.Param("filepath")}) - }) - - r.Run(":9999") -} -``` - -## Day 4 - Nesting Group Control - -```go -func main() { - r := gee.New() - v1 := r.Group("/v1") - { - v1.GET("/", func(c *gee.Context) { - c.HTML(http.StatusOK, "

Hello Gee

") - }) - - v1.GET("/hello", func(c *gee.Context) { - // expect /hello?name=geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) - }) - } - v2 := r.Group("/v2") - { - v2.GET("/hello/:name", func(c *gee.Context) { - // expect /hello/geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) - }) - v2.POST("/login", func(c *gee.Context) { - c.JSON(http.StatusOK, &map[string]string{ - "username": c.PostForm("username"), - "password": c.PostForm("password"), - }) - }) - - } - - r.Run(":9999") -} -``` - -## Day 5 - Middleware - -```go -func onlyForV2() gee.HandlerFunc { - return func(c *gee.Context) { - // Start timer - t := time.Now() - // if a server error occurred - c.Fail(500, "Internal Server Error") - // Calculate resolution time - log.Printf("[%d] %s in %v for group v2", c.StatusCode, c.Req.RequestURI, time.Since(t)) - } -} - -func main() { - r := gee.New() - r.Use(gee.Logger()) // global midlleware - r.GET("/", func(c *gee.Context) { - c.HTML(http.StatusOK, "

Hello Gee

") - }) - - v2 := r.Group("/v2") - v2.Use(onlyForV2()) // v2 group middleware - { - v2.GET("/hello/:name", func(c *gee.Context) { - // expect /hello/geektutu - c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) - }) - } - - r.Run(":9999") -} -``` - -## Day 6 - HTML Template - -```go -type student struct { - Name string - Age int8 -} - -func formatAsDate(t time.Time) string { - year, month, day := t.Date() - return fmt.Sprintf("%d-%02d-%02d", year, month, day) -} - -func main() { - r := gee.New() - r.Use(gee.Logger()) - r.SetFuncMap(template.FuncMap{ - "formatAsDate": formatAsDate, - }) - r.LoadHTMLGlob("templates/*") - r.Static("/assets", "./static") - - stu1 := &student{Name: "Geektutu", Age: 20} - stu2 := &student{Name: "Jack", Age: 22} - r.GET("/", func(c *gee.Context) { - c.HTML(http.StatusOK, "css.tmpl", nil) - }) - r.GET("/students", func(c *gee.Context) { - c.HTML(http.StatusOK, "arr.tmpl", gee.H{ - "title": "gee", - "stuArr": [2]*student{stu1, stu2}, - }) - }) - - r.GET("/date", func(c *gee.Context) { - c.HTML(http.StatusOK, "custom_func.tmpl", gee.H{ - "title": "gee", - "now": time.Date(2019, 8, 17, 0, 0, 0, 0, time.UTC), - }) - }) - - r.Run(":9999") -} -``` - -## Day 7 - Panic Recover - -```go -func main() { - r := gee.Default() - r.GET("/", func(c *gee.Context) { - c.String(http.StatusOK, "Hello Geektutu\n") - }) - // index out of range for testing Recovery() - r.GET("/panic", func(c *gee.Context) { - names := []string{"geektutu"} - c.String(http.StatusOK, names[100]) - }) - - r.Run(":9999") -} - -``` \ No newline at end of file +# 7 days golang programs from scratch + +[![CodeSize](https://img.shields.io/github/languages/code-size/geektutu/7days-golang)](https://github.com/geektutu/7days-golang) +[![LICENSE](https://img.shields.io/badge/license-MIT-green)](https://mit-license.org/) + +
+README 中文版本 +
+ +## 7天用Go从零实现系列 + +7天能写什么呢?类似 gin 的 web 框架?类似 groupcache 的分布式缓存?或者一个简单的 Python 解释器?希望这个仓库能给你答案。 + +推荐先阅读 **[Go 语言简明教程](https://geektutu.com/post/quick-golang.html)**,一篇文章了解Go的基本语法、并发编程,依赖管理等内容。 + +推荐 **[Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html)**,加深对 Go 语言的理解。 + +推荐 **[Go 语言高性能编程](https://geektutu.com/post/high-performance-go.html)**([项目地址](https://github.com/geektutu/high-performance-go)),写出高性能的 Go 代码。 + +期待关注我的「[知乎专栏](https://zhuanlan.zhihu.com/geekgo)」和「[微博](http://weibo.com/geektutu)」,查看最近的文章和动态。 + +### 7天用Go从零实现Web框架 - Gee + +[Gee](https://geektutu.com/post/gee.html) 是一个模仿 [gin](https://github.com/gin-gonic/gin) 实现的 Web 框架,[Go Gin简明教程](https://geektutu.com/post/quick-go-gin.html)可以快速入门。 + +- 第一天:[前置知识(http.Handler接口)](https://geektutu.com/post/gee-day1.html) | [Code](gee-web/day1-http-base) +- 第二天:[上下文设计(Context)](https://geektutu.com/post/gee-day2.html) | [Code](gee-web/day2-context) +- 第三天:[Trie树路由(Router)](https://geektutu.com/post/gee-day3.html) | [Code](gee-web/day3-router) +- 第四天:[分组控制(Group)](https://geektutu.com/post/gee-day4.html) | [Code](gee-web/day4-group) +- 第五天:[中间件(Middleware)](https://geektutu.com/post/gee-day5.html) | [Code](gee-web/day5-middleware) +- 第六天:[HTML模板(Template)](https://geektutu.com/post/gee-day6.html) | [Code](gee-web/day6-template) +- 第七天:[错误恢复(Panic Recover)](https://geektutu.com/post/gee-day7.html) | [Code](gee-web/day7-panic-recover) + +### 7天用Go从零实现分布式缓存 GeeCache + +[GeeCache](https://geektutu.com/post/geecache.html) 是一个模仿 [groupcache](https://github.com/golang/groupcache) 实现的分布式缓存系统 + +- 第一天:[LRU 缓存淘汰策略](https://geektutu.com/post/geecache-day1.html) | [Code](gee-cache/day1-lru) +- 第二天:[单机并发缓存](https://geektutu.com/post/geecache-day2.html) | [Code](gee-cache/day2-single-node) +- 第三天:[HTTP 服务端](https://geektutu.com/post/geecache-day3.html) | [Code](gee-cache/day3-http-server) +- 第四天:[一致性哈希(Hash)](https://geektutu.com/post/geecache-day4.html) | [Code](gee-cache/day4-consistent-hash) +- 第五天:[分布式节点](https://geektutu.com/post/geecache-day5.html) | [Code](gee-cache/day5-multi-nodes) +- 第六天:[防止缓存击穿](https://geektutu.com/post/geecache-day6.html) | [Code](gee-cache/day6-single-flight) +- 第七天:[使用 Protobuf 通信](https://geektutu.com/post/geecache-day7.html) | [Code](gee-cache/day7-proto-buf) + +### 7天用Go从零实现ORM框架 GeeORM + +[GeeORM](https://geektutu.com/post/geeorm.html) 是一个模仿 [gorm](https://github.com/jinzhu/gorm) 和 [xorm](https://github.com/go-xorm/xorm) 的 ORM 框架 + +gorm 准备推出完全重写的 v2 版本(目前还在开发中),相对 gorm-v1 来说,xorm 的设计更容易理解,所以 geeorm 接口设计上主要参考了 xorm,一些细节实现上参考了 gorm。 + +- 第一天:[database/sql 基础](https://geektutu.com/post/geeorm-day1.html) | [Code](gee-orm/day1-database-sql) +- 第二天:[对象表结构映射](https://geektutu.com/post/geeorm-day2.html) | [Code](gee-orm/day2-reflect-schema) +- 第三天:[记录新增和查询](https://geektutu.com/post/geeorm-day3.html) | [Code](gee-orm/day3-save-query) +- 第四天:[链式操作与更新删除](https://geektutu.com/post/geeorm-day4.html) | [Code](gee-orm/day4-chain-operation) +- 第五天:[实现钩子(Hooks)](https://geektutu.com/post/geeorm-day5.html) | [Code](gee-orm/day5-hooks) +- 第六天:[支持事务(Transaction)](https://geektutu.com/post/geeorm-day6.html) | [Code](gee-orm/day6-transaction) +- 第七天:[数据库迁移(Migrate)](https://geektutu.com/post/geeorm-day7.html) | [Code](gee-orm/day7-migrate) + + +### 7天用Go从零实现RPC框架 GeeRPC + +[GeeRPC](https://geektutu.com/post/geerpc.html) 是一个基于 [net/rpc](https://github.com/golang/go/tree/master/src/net/rpc) 开发的 RPC 框架 +GeeRPC 是基于 Go 语言标准库 `net/rpc` 实现的,添加了协议交换、服务注册与发现、负载均衡等功能,代码约 1k。 + +- 第一天 - [服务端与消息编码](https://geektutu.com/post/geerpc-day1.html) | [Code](gee-rpc/day1-codec) +- 第二天 - [支持并发与异步的客户端](https://geektutu.com/post/geerpc-day2.html) | [Code](gee-rpc/day2-client) +- 第三天 - [服务注册(service register)](https://geektutu.com/post/geerpc-day3.html) | [Code](gee-rpc/day3-service ) +- 第四天 - [超时处理(timeout)](https://geektutu.com/post/geerpc-day4.html) | [Code](gee-rpc/day4-timeout ) +- 第五天 - [支持HTTP协议](https://geektutu.com/post/geerpc-day5.html) | [Code](gee-rpc/day5-http-debug) +- 第六天 - [负载均衡(load balance)](https://geektutu.com/post/geerpc-day6.html) | [Code](gee-rpc/day6-load-balance) +- 第七天 - [服务发现与注册中心(registry)](https://geektutu.com/post/geerpc-day7.html) | [Code](gee-rpc/day7-registry) + +### WebAssembly 使用示例 + +具体的实践过程记录在 [Go WebAssembly 简明教程](https://geektutu.com/post/quick-go-wasm.html)。 + +- 示例一:Hello World | [Code](demo-wasm/hello-world) +- 示例二:注册函数 | [Code](demo-wasm/register-functions) +- 示例三:操作 DOM | [Code](demo-wasm/manipulate-dom) +- 示例四:回调函数 | [Code](demo-wasm/callback) + +
+
+ +What can be accomplished in 7 days? A gin-like web framework? A distributed cache like groupcache? Or a simple Python interpreter? Hope this repo can give you the answer. + +## Web Framework - Gee + +[Gee](https://geektutu.com/post/gee.html) is a [gin](https://github.com/gin-gonic/gin)-like framework + +- Day 1 - http.Handler Interface Basic [Code](gee-web/day1-http-base) +- Day 2 - Design a Flexiable Context [Code](gee-web/day2-context) +- Day 3 - Router with Trie-Tree Algorithm [Code](gee-web/day3-router) +- Day 4 - Group Control [Code](gee-web/day4-group) +- Day 5 - Middleware Mechanism [Code](gee-web/day5-middleware) +- Day 6 - Embeded Template Support [Code](gee-web/day6-template) +- Day 7 - Panic Recover & Make it Robust [Code](gee-web/day7-panic-recover) + +## Distributed Cache - GeeCache + +[GeeCache](https://geektutu.com/post/geecache.html) is a [groupcache](https://github.com/golang/groupcache)-like distributed cache + +- Day 1 - LRU (Least Recently Used) Caching Strategy [Code](gee-cache/day1-lru) +- Day 2 - Single Machine Concurrent Cache [Code](gee-cache/day2-single-node) +- Day 3 - Launch a HTTP Server [Code](gee-cache/day3-http-server) +- Day 4 - Consistent Hash Algorithm [Code](gee-cache/day4-consistent-hash) +- Day 5 - Communication between Distributed Nodes [Code](gee-cache/day5-multi-nodes) +- Day 6 - Cache Breakdown & Single Flight | [Code](gee-cache/day6-single-flight) +- Day 7 - Use Protobuf as RPC Data Exchange Type | [Code](gee-cache/day7-proto-buf) + +## Object Relational Mapping - GeeORM + +[GeeORM](https://geektutu.com/post/geeorm.html) is a [gorm](https://github.com/jinzhu/gorm)-like and [xorm](https://github.com/go-xorm/xorm)-like object relational mapping library + +Xorm's desgin is easier to understand than gorm-v1, so the main designs references xorm and some detailed implementions references gorm-v1. + +- Day 1 - database/sql Basic | [Code](gee-orm/day1-database-sql) +- Day 2 - Object Schame Mapping | [Code](gee-orm/day2-reflect-schema) +- Day 3 - Insert and Query | [Code](gee-orm/day3-save-query) +- Day 4 - Chain, Delete and Update | [Code](gee-orm/day4-chain-operation) +- Day 5 - Support Hooks | [Code](gee-orm/day5-hooks) +- Day 6 - Support Transaction | [Code](gee-orm/day6-transaction) +- Day 7 - Migrate Database | [Code](gee-orm/day7-migrate) + +## RPC Framework - GeeRPC + +[GeeRPC](https://geektutu.com/post/geerpc.html) is a [net/rpc](https://github.com/golang/go/tree/master/src/net/rpc)-like RPC framework + +Based on golang standard library `net/rpc`, GeeRPC implements more features. eg, protocol exchange, service registration and discovery, load balance, etc. + +- Day 1 - Server Message Codec | [Code](gee-rpc/day1-codec) +- Day 2 - Concurrent Client | [Code](gee-rpc/day2-client) +- Day 3 - Service Register | [Code](gee-rpc/day3-service ) +- Day 4 - Timeout Processing | [Code](gee-rpc/day4-timeout ) +- Day 5 - Support HTTP Protocol | [Code](gee-rpc/day5-http-debug) +- Day 6 - Load Balance | [Code](gee-rpc/day6-load-balance) +- Day 7 - Discovery and Registry | [Code](gee-rpc/day7-registry) + +## Golang WebAssembly Demo + +- Demo 1 - Hello World [Code](demo-wasm/hello-world) +- Demo 2 - Register Functions [Code](demo-wasm/register-functions) +- Demo 3 - Manipulate DOM [Code](demo-wasm/manipulate-dom) +- Demo 4 - Callback [Code](demo-wasm/callback) diff --git a/day7-panic-recover/main.go b/day7-panic-recover/main.go deleted file mode 100644 index 5021aa6..0000000 --- a/day7-panic-recover/main.go +++ /dev/null @@ -1,48 +0,0 @@ -package main - -/* -$ curl "http://localhost:9999" -Hello Geektutu -$ curl "http://localhost:9999/panic" -{"message":"Internal Server Error"} - ->>> log -2019/08/18 17:55:57 [200] / in 4.533µs -2019/08/18 17:55:58 runtime error: index out of range -Traceback: - /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:523 - /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:44 - /Users/geektutu/7days-golang/day7-panic-recover/main.go:20 - /Users/geektutu/7days-golang/day7-panic-recover/gee/context.go:41 - /Users/geektutu/7days-golang/day7-panic-recover/gee/recovery.go:37 - /Users/geektutu/7days-golang/day7-panic-recover/gee/context.go:41 - /Users/geektutu/7days-golang/day7-panic-recover/gee/logger.go:15 - /Users/geektutu/7days-golang/day7-panic-recover/gee/context.go:41 - /Users/geektutu/7days-golang/day7-panic-recover/gee/router.go:99 - /Users/geektutu/7days-golang/day7-panic-recover/gee/gee.go:129 - /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:2775 - /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:1879 - /usr/local/Cellar/go/1.12.5/libexec/src/runtime/asm_amd64.s:1338 - -2019/08/18 17:55:58 [500] /panic in 143.086µs -*/ - -import ( - "net/http" - - "./gee" -) - -func main() { - r := gee.Default() - r.GET("/", func(c *gee.Context) { - c.String(http.StatusOK, "Hello Geektutu\n") - }) - // index out of range for testing Recovery() - r.GET("/panic", func(c *gee.Context) { - names := []string{"geektutu"} - c.String(http.StatusOK, names[100]) - }) - - r.Run(":9999") -} diff --git a/demo-wasm/.gitignore b/demo-wasm/.gitignore new file mode 100644 index 0000000..b4d6ef2 --- /dev/null +++ b/demo-wasm/.gitignore @@ -0,0 +1,2 @@ +*.wasm +static \ No newline at end of file diff --git a/demo-wasm/callback/Makefile b/demo-wasm/callback/Makefile new file mode 100644 index 0000000..234573b --- /dev/null +++ b/demo-wasm/callback/Makefile @@ -0,0 +1,11 @@ +all: static/main.wasm static/wasm_exec.js +ifeq (, $(shell which goexec)) + go get -u github.com/shurcooL/goexec +endif + goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' + +static/wasm_exec.js: + cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static + +static/main.wasm: main.go + GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm . \ No newline at end of file diff --git a/demo-wasm/callback/index.html b/demo-wasm/callback/index.html new file mode 100644 index 0000000..da37fb4 --- /dev/null +++ b/demo-wasm/callback/index.html @@ -0,0 +1,18 @@ + + + + + + + + + + +

+ + + \ No newline at end of file diff --git a/demo-wasm/callback/main.go b/demo-wasm/callback/main.go new file mode 100644 index 0000000..e9c137e --- /dev/null +++ b/demo-wasm/callback/main.go @@ -0,0 +1,32 @@ +// main.go +package main + +import ( + "syscall/js" + "time" +) + +func fib(i int) int { + if i == 0 || i == 1 { + return 1 + } + return fib(i-1) + fib(i-2) +} + +func fibFunc(this js.Value, args []js.Value) interface{} { + callback := args[len(args)-1] + go func() { + time.Sleep(3 * time.Second) + v := fib(args[0].Int()) + callback.Invoke(v) + }() + + js.Global().Get("ans").Set("innerHTML", "Waiting 3s...") + return nil +} + +func main() { + done := make(chan int, 0) + js.Global().Set("fibFunc", js.FuncOf(fibFunc)) + <-done +} diff --git a/demo-wasm/hello-world/Makefile b/demo-wasm/hello-world/Makefile new file mode 100644 index 0000000..234573b --- /dev/null +++ b/demo-wasm/hello-world/Makefile @@ -0,0 +1,11 @@ +all: static/main.wasm static/wasm_exec.js +ifeq (, $(shell which goexec)) + go get -u github.com/shurcooL/goexec +endif + goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' + +static/wasm_exec.js: + cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static + +static/main.wasm: main.go + GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm . \ No newline at end of file diff --git a/demo-wasm/hello-world/index.html b/demo-wasm/hello-world/index.html new file mode 100644 index 0000000..c75710c --- /dev/null +++ b/demo-wasm/hello-world/index.html @@ -0,0 +1,12 @@ + + + + + + + + \ No newline at end of file diff --git a/demo-wasm/hello-world/main.go b/demo-wasm/hello-world/main.go new file mode 100644 index 0000000..8bcd95b --- /dev/null +++ b/demo-wasm/hello-world/main.go @@ -0,0 +1,9 @@ +// main.go +package main + +import "syscall/js" + +func main() { + alert := js.Global().Get("alert") + alert.Invoke("Hello World!") +} \ No newline at end of file diff --git a/demo-wasm/manipulate-dom/Makefile b/demo-wasm/manipulate-dom/Makefile new file mode 100644 index 0000000..234573b --- /dev/null +++ b/demo-wasm/manipulate-dom/Makefile @@ -0,0 +1,11 @@ +all: static/main.wasm static/wasm_exec.js +ifeq (, $(shell which goexec)) + go get -u github.com/shurcooL/goexec +endif + goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' + +static/wasm_exec.js: + cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static + +static/main.wasm: main.go + GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm . \ No newline at end of file diff --git a/demo-wasm/manipulate-dom/index.html b/demo-wasm/manipulate-dom/index.html new file mode 100644 index 0000000..60f2e5a --- /dev/null +++ b/demo-wasm/manipulate-dom/index.html @@ -0,0 +1,18 @@ + + + + + + + + + + +

1

+ + + \ No newline at end of file diff --git a/demo-wasm/manipulate-dom/main.go b/demo-wasm/manipulate-dom/main.go new file mode 100644 index 0000000..618cd85 --- /dev/null +++ b/demo-wasm/manipulate-dom/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "strconv" + "syscall/js" +) + +func fib(i int) int { + if i == 0 || i == 1 { + return 1 + } + return fib(i-1) + fib(i-2) +} + +var ( + document = js.Global().Get("document") + numEle = document.Call("getElementById", "num") + ansEle = document.Call("getElementById", "ans") + btnEle = js.Global().Get("btn") +) + +func fibFunc(this js.Value, args []js.Value) interface{} { + v := numEle.Get("value") + if num, err := strconv.Atoi(v.String()); err == nil { + ansEle.Set("innerHTML", js.ValueOf(fib(num))) + } + return nil +} + +func main() { + done := make(chan int, 0) + btnEle.Call("addEventListener", "click", js.FuncOf(fibFunc)) + <-done +} diff --git a/demo-wasm/register-functions/Makefile b/demo-wasm/register-functions/Makefile new file mode 100644 index 0000000..234573b --- /dev/null +++ b/demo-wasm/register-functions/Makefile @@ -0,0 +1,11 @@ +all: static/main.wasm static/wasm_exec.js +ifeq (, $(shell which goexec)) + go get -u github.com/shurcooL/goexec +endif + goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' + +static/wasm_exec.js: + cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static + +static/main.wasm: main.go + GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm . \ No newline at end of file diff --git a/demo-wasm/register-functions/index.html b/demo-wasm/register-functions/index.html new file mode 100644 index 0000000..9cc865c --- /dev/null +++ b/demo-wasm/register-functions/index.html @@ -0,0 +1,18 @@ + + + + + + + + + + +

1

+ + + \ No newline at end of file diff --git a/demo-wasm/register-functions/main.go b/demo-wasm/register-functions/main.go new file mode 100644 index 0000000..2e22e78 --- /dev/null +++ b/demo-wasm/register-functions/main.go @@ -0,0 +1,21 @@ +// main.go +package main + +import "syscall/js" + +func fib(i int) int { + if i == 0 || i == 1 { + return 1 + } + return fib(i-1) + fib(i-2) +} + +func fibFunc(this js.Value, args []js.Value) interface{} { + return js.ValueOf(fib(args[0].Int())) +} + +func main() { + done := make(chan int, 0) + js.Global().Set("fibFunc", js.FuncOf(fibFunc)) + <-done +} diff --git a/gee-bolt/day1-pages/go.mod b/gee-bolt/day1-pages/go.mod new file mode 100644 index 0000000..17b5990 --- /dev/null +++ b/gee-bolt/day1-pages/go.mod @@ -0,0 +1,3 @@ +module geebolt + +go 1.13 diff --git a/gee-bolt/day1-pages/meta.go b/gee-bolt/day1-pages/meta.go new file mode 100644 index 0000000..4e9cdb1 --- /dev/null +++ b/gee-bolt/day1-pages/meta.go @@ -0,0 +1,33 @@ +package geebolt + +import ( + "errors" + "hash/fnv" + "unsafe" +) + +// Represent a marker value to indicate that a file is a gee-bolt DB +const magic uint32 = 0xED0CDAED + +type meta struct { + magic uint32 + pageSize uint32 + pgid uint64 + checksum uint64 +} + +func (m *meta) sum64() uint64 { + var h = fnv.New64a() + _, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:]) + return h.Sum64() +} + +func (m *meta) validate() error { + if m.magic != magic { + return errors.New("invalid magic number") + } + if m.checksum != m.sum64() { + return errors.New("invalid checksum") + } + return nil +} diff --git a/gee-bolt/day1-pages/page.go b/gee-bolt/day1-pages/page.go new file mode 100644 index 0000000..18fc492 --- /dev/null +++ b/gee-bolt/day1-pages/page.go @@ -0,0 +1,88 @@ +package geebolt + +import ( + "fmt" + "reflect" + "unsafe" +) + +const pageHeaderSize = unsafe.Sizeof(page{}) +const branchPageElementSize = unsafe.Sizeof(branchPageElement{}) +const leafPageElementSize = unsafe.Sizeof(leafPageElement{}) +const maxKeysPerPage = 1024 + +const ( + branchPageFlag uint16 = iota + leafPageFlag + metaPageFlag + freelistPageFlag +) + +type page struct { + id uint64 + flags uint16 + count uint16 + overflow uint32 +} + +type leafPageElement struct { + pos uint32 + ksize uint32 + vsize uint32 +} + +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid uint64 +} + +func (p *page) typ() string { + switch p.flags { + case branchPageFlag: + return "branch" + case leafPageFlag: + return "leaf" + case metaPageFlag: + return "meta" + case freelistPageFlag: + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +func (p *page) meta() *meta { + return (*meta)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + pageHeaderSize)) +} + +func (p *page) dataPtr() unsafe.Pointer { + return unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(p)) + pageHeaderSize, + Len: int(p.count), + Cap: int(p.count), + }) +} + +func (p *page) leafPageElement(index uint16) *leafPageElement { + off := pageHeaderSize + uintptr(index)*leafPageElementSize + return (*leafPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off)) +} + +func (p *page) leafPageElements() []leafPageElement { + if p.count == 0 { + return nil + } + return *(*[]leafPageElement)(p.dataPtr()) +} + +func (p *page) branchPageElement(index uint16) *branchPageElement { + off := pageHeaderSize + uintptr(index)*branchPageElementSize + return (*branchPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off)) +} + +func (p *page) branchPageElements() []branchPageElement { + if p.count == 0 { + return nil + } + return *(*[]branchPageElement)(p.dataPtr()) +} diff --git a/gee-bolt/day2-mmap/db.go b/gee-bolt/day2-mmap/db.go new file mode 100755 index 0000000..9c644d1 --- /dev/null +++ b/gee-bolt/day2-mmap/db.go @@ -0,0 +1,18 @@ +package geebolt + +import "os" + +type DB struct { + data []byte + file *os.File +} + +const maxMapSize = 1 << 31 + +func (db *DB) mmap(sz int) error { + b, err := syscall.Mmap() +} + +func Open(path string) { + +} diff --git a/gee-bolt/day2-mmap/go.mod b/gee-bolt/day2-mmap/go.mod new file mode 100755 index 0000000..17b5990 --- /dev/null +++ b/gee-bolt/day2-mmap/go.mod @@ -0,0 +1,3 @@ +module geebolt + +go 1.13 diff --git a/gee-bolt/day3-tree/go.mod b/gee-bolt/day3-tree/go.mod new file mode 100755 index 0000000..17b5990 --- /dev/null +++ b/gee-bolt/day3-tree/go.mod @@ -0,0 +1,3 @@ +module geebolt + +go 1.13 diff --git a/gee-bolt/day3-tree/meta.go b/gee-bolt/day3-tree/meta.go new file mode 100644 index 0000000..4e9cdb1 --- /dev/null +++ b/gee-bolt/day3-tree/meta.go @@ -0,0 +1,33 @@ +package geebolt + +import ( + "errors" + "hash/fnv" + "unsafe" +) + +// Represent a marker value to indicate that a file is a gee-bolt DB +const magic uint32 = 0xED0CDAED + +type meta struct { + magic uint32 + pageSize uint32 + pgid uint64 + checksum uint64 +} + +func (m *meta) sum64() uint64 { + var h = fnv.New64a() + _, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:]) + return h.Sum64() +} + +func (m *meta) validate() error { + if m.magic != magic { + return errors.New("invalid magic number") + } + if m.checksum != m.sum64() { + return errors.New("invalid checksum") + } + return nil +} diff --git a/gee-bolt/day3-tree/node.go b/gee-bolt/day3-tree/node.go new file mode 100755 index 0000000..a40c51a --- /dev/null +++ b/gee-bolt/day3-tree/node.go @@ -0,0 +1,53 @@ +package geebolt + +import ( + "bytes" + "sort" +) + +type kv struct { + key []byte + value []byte +} + +type node struct { + isLeaf bool + key []byte + parent *node + children []*node + kvs []kv +} + +func (n *node) root() *node { + if n.parent == nil { + return n + } + return n.parent.root() +} + +func (n *node) index(key []byte) (index int, exact bool) { + index = sort.Search(len(n.kvs), func(i int) bool { + return bytes.Compare(n.kvs[i].key, key) != -1 + }) + exact = len(n.kvs) > 0 && index < len(n.kvs) && bytes.Equal(n.kvs[index].key, key) + return +} + +func (n *node) put(oldKey, newKey, value []byte) { + index, exact := n.index(oldKey) + if !exact { + n.kvs = append(n.kvs, kv{}) + copy(n.kvs[index+1:], n.kvs[index:]) + } + kv := &n.kvs[index] + kv.key = newKey + kv.value = value +} + +func (n *node) del(key []byte) { + index, exact := n.index(key) + if exact { + n.kvs = append(n.kvs[:index], n.kvs[index+1:]...) + } +} + diff --git a/gee-bolt/day3-tree/page.go b/gee-bolt/day3-tree/page.go new file mode 100644 index 0000000..18fc492 --- /dev/null +++ b/gee-bolt/day3-tree/page.go @@ -0,0 +1,88 @@ +package geebolt + +import ( + "fmt" + "reflect" + "unsafe" +) + +const pageHeaderSize = unsafe.Sizeof(page{}) +const branchPageElementSize = unsafe.Sizeof(branchPageElement{}) +const leafPageElementSize = unsafe.Sizeof(leafPageElement{}) +const maxKeysPerPage = 1024 + +const ( + branchPageFlag uint16 = iota + leafPageFlag + metaPageFlag + freelistPageFlag +) + +type page struct { + id uint64 + flags uint16 + count uint16 + overflow uint32 +} + +type leafPageElement struct { + pos uint32 + ksize uint32 + vsize uint32 +} + +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid uint64 +} + +func (p *page) typ() string { + switch p.flags { + case branchPageFlag: + return "branch" + case leafPageFlag: + return "leaf" + case metaPageFlag: + return "meta" + case freelistPageFlag: + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +func (p *page) meta() *meta { + return (*meta)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + pageHeaderSize)) +} + +func (p *page) dataPtr() unsafe.Pointer { + return unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(p)) + pageHeaderSize, + Len: int(p.count), + Cap: int(p.count), + }) +} + +func (p *page) leafPageElement(index uint16) *leafPageElement { + off := pageHeaderSize + uintptr(index)*leafPageElementSize + return (*leafPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off)) +} + +func (p *page) leafPageElements() []leafPageElement { + if p.count == 0 { + return nil + } + return *(*[]leafPageElement)(p.dataPtr()) +} + +func (p *page) branchPageElement(index uint16) *branchPageElement { + off := pageHeaderSize + uintptr(index)*branchPageElementSize + return (*branchPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off)) +} + +func (p *page) branchPageElements() []branchPageElement { + if p.count == 0 { + return nil + } + return *(*[]branchPageElement)(p.dataPtr()) +} diff --git a/gee-cache/day1-lru/geecache/go.mod b/gee-cache/day1-lru/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day1-lru/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day1-lru/geecache/lru/lru.go b/gee-cache/day1-lru/geecache/lru/lru.go new file mode 100644 index 0000000..c4f88b9 --- /dev/null +++ b/gee-cache/day1-lru/geecache/lru/lru.go @@ -0,0 +1,93 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 // 是允许使用的最大内存 + nbytes int64 // nbytes 是当前已使用的内存 + ll *list.List // 在这里我们直接使用 Go 语言标准库实现的双向链表list.List + cache map[string]*list.Element // 键是字符串,值是双向链表中对应节点的指针 + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) // 是某条记录被移除时的回调函数,可以为 nil +} + +// 键值对 entry 是双向链表节点的数据类型,在链表中仍保存每个值对应的 key 的好处在于,淘汰队首节点时,需要用 key 从字典中删除对应的映射 +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +// 为了通用性,我们允许值是实现了 Value 接口的任意类型,该接口只包含了一个方法 Len() int,用于返回值所占用的内存大小。 +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + // 如果键对应的链表节点存在,则将对应节点移动到队尾,并返回查找到的值 + // 如果键存在,则更新对应节点的值,并将该节点移到队尾 + if ele, ok := c.cache[key]; ok { + // c.ll.MoveToFront(ele),即将链表中的节点 ele 移动到队尾(双向链表作为队列,队首队尾是相对的,在这里约定 front 为队尾) + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + // 不存在则是新增场景,首先队尾添加新节点 &entry{key, value}, 并字典中添加 key 和节点的映射关系。 + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + // 更新 c.nbytes,如果超过了设定的最大值 c.maxBytes,则移除最少访问的节点 + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +// 查找主要有 2 个步骤,第一步是从字典中找到对应的双向链表的节点,第二步,将该节点移动到队尾 +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +// 这里的删除,实际上是缓存淘汰。即移除最近最少访问的节点(队首) +func (c *Cache) RemoveOldest() { + // c.ll.Back() 取到队首节点,从链表中删除 + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + // delete(c.cache, kv.key),从字典中 c.cache 删除该节点的映射关系 + delete(c.cache, kv.key) + // 更新当前所用的内存 c.nbytes + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + // 如果回调函数 OnEvicted 不为 nil,则调用回调函数 + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +// 最后,为了方便测试,我们实现 Len() 用来获取添加了多少条数据 +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day1-lru/geecache/lru/lru_test.go b/gee-cache/day1-lru/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day1-lru/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day2-single-node/geecache/byteview.go b/gee-cache/day2-single-node/geecache/byteview.go new file mode 100644 index 0000000..62a793a --- /dev/null +++ b/gee-cache/day2-single-node/geecache/byteview.go @@ -0,0 +1,30 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +// ByteView 只有一个数据成员,b []byte,b 将会存储真实的缓存值。选择 byte 类型是为了能够支持任意的数据类型的存储,例如字符串、图片等。 +type ByteView struct { + b []byte +} + +// Len returns the view's length +// 实现 Len() int 方法,我们在 lru.Cache 的实现中,要求被缓存对象必须实现 Value 接口,即 Len() int 方法,返回其所占的内存大小。 +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +// b 是只读的,使用 ByteSlice() 方法返回一个拷贝,防止缓存值被外部程序修改。 +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day2-single-node/geecache/cache.go b/gee-cache/day2-single-node/geecache/cache.go new file mode 100644 index 0000000..6f56ab8 --- /dev/null +++ b/gee-cache/day2-single-node/geecache/cache.go @@ -0,0 +1,38 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +// cache.go 的实现非常简单,实例化 lru,封装 get 和 add 方法,并添加互斥锁 mu。 +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +// 在 add 方法中,判断了 c.lru 是否为 nil,如果等于 nil 再创建实例。这种方法称之为延迟初始化(Lazy Initialization), +// 一个对象的延迟初始化意味着该对象的创建将会延迟至第一次使用该对象时。主要用于提高性能,并减少程序内存要求。 +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day2-single-node/geecache/geecache.go b/gee-cache/day2-single-node/geecache/geecache.go new file mode 100644 index 0000000..a3ca247 --- /dev/null +++ b/gee-cache/day2-single-node/geecache/geecache.go @@ -0,0 +1,107 @@ +package geecache + +import ( + "fmt" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +// 一个 Group 可以认为是一个缓存的命名空间 +type Group struct { + // 每个 Group 拥有一个唯一的名称 name。比如可以创建三个 Group, + // 缓存学生的成绩命名为 scores,缓存学生信息的命名为 info,缓存学生课程的命名为 courses。 + name string + // 第二个属性是 getter Getter,即缓存未命中时获取源数据的回调(callback)。 + getter Getter + // 第三个属性是 mainCache cache,即一开始实现的并发缓存 + mainCache cache +} + +// A Getter loads data for a key. +// 定义接口 Getter 和 回调函数 Get(key string)([]byte, error),参数是 key,返回值是 []byte。 +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +// 定义函数类型 GetterFunc,并实现 Getter 接口的 Get 方法。 +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +// 函数类型实现某一个接口,称之为接口型函数,方便使用者在调用时既能够传入函数作为参数,也能够传入实现了该接口的结构体作为参数。 +// 补充: +// 这里呢,定义了一个接口 Getter,只包含一个方法 Get(key string) ([]byte, error),紧接着定义了一个函数类型 GetterFunc, +// GetterFunc 参数和返回值与 Getter 中 Get 方法是一致的。而且 GetterFunc 还定义了 Get 方式,并在 Get 方法中调用自己, +// 这样就实现了接口 Getter。所以 GetterFunc 是一个实现了接口的函数类型,简称为接口型函数。 +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +// 构建函数 NewGroup 用来实例化 Group,并且将 group 存储在全局变量 groups 中。 +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + // GetGroup 用来特定名称的 Group,这里使用了只读锁 RLock(),因为不涉及任何冲突变量的写操作 + mu.RUnlock() + return g +} + +// Get value for a key from cache +// Get 方法实现了上述所说的流程 ⑴ 和 ⑶。 +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + // 流程 ⑴ :从 mainCache 中查找缓存,如果存在则返回缓存值 + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + // 流程 ⑶ :缓存不存在,则调用 load 方法 + return g.load(key) +} + +// load 调用 getLocally(分布式场景下会调用 getFromPeer 从其他节点获取), +func (g *Group) load(key string) (value ByteView, err error) { + return g.getLocally(key) +} + +// getLocally 调用用户回调函数 g.getter.Get() 获取源数据,并且将源数据添加到缓存 mainCache 中(通过 populateCache 方法) +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} diff --git a/gee-cache/day2-single-node/geecache/geecache_test.go b/gee-cache/day2-single-node/geecache/geecache_test.go new file mode 100644 index 0000000..d507b9a --- /dev/null +++ b/gee-cache/day2-single-node/geecache/geecache_test.go @@ -0,0 +1,72 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +// 在这个测试用例中,我们借助 GetterFunc 的类型转换,将一个匿名回调函数转换成了接口 f Getter。 +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + expect := []byte("key") + // 调用该接口的方法 f.Get(key string),实际上就是在调用匿名回调函数。 + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + // 如果存在db中 + if v, ok := db[key]; ok { + // 如果该key没有调用回调函数,就初始化一下 + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + // count++ + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + // 在缓存已经存在的情况下,是否直接从缓存中获取,为了实现这一点,使用 loadCounts 统计某个键调用回调函数的次数,如果次数大于1,则表示调用了多次回调函数,没有缓存。 + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day2-single-node/geecache/go.mod b/gee-cache/day2-single-node/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day2-single-node/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day2-single-node/geecache/lru/lru.go b/gee-cache/day2-single-node/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day2-single-node/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day2-single-node/geecache/lru/lru_test.go b/gee-cache/day2-single-node/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day2-single-node/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day3-http-server/geecache/byteview.go b/gee-cache/day3-http-server/geecache/byteview.go new file mode 100644 index 0000000..3ee1022 --- /dev/null +++ b/gee-cache/day3-http-server/geecache/byteview.go @@ -0,0 +1,27 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day3-http-server/geecache/cache.go b/gee-cache/day3-http-server/geecache/cache.go new file mode 100644 index 0000000..665c3f3 --- /dev/null +++ b/gee-cache/day3-http-server/geecache/cache.go @@ -0,0 +1,35 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day3-http-server/geecache/geecache.go b/gee-cache/day3-http-server/geecache/geecache.go new file mode 100644 index 0000000..e289018 --- /dev/null +++ b/gee-cache/day3-http-server/geecache/geecache.go @@ -0,0 +1,90 @@ +package geecache + +import ( + "fmt" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache +} + +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} + +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +func (g *Group) load(key string) (value ByteView, err error) { + return g.getLocally(key) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} diff --git a/gee-cache/day3-http-server/geecache/geecache_test.go b/gee-cache/day3-http-server/geecache/geecache_test.go new file mode 100644 index 0000000..7ef9f4f --- /dev/null +++ b/gee-cache/day3-http-server/geecache/geecache_test.go @@ -0,0 +1,67 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day3-http-server/geecache/go.mod b/gee-cache/day3-http-server/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day3-http-server/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day3-http-server/geecache/http.go b/gee-cache/day3-http-server/geecache/http.go new file mode 100644 index 0000000..468ef5e --- /dev/null +++ b/gee-cache/day3-http-server/geecache/http.go @@ -0,0 +1,63 @@ +package geecache + +import ( + "fmt" + "log" + "net/http" + "strings" +) + +const defaultBasePath = "/_geecache/" + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} + +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // ServeHTTP 的实现逻辑是比较简单的,首先判断访问路径的前缀是否是 basePath,不是返回错误。 + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + // 我们约定访问路径格式为 /// + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + groupName := parts[0] + key := parts[1] + // 通过 groupname 得到 group 实例 + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + // 再使用 group.Get(key) 获取缓存数据 + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + // 最终使用 w.Write() 将缓存值作为 httpResponse 的 body 返回 + w.Write(view.ByteSlice()) +} diff --git a/gee-cache/day3-http-server/geecache/lru/lru.go b/gee-cache/day3-http-server/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day3-http-server/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day3-http-server/geecache/lru/lru_test.go b/gee-cache/day3-http-server/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day3-http-server/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day3-http-server/go.mod b/gee-cache/day3-http-server/go.mod new file mode 100644 index 0000000..d0fd3ba --- /dev/null +++ b/gee-cache/day3-http-server/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require geecache v0.0.0 + +replace geecache => ./geecache diff --git a/gee-cache/day3-http-server/main.go b/gee-cache/day3-http-server/main.go new file mode 100644 index 0000000..dff9003 --- /dev/null +++ b/gee-cache/day3-http-server/main.go @@ -0,0 +1,39 @@ +package main + +/* +$ curl http://localhost:9999/_geecache/scores/Tom +630 + +$ curl http://localhost:9999/_geecache/scores/kkk +kkk not exist +*/ + +import ( + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func main() { + // 创建一个名为 scores 的 Group,若缓存为空,回调函数会从 db 中获取数据并返回。 + geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + // 使用 http.ListenAndServe 在 9999 端口启动了 HTTP 服务。 + addr := "localhost:9999" + peers := geecache.NewHTTPPool(addr) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr, peers)) +} diff --git a/gee-cache/day4-consistent-hash/geecache/byteview.go b/gee-cache/day4-consistent-hash/geecache/byteview.go new file mode 100644 index 0000000..3ee1022 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/byteview.go @@ -0,0 +1,27 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day4-consistent-hash/geecache/cache.go b/gee-cache/day4-consistent-hash/geecache/cache.go new file mode 100644 index 0000000..665c3f3 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/cache.go @@ -0,0 +1,35 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash.go b/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash.go new file mode 100644 index 0000000..3227290 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash.go @@ -0,0 +1,66 @@ +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +// Hash maps bytes to uint32 +type Hash func(data []byte) uint32 + +// Map constains all hashed keys +type Map struct { + hash Hash // 定义了函数类型 Hash, + replicas int // 虚拟节点倍数 + keys []int // Sorted 哈希环 + hashMap map[int]string // 虚拟节点与真实节点的映射表 hashMap,键是虚拟节点的哈希值,值是真实节点的名称。 +} + +// New creates a Map instance 构造函数 New() 允许自定义虚拟节点倍数和 Hash 函数。 +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE // 采取依赖注入的方式,允许用于替换成自定义的 Hash 函数,也方便测试时替换,默认为 crc32.ChecksumIEEE 算法。 + } + return m +} + +// Add adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + // 对每一个真实节点 key,对应创建 m.replicas 个虚拟节点,虚拟节点的名称是:strconv.Itoa(i) + key,即通过添加编号的方式区分不同虚拟节点。 + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + // 使用 m.hash() 计算虚拟节点的哈希值,使用 append(m.keys, hash) 添加到环上。 + m.keys = append(m.keys, hash) + // 在 hashMap 中增加虚拟节点和真实节点的映射关系 + m.hashMap[hash] = key + } + } + // 最后一步,环上的哈希值排序。 + sort.Ints(m.keys) +} + +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if len(m.keys) == 0 { + return "" + } + // 第一步,计算 key 的哈希值。 + hash := int(m.hash([]byte(key))) + // Binary search for appropriate replica. + // 第二步,顺时针找到第一个匹配的虚拟节点的下标 idx,从 m.keys 中获取到对应的哈希值。如果 idx == len(m.keys), + // 说明应选择 m.keys[0], + idx := sort.Search(len(m.keys), func(i int) bool { + // 寻找到第一个大于这个hash值的keys[i]的坐标idx + return m.keys[i] >= hash + }) + // 第三步,通过 hashMap 映射得到真实的节点。 + // 因为 m.keys 是一个环状结构,所以用取余数的方式来处理这种情况。 + return m.hashMap[m.keys[idx%len(m.keys)]] +} diff --git a/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash_test.go b/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash_test.go new file mode 100644 index 0000000..5faab00 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash_test.go @@ -0,0 +1,45 @@ +package consistenthash + +import ( + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + // 自定义的 Hash 算法只处理数字,传入字符串表示的数字,返回对应的数字即可。 + hash := New(3, func(key []byte) uint32 { + i, _ := strconv.Atoi(string(key)) + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") // 一开始,有 2/4/6 三个真实节点,对应的虚拟节点的哈希值是 02/12/22、04/14/24、06/16/26。 + //那么用例 2/11/23/27 选择的虚拟节点分别是 02/12/24/02,也就是真实节点 2/2/4/2。 + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + "3": "4", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} diff --git a/gee-cache/day4-consistent-hash/geecache/geecache.go b/gee-cache/day4-consistent-hash/geecache/geecache.go new file mode 100644 index 0000000..e289018 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/geecache.go @@ -0,0 +1,90 @@ +package geecache + +import ( + "fmt" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache +} + +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} + +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +func (g *Group) load(key string) (value ByteView, err error) { + return g.getLocally(key) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} diff --git a/gee-cache/day4-consistent-hash/geecache/geecache_test.go b/gee-cache/day4-consistent-hash/geecache/geecache_test.go new file mode 100644 index 0000000..7ef9f4f --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/geecache_test.go @@ -0,0 +1,67 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day4-consistent-hash/geecache/go.mod b/gee-cache/day4-consistent-hash/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day4-consistent-hash/geecache/http.go b/gee-cache/day4-consistent-hash/geecache/http.go new file mode 100644 index 0000000..b9b994e --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/http.go @@ -0,0 +1,62 @@ +package geecache + +import ( + "fmt" + "log" + "net/http" + "strings" +) + +const defaultBasePath = "/_geecache/" + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} + +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + groupName := parts[0] + key := parts[1] + + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) +} diff --git a/gee-cache/day4-consistent-hash/geecache/lru/lru.go b/gee-cache/day4-consistent-hash/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day4-consistent-hash/geecache/lru/lru_test.go b/gee-cache/day4-consistent-hash/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day4-consistent-hash/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day4-consistent-hash/go.mod b/gee-cache/day4-consistent-hash/go.mod new file mode 100644 index 0000000..d0fd3ba --- /dev/null +++ b/gee-cache/day4-consistent-hash/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require geecache v0.0.0 + +replace geecache => ./geecache diff --git a/gee-cache/day4-consistent-hash/main.go b/gee-cache/day4-consistent-hash/main.go new file mode 100644 index 0000000..5442dd7 --- /dev/null +++ b/gee-cache/day4-consistent-hash/main.go @@ -0,0 +1,38 @@ +package main + +/* +$ curl http://localhost:9999/_geecache/scores/Tom +630 + +$ curl http://localhost:9999/_geecache/scores/kkk +kkk not exist +*/ + +import ( + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func main() { + geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + addr := "localhost:9999" + peers := geecache.NewHTTPPool(addr) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr, peers)) +} diff --git a/gee-cache/day5-multi-nodes/geecache/byteview.go b/gee-cache/day5-multi-nodes/geecache/byteview.go new file mode 100644 index 0000000..3ee1022 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/byteview.go @@ -0,0 +1,27 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day5-multi-nodes/geecache/cache.go b/gee-cache/day5-multi-nodes/geecache/cache.go new file mode 100644 index 0000000..665c3f3 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/cache.go @@ -0,0 +1,35 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash.go b/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash.go new file mode 100644 index 0000000..c8c9082 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash.go @@ -0,0 +1,58 @@ +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +// Hash maps bytes to uint32 +type Hash func(data []byte) uint32 + +// Map constains all hashed keys +type Map struct { + hash Hash + replicas int + keys []int // Sorted + hashMap map[int]string +} + +// New creates a Map instance +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE + } + return m +} + +// Add adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} + +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if len(m.keys) == 0 { + return "" + } + + hash := int(m.hash([]byte(key))) + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { + return m.keys[i] >= hash + }) + + return m.hashMap[m.keys[idx%len(m.keys)]] +} diff --git a/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash_test.go b/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash_test.go new file mode 100644 index 0000000..34e1275 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash_test.go @@ -0,0 +1,43 @@ +package consistenthash + +import ( + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + hash := New(3, func(key []byte) uint32 { + i, _ := strconv.Atoi(string(key)) + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") + + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} diff --git a/gee-cache/day5-multi-nodes/geecache/geecache.go b/gee-cache/day5-multi-nodes/geecache/geecache.go new file mode 100644 index 0000000..5372a4a --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/geecache.go @@ -0,0 +1,116 @@ +package geecache + +import ( + "fmt" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache + peers PeerPicker +} + +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} + +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +// RegisterPeers registers a PeerPicker for choosing remote peer +func (g *Group) RegisterPeers(peers PeerPicker) { + if g.peers != nil { + panic("RegisterPeerPicker called more than once") + } + g.peers = peers +} + +func (g *Group) load(key string) (value ByteView, err error) { + if g.peers != nil { + if peer, ok := g.peers.PickPeer(key); ok { + if value, err = g.getFromPeer(peer, key); err == nil { + return value, nil + } + log.Println("[GeeCache] Failed to get from peer", err) + } + } + + return g.getLocally(key) +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) { + bytes, err := peer.Get(g.name, key) + if err != nil { + return ByteView{}, err + } + return ByteView{b: bytes}, nil +} diff --git a/gee-cache/day5-multi-nodes/geecache/geecache_test.go b/gee-cache/day5-multi-nodes/geecache/geecache_test.go new file mode 100644 index 0000000..7ef9f4f --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/geecache_test.go @@ -0,0 +1,67 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day5-multi-nodes/geecache/go.mod b/gee-cache/day5-multi-nodes/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day5-multi-nodes/geecache/http.go b/gee-cache/day5-multi-nodes/geecache/http.go new file mode 100644 index 0000000..815591f --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/http.go @@ -0,0 +1,128 @@ +package geecache + +import ( + "fmt" + "geecache/consistenthash" + "io/ioutil" + "log" + "net/http" + "net/url" + "strings" + "sync" +) + +const ( + defaultBasePath = "/_geecache/" + defaultReplicas = 50 +) + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string + mu sync.Mutex // guards peers and httpGetters + peers *consistenthash.Map + httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008" +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} + +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + groupName := parts[0] + key := parts[1] + + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) +} + +// Set updates the pool's list of peers. +func (p *HTTPPool) Set(peers ...string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peers = consistenthash.New(defaultReplicas, nil) + p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath} + } +} + +// PickPeer picks a peer according to key +func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if peer := p.peers.Get(key); peer != "" && peer != p.self { + p.Log("Pick peer %s", peer) + return p.httpGetters[peer], true + } + return nil, false +} + +var _ PeerPicker = (*HTTPPool)(nil) + +type httpGetter struct { + baseURL string +} + +func (h *httpGetter) Get(group string, key string) ([]byte, error) { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(group), + url.QueryEscape(key), + ) + res, err := http.Get(u) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned: %v", res.Status) + } + + bytes, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %v", err) + } + + return bytes, nil +} + +var _ PeerGetter = (*httpGetter)(nil) diff --git a/gee-cache/day5-multi-nodes/geecache/lru/lru.go b/gee-cache/day5-multi-nodes/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day5-multi-nodes/geecache/lru/lru_test.go b/gee-cache/day5-multi-nodes/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day5-multi-nodes/geecache/peers.go b/gee-cache/day5-multi-nodes/geecache/peers.go new file mode 100644 index 0000000..6246267 --- /dev/null +++ b/gee-cache/day5-multi-nodes/geecache/peers.go @@ -0,0 +1,14 @@ +package geecache + +// PeerPicker is the interface that must be implemented to locate +// the peer that owns a specific key. +// PeerPicker 的 PickPeer() 方法用于根据传入的 key 选择相应节点 PeerGetter。 +type PeerPicker interface { + PickPeer(key string) (peer PeerGetter, ok bool) +} + +// PeerGetter is the interface that must be implemented by a peer. +// 接口 PeerGetter 的 Get() 方法用于从对应 group 查找缓存值。PeerGetter 就对应于上述流程中的 HTTP 客户端 +type PeerGetter interface { + Get(group string, key string) ([]byte, error) +} diff --git a/gee-cache/day5-multi-nodes/go.mod b/gee-cache/day5-multi-nodes/go.mod new file mode 100644 index 0000000..d0fd3ba --- /dev/null +++ b/gee-cache/day5-multi-nodes/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require geecache v0.0.0 + +replace geecache => ./geecache diff --git a/gee-cache/day5-multi-nodes/main.go b/gee-cache/day5-multi-nodes/main.go new file mode 100644 index 0000000..56abc7e --- /dev/null +++ b/gee-cache/day5-multi-nodes/main.go @@ -0,0 +1,86 @@ +package main + +/* +$ curl "http://localhost:9999/api?key=Tom" +630 + +$ curl "http://localhost:9999/api?key=kkk" +kkk not exist +*/ + +import ( + "flag" + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func createGroup() *geecache.Group { + return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) +} + +func startCacheServer(addr string, addrs []string, gee *geecache.Group) { + peers := geecache.NewHTTPPool(addr) + peers.Set(addrs...) + gee.RegisterPeers(peers) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr[7:], peers)) +} + +func startAPIServer(apiAddr string, gee *geecache.Group) { + http.Handle("/api", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + key := r.URL.Query().Get("key") + view, err := gee.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) + + })) + log.Println("fontend server is running at", apiAddr) + log.Fatal(http.ListenAndServe(apiAddr[7:], nil)) + +} + +func main() { + var port int + var api bool + flag.IntVar(&port, "port", 8001, "Geecache server port") + flag.BoolVar(&api, "api", false, "Start a api server?") + flag.Parse() + + apiAddr := "http://localhost:9999" + addrMap := map[int]string{ + 8001: "http://localhost:8001", + 8002: "http://localhost:8002", + 8003: "http://localhost:8003", + } + + var addrs []string + for _, v := range addrMap { + addrs = append(addrs, v) + } + + gee := createGroup() + if api { + go startAPIServer(apiAddr, gee) + } + startCacheServer(addrMap[port], addrs, gee) +} diff --git a/gee-cache/day5-multi-nodes/run.sh b/gee-cache/day5-multi-nodes/run.sh new file mode 100755 index 0000000..066979d --- /dev/null +++ b/gee-cache/day5-multi-nodes/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash +trap "rm server;kill 0" EXIT + +go build -o server +./server -port=8001 & +./server -port=8002 & +./server -port=8003 -api=1 & + +sleep 2 +echo ">>> start test" +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & + +wait \ No newline at end of file diff --git a/gee-cache/day6-single-flight/geecache/byteview.go b/gee-cache/day6-single-flight/geecache/byteview.go new file mode 100644 index 0000000..3ee1022 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/byteview.go @@ -0,0 +1,27 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day6-single-flight/geecache/cache.go b/gee-cache/day6-single-flight/geecache/cache.go new file mode 100644 index 0000000..665c3f3 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/cache.go @@ -0,0 +1,35 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash.go b/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash.go new file mode 100644 index 0000000..c8c9082 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash.go @@ -0,0 +1,58 @@ +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +// Hash maps bytes to uint32 +type Hash func(data []byte) uint32 + +// Map constains all hashed keys +type Map struct { + hash Hash + replicas int + keys []int // Sorted + hashMap map[int]string +} + +// New creates a Map instance +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE + } + return m +} + +// Add adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} + +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if len(m.keys) == 0 { + return "" + } + + hash := int(m.hash([]byte(key))) + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { + return m.keys[i] >= hash + }) + + return m.hashMap[m.keys[idx%len(m.keys)]] +} diff --git a/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash_test.go b/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash_test.go new file mode 100644 index 0000000..34e1275 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/consistenthash/consistenthash_test.go @@ -0,0 +1,43 @@ +package consistenthash + +import ( + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + hash := New(3, func(key []byte) uint32 { + i, _ := strconv.Atoi(string(key)) + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") + + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} diff --git a/gee-cache/day6-single-flight/geecache/geecache.go b/gee-cache/day6-single-flight/geecache/geecache.go new file mode 100644 index 0000000..69004ac --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/geecache.go @@ -0,0 +1,130 @@ +package geecache + +import ( + "fmt" + "geecache/singleflight" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache + peers PeerPicker + // use singleflight.Group to make sure that + // each key is only fetched once + loader *singleflight.Group +} + +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + loader: &singleflight.Group{}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} + +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +// RegisterPeers registers a PeerPicker for choosing remote peer +func (g *Group) RegisterPeers(peers PeerPicker) { + if g.peers != nil { + panic("RegisterPeerPicker called more than once") + } + g.peers = peers +} + +func (g *Group) load(key string) (value ByteView, err error) { + // each key is only fetched once (either locally or remotely) + // regardless of the number of concurrent callers. + viewi, err := g.loader.Do(key, func() (interface{}, error) { + if g.peers != nil { + if peer, ok := g.peers.PickPeer(key); ok { + if value, err = g.getFromPeer(peer, key); err == nil { + return value, nil + } + log.Println("[GeeCache] Failed to get from peer", err) + } + } + + return g.getLocally(key) + }) + + if err == nil { + return viewi.(ByteView), nil + } + return +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) { + bytes, err := peer.Get(g.name, key) + if err != nil { + return ByteView{}, err + } + return ByteView{b: bytes}, nil +} diff --git a/gee-cache/day6-single-flight/geecache/geecache_test.go b/gee-cache/day6-single-flight/geecache/geecache_test.go new file mode 100644 index 0000000..7ef9f4f --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/geecache_test.go @@ -0,0 +1,67 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day6-single-flight/geecache/go.mod b/gee-cache/day6-single-flight/geecache/go.mod new file mode 100644 index 0000000..f9d454e --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/go.mod @@ -0,0 +1,3 @@ +module geecache + +go 1.13 diff --git a/gee-cache/day6-single-flight/geecache/http.go b/gee-cache/day6-single-flight/geecache/http.go new file mode 100644 index 0000000..815591f --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/http.go @@ -0,0 +1,128 @@ +package geecache + +import ( + "fmt" + "geecache/consistenthash" + "io/ioutil" + "log" + "net/http" + "net/url" + "strings" + "sync" +) + +const ( + defaultBasePath = "/_geecache/" + defaultReplicas = 50 +) + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string + mu sync.Mutex // guards peers and httpGetters + peers *consistenthash.Map + httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008" +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} + +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + groupName := parts[0] + key := parts[1] + + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) +} + +// Set updates the pool's list of peers. +func (p *HTTPPool) Set(peers ...string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peers = consistenthash.New(defaultReplicas, nil) + p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath} + } +} + +// PickPeer picks a peer according to key +func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if peer := p.peers.Get(key); peer != "" && peer != p.self { + p.Log("Pick peer %s", peer) + return p.httpGetters[peer], true + } + return nil, false +} + +var _ PeerPicker = (*HTTPPool)(nil) + +type httpGetter struct { + baseURL string +} + +func (h *httpGetter) Get(group string, key string) ([]byte, error) { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(group), + url.QueryEscape(key), + ) + res, err := http.Get(u) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned: %v", res.Status) + } + + bytes, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %v", err) + } + + return bytes, nil +} + +var _ PeerGetter = (*httpGetter)(nil) diff --git a/gee-cache/day6-single-flight/geecache/lru/lru.go b/gee-cache/day6-single-flight/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day6-single-flight/geecache/lru/lru_test.go b/gee-cache/day6-single-flight/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day6-single-flight/geecache/peers.go b/gee-cache/day6-single-flight/geecache/peers.go new file mode 100644 index 0000000..8d010e2 --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/peers.go @@ -0,0 +1,12 @@ +package geecache + +// PeerPicker is the interface that must be implemented to locate +// the peer that owns a specific key. +type PeerPicker interface { + PickPeer(key string) (peer PeerGetter, ok bool) +} + +// PeerGetter is the interface that must be implemented by a peer. +type PeerGetter interface { + Get(group string, key string) ([]byte, error) +} diff --git a/gee-cache/day6-single-flight/geecache/singleflight/singleflight.go b/gee-cache/day6-single-flight/geecache/singleflight/singleflight.go new file mode 100644 index 0000000..85bd0dd --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/singleflight/singleflight.go @@ -0,0 +1,46 @@ +package singleflight + +import "sync" + +// call is an in-flight or completed Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// Group represents a class of work and forms a namespace in which +// units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} diff --git a/gee-cache/day6-single-flight/geecache/singleflight/singleflight_test.go b/gee-cache/day6-single-flight/geecache/singleflight/singleflight_test.go new file mode 100644 index 0000000..450951a --- /dev/null +++ b/gee-cache/day6-single-flight/geecache/singleflight/singleflight_test.go @@ -0,0 +1,16 @@ +package singleflight + +import ( + "testing" +) + +func TestDo(t *testing.T) { + var g Group + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + + if v != "bar" || err != nil { + t.Errorf("Do v = %v, error = %v", v, err) + } +} diff --git a/gee-cache/day6-single-flight/go.mod b/gee-cache/day6-single-flight/go.mod new file mode 100644 index 0000000..d0fd3ba --- /dev/null +++ b/gee-cache/day6-single-flight/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require geecache v0.0.0 + +replace geecache => ./geecache diff --git a/gee-cache/day6-single-flight/main.go b/gee-cache/day6-single-flight/main.go new file mode 100644 index 0000000..56abc7e --- /dev/null +++ b/gee-cache/day6-single-flight/main.go @@ -0,0 +1,86 @@ +package main + +/* +$ curl "http://localhost:9999/api?key=Tom" +630 + +$ curl "http://localhost:9999/api?key=kkk" +kkk not exist +*/ + +import ( + "flag" + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func createGroup() *geecache.Group { + return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) +} + +func startCacheServer(addr string, addrs []string, gee *geecache.Group) { + peers := geecache.NewHTTPPool(addr) + peers.Set(addrs...) + gee.RegisterPeers(peers) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr[7:], peers)) +} + +func startAPIServer(apiAddr string, gee *geecache.Group) { + http.Handle("/api", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + key := r.URL.Query().Get("key") + view, err := gee.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) + + })) + log.Println("fontend server is running at", apiAddr) + log.Fatal(http.ListenAndServe(apiAddr[7:], nil)) + +} + +func main() { + var port int + var api bool + flag.IntVar(&port, "port", 8001, "Geecache server port") + flag.BoolVar(&api, "api", false, "Start a api server?") + flag.Parse() + + apiAddr := "http://localhost:9999" + addrMap := map[int]string{ + 8001: "http://localhost:8001", + 8002: "http://localhost:8002", + 8003: "http://localhost:8003", + } + + var addrs []string + for _, v := range addrMap { + addrs = append(addrs, v) + } + + gee := createGroup() + if api { + go startAPIServer(apiAddr, gee) + } + startCacheServer(addrMap[port], addrs, gee) +} diff --git a/gee-cache/day6-single-flight/run.sh b/gee-cache/day6-single-flight/run.sh new file mode 100755 index 0000000..28b421c --- /dev/null +++ b/gee-cache/day6-single-flight/run.sh @@ -0,0 +1,19 @@ +#!/bin/bash +trap "rm server;kill 0" EXIT + +go build -o server +./server -port=8001 & +./server -port=8002 & +./server -port=8003 -api=1 & + +sleep 2 +echo ">>> start test" +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & + +wait \ No newline at end of file diff --git a/gee-cache/day7-proto-buf/geecache/byteview.go b/gee-cache/day7-proto-buf/geecache/byteview.go new file mode 100644 index 0000000..3ee1022 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/byteview.go @@ -0,0 +1,27 @@ +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/gee-cache/day7-proto-buf/geecache/cache.go b/gee-cache/day7-proto-buf/geecache/cache.go new file mode 100644 index 0000000..665c3f3 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/cache.go @@ -0,0 +1,35 @@ +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} diff --git a/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash.go b/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash.go new file mode 100644 index 0000000..c8c9082 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash.go @@ -0,0 +1,58 @@ +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +// Hash maps bytes to uint32 +type Hash func(data []byte) uint32 + +// Map constains all hashed keys +type Map struct { + hash Hash + replicas int + keys []int // Sorted + hashMap map[int]string +} + +// New creates a Map instance +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE + } + return m +} + +// Add adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} + +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if len(m.keys) == 0 { + return "" + } + + hash := int(m.hash([]byte(key))) + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { + return m.keys[i] >= hash + }) + + return m.hashMap[m.keys[idx%len(m.keys)]] +} diff --git a/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash_test.go b/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash_test.go new file mode 100644 index 0000000..34e1275 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash_test.go @@ -0,0 +1,43 @@ +package consistenthash + +import ( + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + hash := New(3, func(key []byte) uint32 { + i, _ := strconv.Atoi(string(key)) + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") + + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} diff --git a/gee-cache/day7-proto-buf/geecache/geecache.go b/gee-cache/day7-proto-buf/geecache/geecache.go new file mode 100644 index 0000000..cbce7b9 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/geecache.go @@ -0,0 +1,136 @@ +package geecache + +import ( + "fmt" + pb "geecache/geecachepb" + "geecache/singleflight" + "log" + "sync" +) + +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache + peers PeerPicker + // use singleflight.Group to make sure that + // each key is only fetched once + loader *singleflight.Group +} + +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + loader: &singleflight.Group{}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} + +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +// RegisterPeers registers a PeerPicker for choosing remote peer +func (g *Group) RegisterPeers(peers PeerPicker) { + if g.peers != nil { + panic("RegisterPeerPicker called more than once") + } + g.peers = peers +} + +func (g *Group) load(key string) (value ByteView, err error) { + // each key is only fetched once (either locally or remotely) + // regardless of the number of concurrent callers. + viewi, err := g.loader.Do(key, func() (interface{}, error) { + if g.peers != nil { + if peer, ok := g.peers.PickPeer(key); ok { + if value, err = g.getFromPeer(peer, key); err == nil { + return value, nil + } + log.Println("[GeeCache] Failed to get from peer", err) + } + } + + return g.getLocally(key) + }) + + if err == nil { + return viewi.(ByteView), nil + } + return +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) { + req := &pb.Request{ + Group: g.name, + Key: key, + } + res := &pb.Response{} + err := peer.Get(req, res) + if err != nil { + return ByteView{}, err + } + return ByteView{b: res.Value}, nil +} diff --git a/gee-cache/day7-proto-buf/geecache/geecache_test.go b/gee-cache/day7-proto-buf/geecache/geecache_test.go new file mode 100644 index 0000000..7ef9f4f --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/geecache_test.go @@ -0,0 +1,67 @@ +package geecache + +import ( + "fmt" + "log" + "reflect" + "testing" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Fatal("callback failed") + } +} + +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key]++ + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} + +func TestGetGroup(t *testing.T) { + groupName := "scores" + NewGroup(groupName, 2<<10, GetterFunc( + func(key string) (bytes []byte, err error) { return })) + if group := GetGroup(groupName); group == nil || group.name != groupName { + t.Fatalf("group %s not exist", groupName) + } + + if group := GetGroup(groupName + "111"); group != nil { + t.Fatalf("expect nil, but %s got", group.name) + } +} diff --git a/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.pb.go b/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.pb.go new file mode 100644 index 0000000..d89521d --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.pb.go @@ -0,0 +1,128 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: geecachepb.proto + +package geecachepb + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Request struct { + Group string `protobuf:"bytes,1,opt,name=group,proto3" json:"group,omitempty"` + Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { + return fileDescriptor_889d0a4ad37a0d42, []int{0} +} + +func (m *Request) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Request.Unmarshal(m, b) +} +func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Request.Marshal(b, m, deterministic) +} +func (m *Request) XXX_Merge(src proto.Message) { + xxx_messageInfo_Request.Merge(m, src) +} +func (m *Request) XXX_Size() int { + return xxx_messageInfo_Request.Size(m) +} +func (m *Request) XXX_DiscardUnknown() { + xxx_messageInfo_Request.DiscardUnknown(m) +} + +var xxx_messageInfo_Request proto.InternalMessageInfo + +func (m *Request) GetGroup() string { + if m != nil { + return m.Group + } + return "" +} + +func (m *Request) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +type Response struct { + Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Response) Reset() { *m = Response{} } +func (m *Response) String() string { return proto.CompactTextString(m) } +func (*Response) ProtoMessage() {} +func (*Response) Descriptor() ([]byte, []int) { + return fileDescriptor_889d0a4ad37a0d42, []int{1} +} + +func (m *Response) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Response.Unmarshal(m, b) +} +func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Response.Marshal(b, m, deterministic) +} +func (m *Response) XXX_Merge(src proto.Message) { + xxx_messageInfo_Response.Merge(m, src) +} +func (m *Response) XXX_Size() int { + return xxx_messageInfo_Response.Size(m) +} +func (m *Response) XXX_DiscardUnknown() { + xxx_messageInfo_Response.DiscardUnknown(m) +} + +var xxx_messageInfo_Response proto.InternalMessageInfo + +func (m *Response) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func init() { + proto.RegisterType((*Request)(nil), "geecachepb.Request") + proto.RegisterType((*Response)(nil), "geecachepb.Response") +} + +func init() { proto.RegisterFile("geecachepb.proto", fileDescriptor_889d0a4ad37a0d42) } + +var fileDescriptor_889d0a4ad37a0d42 = []byte{ + // 148 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x48, 0x4f, 0x4d, 0x4d, + 0x4e, 0x4c, 0xce, 0x48, 0x2d, 0x48, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88, + 0x28, 0x19, 0x72, 0xb1, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x89, 0x70, 0xb1, 0xa6, + 0x17, 0xe5, 0x97, 0x16, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x06, 0x41, 0x38, 0x42, 0x02, 0x5c, + 0xcc, 0xd9, 0xa9, 0x95, 0x12, 0x4c, 0x60, 0x31, 0x10, 0x53, 0x49, 0x81, 0x8b, 0x23, 0x28, 0xb5, + 0xb8, 0x20, 0x3f, 0xaf, 0x38, 0x15, 0xa4, 0xa7, 0x2c, 0x31, 0xa7, 0x34, 0x15, 0xac, 0x87, 0x27, + 0x08, 0xc2, 0x31, 0xb2, 0xe3, 0xe2, 0x72, 0x07, 0x69, 0x76, 0x06, 0x59, 0x22, 0x64, 0xc0, 0xc5, + 0xec, 0x9e, 0x5a, 0x22, 0x24, 0xac, 0x87, 0xe4, 0x10, 0xa8, 0x9d, 0x52, 0x22, 0xa8, 0x82, 0x10, + 0x53, 0x93, 0xd8, 0xc0, 0xee, 0x34, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x5c, 0xd5, 0xdd, 0x09, + 0xbb, 0x00, 0x00, 0x00, +} diff --git a/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.proto b/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.proto new file mode 100644 index 0000000..3f5b313 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package geecachepb; + +message Request { + string group = 1; + string key = 2; +} + +message Response { + bytes value = 1; +} + +service GroupCache { + rpc Get(Request) returns (Response); +} diff --git a/gee-cache/day7-proto-buf/geecache/go.mod b/gee-cache/day7-proto-buf/geecache/go.mod new file mode 100644 index 0000000..2ad7119 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/go.mod @@ -0,0 +1,5 @@ +module geecache + +go 1.13 + +require github.com/golang/protobuf v1.3.3 diff --git a/gee-cache/day7-proto-buf/geecache/go.sum b/gee-cache/day7-proto-buf/geecache/go.sum new file mode 100644 index 0000000..b1efb8b --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/go.sum @@ -0,0 +1,2 @@ +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= diff --git a/gee-cache/day7-proto-buf/geecache/http.go b/gee-cache/day7-proto-buf/geecache/http.go new file mode 100644 index 0000000..be8a44e --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/http.go @@ -0,0 +1,142 @@ +package geecache + +import ( + "fmt" + "geecache/consistenthash" + pb "geecache/geecachepb" + "io/ioutil" + "log" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/golang/protobuf/proto" +) + +const ( + defaultBasePath = "/_geecache/" + defaultReplicas = 50 +) + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string + mu sync.Mutex // guards peers and httpGetters + peers *consistenthash.Map + httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008" +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} + +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + groupName := parts[0] + key := parts[1] + + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Write the value to the response body as a proto message. + body, err := proto.Marshal(&pb.Response{Value: view.ByteSlice()}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(body) +} + +// Set updates the pool's list of peers. +func (p *HTTPPool) Set(peers ...string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peers = consistenthash.New(defaultReplicas, nil) + p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath} + } +} + +// PickPeer picks a peer according to key +func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if peer := p.peers.Get(key); peer != "" && peer != p.self { + p.Log("Pick peer %s", peer) + return p.httpGetters[peer], true + } + return nil, false +} + +var _ PeerPicker = (*HTTPPool)(nil) + +type httpGetter struct { + baseURL string +} + +func (h *httpGetter) Get(in *pb.Request, out *pb.Response) error { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(in.GetGroup()), + url.QueryEscape(in.GetKey()), + ) + res, err := http.Get(u) + if err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return fmt.Errorf("server returned: %v", res.Status) + } + + bytes, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("reading response body: %v", err) + } + + if err = proto.Unmarshal(bytes, out); err != nil { + return fmt.Errorf("decoding response body: %v", err) + } + + return nil +} + +var _ PeerGetter = (*httpGetter)(nil) diff --git a/gee-cache/day7-proto-buf/geecache/lru/lru.go b/gee-cache/day7-proto-buf/geecache/lru/lru.go new file mode 100644 index 0000000..dc1a317 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/lru/lru.go @@ -0,0 +1,79 @@ +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} + +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} + +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} + +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} + +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} + +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} diff --git a/gee-cache/day7-proto-buf/geecache/lru/lru_test.go b/gee-cache/day7-proto-buf/geecache/lru/lru_test.go new file mode 100644 index 0000000..f2d3470 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/lru/lru_test.go @@ -0,0 +1,65 @@ +package lru + +import ( + "reflect" + "testing" +) + +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} + +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} + +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} + +func TestAdd(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key", String("1")) + lru.Add("key", String("111")) + + if lru.nbytes != int64(len("key")+len("111")) { + t.Fatal("expected 6 but got", lru.nbytes) + } +} diff --git a/gee-cache/day7-proto-buf/geecache/peers.go b/gee-cache/day7-proto-buf/geecache/peers.go new file mode 100644 index 0000000..9324577 --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/peers.go @@ -0,0 +1,14 @@ +package geecache + +import pb "geecache/geecachepb" + +// PeerPicker is the interface that must be implemented to locate +// the peer that owns a specific key. +type PeerPicker interface { + PickPeer(key string) (peer PeerGetter, ok bool) +} + +// PeerGetter is the interface that must be implemented by a peer. +type PeerGetter interface { + Get(in *pb.Request, out *pb.Response) error +} diff --git a/gee-cache/day7-proto-buf/geecache/singleflight/singleflight.go b/gee-cache/day7-proto-buf/geecache/singleflight/singleflight.go new file mode 100644 index 0000000..85bd0dd --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/singleflight/singleflight.go @@ -0,0 +1,46 @@ +package singleflight + +import "sync" + +// call is an in-flight or completed Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// Group represents a class of work and forms a namespace in which +// units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} diff --git a/gee-cache/day7-proto-buf/geecache/singleflight/singleflight_test.go b/gee-cache/day7-proto-buf/geecache/singleflight/singleflight_test.go new file mode 100644 index 0000000..450951a --- /dev/null +++ b/gee-cache/day7-proto-buf/geecache/singleflight/singleflight_test.go @@ -0,0 +1,16 @@ +package singleflight + +import ( + "testing" +) + +func TestDo(t *testing.T) { + var g Group + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + + if v != "bar" || err != nil { + t.Errorf("Do v = %v, error = %v", v, err) + } +} diff --git a/gee-cache/day7-proto-buf/go.mod b/gee-cache/day7-proto-buf/go.mod new file mode 100644 index 0000000..d0fd3ba --- /dev/null +++ b/gee-cache/day7-proto-buf/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require geecache v0.0.0 + +replace geecache => ./geecache diff --git a/gee-cache/day7-proto-buf/go.sum b/gee-cache/day7-proto-buf/go.sum new file mode 100644 index 0000000..b1efb8b --- /dev/null +++ b/gee-cache/day7-proto-buf/go.sum @@ -0,0 +1,2 @@ +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= diff --git a/gee-cache/day7-proto-buf/main.go b/gee-cache/day7-proto-buf/main.go new file mode 100644 index 0000000..56abc7e --- /dev/null +++ b/gee-cache/day7-proto-buf/main.go @@ -0,0 +1,86 @@ +package main + +/* +$ curl "http://localhost:9999/api?key=Tom" +630 + +$ curl "http://localhost:9999/api?key=kkk" +kkk not exist +*/ + +import ( + "flag" + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func createGroup() *geecache.Group { + return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) +} + +func startCacheServer(addr string, addrs []string, gee *geecache.Group) { + peers := geecache.NewHTTPPool(addr) + peers.Set(addrs...) + gee.RegisterPeers(peers) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr[7:], peers)) +} + +func startAPIServer(apiAddr string, gee *geecache.Group) { + http.Handle("/api", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + key := r.URL.Query().Get("key") + view, err := gee.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) + + })) + log.Println("fontend server is running at", apiAddr) + log.Fatal(http.ListenAndServe(apiAddr[7:], nil)) + +} + +func main() { + var port int + var api bool + flag.IntVar(&port, "port", 8001, "Geecache server port") + flag.BoolVar(&api, "api", false, "Start a api server?") + flag.Parse() + + apiAddr := "http://localhost:9999" + addrMap := map[int]string{ + 8001: "http://localhost:8001", + 8002: "http://localhost:8002", + 8003: "http://localhost:8003", + } + + var addrs []string + for _, v := range addrMap { + addrs = append(addrs, v) + } + + gee := createGroup() + if api { + go startAPIServer(apiAddr, gee) + } + startCacheServer(addrMap[port], addrs, gee) +} diff --git a/gee-cache/day7-proto-buf/run.sh b/gee-cache/day7-proto-buf/run.sh new file mode 100755 index 0000000..066979d --- /dev/null +++ b/gee-cache/day7-proto-buf/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash +trap "rm server;kill 0" EXIT + +go build -o server +./server -port=8001 & +./server -port=8002 & +./server -port=8003 -api=1 & + +sleep 2 +echo ">>> start test" +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & + +wait \ No newline at end of file diff --git a/gee-cache/doc/geecache-day1.md b/gee-cache/doc/geecache-day1.md new file mode 100644 index 0000000..0bdcf75 --- /dev/null +++ b/gee-cache/doc/geecache-day1.md @@ -0,0 +1,253 @@ +--- +title: 动手写分布式缓存 - GeeCache第一天 LRU 缓存淘汰策略 +date: 2020-02-11 22:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了常用的三种缓存淘汰(失效)算法:先进先出(FIFO),最少使用(LFU) 和 最近最少使用(LRU),并实现 LRU 算法和相应的测试代码。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- 分布式缓存 +- LRU +- 缓存失效 +image: post/geecache-day1/lru_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day1 LRU 缓存淘汰策略 +--- + + +本文是[7天用Go从零实现分布式缓存GeeCache教程系列](https://geektutu.com/post/geecache.html)的第一篇。 + +- 介绍常用的三种缓存淘汰(失效)算法:FIFO,LFU 和 LRU +- 实现 LRU 缓存淘汰算法,**代码约80行** + +## 1 FIFO/LFU/LRU 算法简介 + +GeeCache 的缓存全部存储在内存中,内存是有限的,因此不可能无限制地添加数据。假定我们设置缓存能够使用的内存大小为 N,那么在某一个时间点,添加了某一条缓存记录之后,占用内存超过了 N,这个时候就需要从缓存中移除一条或多条数据了。那移除谁呢?我们肯定希望尽可能移除“没用”的数据,那如何判定数据“有用”还是“没用”呢? + +### 1.1 FIFO(First In First Out) + +先进先出,也就是淘汰缓存中最老(最早添加)的记录。FIFO 认为,最早添加的记录,其不再被使用的可能性比刚添加的可能性大。这种算法的实现也非常简单,创建一个队列,新增记录添加到队尾,每次内存不够时,淘汰队首。但是很多场景下,部分记录虽然是最早添加但也最常被访问,而不得不因为呆的时间太长而被淘汰。这类数据会被频繁地添加进缓存,又被淘汰出去,导致缓存命中率降低。 + +### 1.2 LFU(Least Frequently Used) + +最少使用,也就是淘汰缓存中访问频率最低的记录。LFU 认为,如果数据过去被访问多次,那么将来被访问的频率也更高。LFU 的实现需要维护一个按照访问次数排序的队列,每次访问,访问次数加1,队列重新排序,淘汰时选择访问次数最少的即可。LFU 算法的命中率是比较高的,但缺点也非常明显,维护每个记录的访问次数,对内存的消耗是很高的;另外,如果数据的访问模式发生变化,LFU 需要较长的时间去适应,也就是说 LFU 算法受历史数据的影响比较大。例如某个数据历史上访问次数奇高,但在某个时间点之后几乎不再被访问,但因为历史访问次数过高,而迟迟不能被淘汰。 + +### 1.3 LRU(Least Recently Used) + +最近最少使用,相对于仅考虑时间因素的 FIFO 和仅考虑访问频率的 LFU,LRU 算法可以认为是相对平衡的一种淘汰算法。LRU 认为,如果数据最近被访问过,那么将来被访问的概率也会更高。LRU 算法的实现非常简单,维护一个队列,如果某条记录被访问了,则移动到队尾,那么队首则是最近最少访问的数据,淘汰该条记录即可。 + +## 2 LRU 算法实现 + +### 2.1 核心数据结构 + +![implement lru algorithm with golang](geecache-day1/lru.jpg) + +这张图很好地表示了 LRU 算法最核心的 2 个数据结构 + +- 绿色的是字典(map),存储键和值的映射关系。这样根据某个键(key)查找对应的值(value)的复杂是`O(1)`,在字典中插入一条记录的复杂度也是`O(1)`。 +- 红色的是双向链表(double linked list)实现的队列。将所有的值放到双向链表中,这样,当访问到某个值时,将其移动到队尾的复杂度是`O(1)`,在队尾新增一条记录以及删除一条记录的复杂度均为`O(1)`。 + +接下来我们创建一个包含字典和双向链表的结构体类型 Cache,方便实现后续的增删查改操作。 + +[day1-lru/geecache/lru/lru.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day1-lru/geecache/lru) + +```go +package lru + +import "container/list" + +// Cache is a LRU cache. It is not safe for concurrent access. +type Cache struct { + maxBytes int64 + nbytes int64 + ll *list.List + cache map[string]*list.Element + // optional and executed when an entry is purged. + OnEvicted func(key string, value Value) +} + +type entry struct { + key string + value Value +} + +// Value use Len to count how many bytes it takes +type Value interface { + Len() int +} +``` + +- 在这里我们直接使用 Go 语言标准库实现的双向链表`list.List`。 +- 字典的定义是 `map[string]*list.Element`,键是字符串,值是双向链表中对应节点的指针。 +- `maxBytes` 是允许使用的最大内存,`nbytes` 是当前已使用的内存,`OnEvicted` 是某条记录被移除时的回调函数,可以为 nil。 +- 键值对 `entry` 是双向链表节点的数据类型,在链表中仍保存每个值对应的 key 的好处在于,淘汰队首节点时,需要用 key 从字典中删除对应的映射。 +- 为了通用性,我们允许值是实现了 `Value` 接口的任意类型,该接口只包含了一个方法 `Len() int`,用于返回值所占用的内存大小。 + + +方便实例化 `Cache`,实现 `New()` 函数: + +```go +// New is the Constructor of Cache +func New(maxBytes int64, onEvicted func(string, Value)) *Cache { + return &Cache{ + maxBytes: maxBytes, + ll: list.New(), + cache: make(map[string]*list.Element), + OnEvicted: onEvicted, + } +} +``` + +### 2.2 查找功能 + +查找主要有 2 个步骤,第一步是从字典中找到对应的双向链表的节点,第二步,将该节点移动到队尾。 + +```go +// Get look ups a key's value +func (c *Cache) Get(key string) (value Value, ok bool) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + return kv.value, true + } + return +} +``` + +- 如果键对应的链表节点存在,则将对应节点移动到队尾,并返回查找到的值。 +- `c.ll.MoveToFront(ele)`,即将链表中的节点 `ele` 移动到队尾(双向链表作为队列,队首队尾是相对的,在这里约定 front 为队尾) + +### 2.3 删除 + +这里的删除,实际上是缓存淘汰。即移除最近最少访问的节点(队首) + +```go +// RemoveOldest removes the oldest item +func (c *Cache) RemoveOldest() { + ele := c.ll.Back() + if ele != nil { + c.ll.Remove(ele) + kv := ele.Value.(*entry) + delete(c.cache, kv.key) + c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len()) + if c.OnEvicted != nil { + c.OnEvicted(kv.key, kv.value) + } + } +} +``` + +- `c.ll.Back()` 取到队首节点,从链表中删除。 +- `delete(c.cache, kv.key)`,从字典中 `c.cache` 删除该节点的映射关系。 +- 更新当前所用的内存 `c.nbytes`。 +- 如果回调函数 `OnEvicted` 不为 nil,则调用回调函数。 + +### 2.4 新增/修改 + +```go +// Add adds a value to the cache. +func (c *Cache) Add(key string, value Value) { + if ele, ok := c.cache[key]; ok { + c.ll.MoveToFront(ele) + kv := ele.Value.(*entry) + c.nbytes += int64(value.Len()) - int64(kv.value.Len()) + kv.value = value + } else { + ele := c.ll.PushFront(&entry{key, value}) + c.cache[key] = ele + c.nbytes += int64(len(key)) + int64(value.Len()) + } + for c.maxBytes != 0 && c.maxBytes < c.nbytes { + c.RemoveOldest() + } +} +``` + +- 如果键存在,则更新对应节点的值,并将该节点移到队尾。 +- 不存在则是新增场景,首先队尾添加新节点 `&entry{key, value}`, 并字典中添加 key 和节点的映射关系。 +- 更新 `c.nbytes`,如果超过了设定的最大值 `c.maxBytes`,则移除最少访问的节点。 + +最后,为了方便测试,我们实现 `Len()` 用来获取添加了多少条数据。 + +```go +// Len the number of cache entries +func (c *Cache) Len() int { + return c.ll.Len() +} +``` + +## 3 测试 + +例如,我们可以尝试添加几条数据,测试 `Get` 方法 + +[day1-lru/geecache/lru/lru_test.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day1-lru/geecache/lru) + +```go +type String string + +func (d String) Len() int { + return len(d) +} + +func TestGet(t *testing.T) { + lru := New(int64(0), nil) + lru.Add("key1", String("1234")) + if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" { + t.Fatalf("cache hit key1=1234 failed") + } + if _, ok := lru.Get("key2"); ok { + t.Fatalf("cache miss key2 failed") + } +} +``` + +测试,当使用内存超过了设定值时,是否会触发“无用”节点的移除: + +```go +func TestRemoveoldest(t *testing.T) { + k1, k2, k3 := "key1", "key2", "k3" + v1, v2, v3 := "value1", "value2", "v3" + cap := len(k1 + k2 + v1 + v2) + lru := New(int64(cap), nil) + lru.Add(k1, String(v1)) + lru.Add(k2, String(v2)) + lru.Add(k3, String(v3)) + + if _, ok := lru.Get("key1"); ok || lru.Len() != 2 { + t.Fatalf("Removeoldest key1 failed") + } +} +``` + +测试回调函数能否被调用: + +```go +func TestOnEvicted(t *testing.T) { + keys := make([]string, 0) + callback := func(key string, value Value) { + keys = append(keys, key) + } + lru := New(int64(10), callback) + lru.Add("key1", String("123456")) + lru.Add("k2", String("k2")) + lru.Add("k3", String("k3")) + lru.Add("k4", String("k4")) + + expect := []string{"key1", "k2"} + + if !reflect.DeepEqual(expect, keys) { + t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect) + } +} +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [list 官方文档 - golang.org](https://golang.org/pkg/container/list/) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day1/lru.jpg b/gee-cache/doc/geecache-day1/lru.jpg new file mode 100644 index 0000000..db90cd8 Binary files /dev/null and b/gee-cache/doc/geecache-day1/lru.jpg differ diff --git a/gee-cache/doc/geecache-day1/lru_logo.jpg b/gee-cache/doc/geecache-day1/lru_logo.jpg new file mode 100644 index 0000000..306b014 Binary files /dev/null and b/gee-cache/doc/geecache-day1/lru_logo.jpg differ diff --git a/gee-cache/doc/geecache-day2.md b/gee-cache/doc/geecache-day2.md new file mode 100644 index 0000000..583408b --- /dev/null +++ b/gee-cache/doc/geecache-day2.md @@ -0,0 +1,435 @@ +--- +title: 动手写分布式缓存 - GeeCache第二天 单机并发缓存 +date: 2020-02-12 22:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了 sync.Mutex 互斥锁的使用,并发控制 LRU 缓存。实现 GeeCache 核心数据结构 Group,缓存不存在时,调用回调函数(callback)获取源数据。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- 分布式缓存 +- 互斥锁 +- sync.Mutex +image: post/geecache-day2/concurrent_cache_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day2 单机并发缓存 +--- + +![geecache concurrent cache](geecache-day2/concurrent_cache.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第二篇。 + +- 介绍 sync.Mutex 互斥锁的使用,并实现 LRU 缓存的并发控制。 +- 实现 GeeCache 核心数据结构 Group,缓存不存在时,调用回调函数获取源数据,**代码约150行** + +## 1 sync.Mutex + +多个协程(goroutine)同时读写同一个变量,在并发度较高的情况下,会发生冲突。确保一次只有一个协程(goroutine)可以访问该变量以避免冲突,这称之为`互斥`,互斥锁可以解决这个问题。 + +> sync.Mutex 是一个互斥锁,可以由不同的协程加锁和解锁。 + +`sync.Mutex` 是 Go 语言标准库提供的一个互斥锁,当一个协程(goroutine)获得了这个锁的拥有权后,其它请求锁的协程(goroutine) 就会阻塞在 `Lock()` 方法的调用上,直到调用 `Unlock()` 锁被释放。 + +接下来举一个简单的例子,假设有10个并发的协程打印了同一个数字`100`,为了避免重复打印,实现了`printOnce(num int)` 函数,使用集合 set 记录已打印过的数字,如果数字已打印过,则不再打印。 + +```go +var set = make(map[int]bool, 0) + +func printOnce(num int) { + if _, exist := set[num]; !exist { + fmt.Println(num) + } + set[num] = true +} + +func main() { + for i := 0; i < 10; i++ { + go printOnce(100) + } + time.Sleep(time.Second) +} +``` + +我们运行 `go run .` 会发生什么情况呢? + +```bash +$ go run . +100 +100 +``` + +有时候打印 2 次,有时候打印 4 次,有时候还会触发 panic,因为对同一个数据结构`set`的访问冲突了。接下来用互斥锁的`Lock()`和`Unlock()` 方法将冲突的部分包裹起来: + +```go +var m sync.Mutex +var set = make(map[int]bool, 0) + +func printOnce(num int) { + m.Lock() + if _, exist := set[num]; !exist { + fmt.Println(num) + } + set[num] = true + m.Unlock() +} + +func main() { + for i := 0; i < 10; i++ { + go printOnce(100) + } + time.Sleep(time.Second) +} +``` + +```bash +$ go run . +100 +``` + +相同的数字只会被打印一次。当一个协程调用了 `Lock()` 方法时,其他协程被阻塞了,直到`Unlock()`调用将锁释放。因此被包裹部分的代码就能够避免冲突,实现互斥。 + +`Unlock()`释放锁还有另外一种写法: + +```go +func printOnce(num int) { + m.Lock() + defer m.Unlock() + if _, exist := set[num]; !exist { + fmt.Println(num) + } + set[num] = true +} +``` + +## 2 支持并发读写 + +上一篇文章 [GeeCache 第一天](https://geektutu.com/post/geecache-day1.html) 实现了 LRU 缓存淘汰策略。接下来我们使用 `sync.Mutex` 封装 LRU 的几个方法,使之支持并发的读写。在这之前,我们抽象了一个只读数据结构 `ByteView` 用来表示缓存值,是 GeeCache 主要的数据结构之一。 + +[day2-single-node/geecache/byteview.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache) + +```go +package geecache + +// A ByteView holds an immutable view of bytes. +type ByteView struct { + b []byte +} + +// Len returns the view's length +func (v ByteView) Len() int { + return len(v.b) +} + +// ByteSlice returns a copy of the data as a byte slice. +func (v ByteView) ByteSlice() []byte { + return cloneBytes(v.b) +} + +// String returns the data as a string, making a copy if necessary. +func (v ByteView) String() string { + return string(v.b) +} + +func cloneBytes(b []byte) []byte { + c := make([]byte, len(b)) + copy(c, b) + return c +} +``` + +- ByteView 只有一个数据成员,`b []byte`,b 将会存储真实的缓存值。选择 byte 类型是为了能够支持任意的数据类型的存储,例如字符串、图片等。 +- 实现 `Len() int` 方法,我们在 lru.Cache 的实现中,要求被缓存对象必须实现 Value 接口,即 `Len() int` 方法,返回其所占的内存大小。 +- `b` 是只读的,使用 `ByteSlice()` 方法返回一个拷贝,防止缓存值被外部程序修改。 + +接下来就可以为 lru.Cache 添加并发特性了。 + +[day2-single-node/geecache/cache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache) + +```go +package geecache + +import ( + "geecache/lru" + "sync" +) + +type cache struct { + mu sync.Mutex + lru *lru.Cache + cacheBytes int64 +} + +func (c *cache) add(key string, value ByteView) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + c.lru = lru.New(c.cacheBytes, nil) + } + c.lru.Add(key, value) +} + +func (c *cache) get(key string) (value ByteView, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + + if v, ok := c.lru.Get(key); ok { + return v.(ByteView), ok + } + + return +} +``` + +- `cache.go` 的实现非常简单,实例化 lru,封装 get 和 add 方法,并添加互斥锁 mu。 +- 在 `add` 方法中,判断了 `c.lru` 是否为 nil,如果等于 nil 再创建实例。这种方法称之为延迟初始化(Lazy Initialization),一个对象的延迟初始化意味着该对象的创建将会延迟至第一次使用该对象时。主要用于提高性能,并减少程序内存要求。 + +## 3 主体结构 Group + +Group 是 GeeCache 最核心的数据结构,负责与用户的交互,并且控制缓存值存储和获取的流程。 + +```bash + 是 +接收 key --> 检查是否被缓存 -----> 返回缓存值 ⑴ + | 否 是 + |-----> 是否应当从远程节点获取 -----> 与远程节点交互 --> 返回缓存值 ⑵ + | 否 + |-----> 调用`回调函数`,获取值并添加到缓存 --> 返回缓存值 ⑶ +``` + +我们将在 `geecache.go` 中实现主体结构 Group,那么 GeeCache 的代码结构的雏形已经形成了。 + +```bash +geecache/ + |--lru/ + |--lru.go // lru 缓存淘汰策略 + |--byteview.go // 缓存值的抽象与封装 + |--cache.go // 并发控制 + |--geecache.go // 负责与外部交互,控制缓存存储和获取的主流程 +``` + +接下来我们将实现流程 ⑴ 和 ⑶,远程交互的部分后续再实现。 + + +### 3.1 回调 Getter + +我们思考一下,如果缓存不存在,应从数据源(文件,数据库等)获取数据并添加到缓存中。GeeCache 是否应该支持多种数据源的配置呢?不应该,一是数据源的种类太多,没办法一一实现;二是扩展性不好。如何从源头获取数据,应该是用户决定的事情,我们就把这件事交给用户好了。因此,我们设计了一个回调函数(callback),在缓存不存在时,调用这个函数,得到源数据。 + +[day2-single-node/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache) + +```go +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} +``` + +- 定义接口 Getter 和 回调函数 `Get(key string)([]byte, error)`,参数是 key,返回值是 []byte。 +- 定义函数类型 GetterFunc,并实现 Getter 接口的 `Get` 方法。 +- 函数类型实现某一个接口,称之为接口型函数,方便使用者在调用时既能够传入函数作为参数,也能够传入实现了该接口的结构体作为参数。 + +> 了解接口型函数的使用场景,可以参考 [Go 接口型函数的使用场景 - 7days-golang Q & A](https://geektutu.com/post/7days-golang-q1.html) + +我们可以写一个测试用例来保证回调函数能够正常工作。 + +```go +func TestGetter(t *testing.T) { + var f Getter = GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil + }) + + expect := []byte("key") + if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) { + t.Errorf("callback failed") + } +} +``` + +- 在这个测试用例中,我们借助 GetterFunc 的类型转换,将一个匿名回调函数转换成了接口 `f Getter`。 +- 调用该接口的方法 `f.Get(key string)`,实际上就是在调用匿名回调函数。 + +> 定义一个函数类型 F,并且实现接口 A 的方法,然后在这个方法中调用自己。这是 Go 语言中将其他函数(参数返回值定义与 F 一致)转换为接口 A 的常用技巧。 + +### 3.2 Group 的定义 + +接下来是最核心数据结构 Group 的定义: + +[day2-single-node/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache) + +```go +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache +} + +var ( + mu sync.RWMutex + groups = make(map[string]*Group) +) + +// NewGroup create a new instance of Group +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + if getter == nil { + panic("nil Getter") + } + mu.Lock() + defer mu.Unlock() + g := &Group{ + name: name, + getter: getter, + mainCache: cache{cacheBytes: cacheBytes}, + } + groups[name] = g + return g +} + +// GetGroup returns the named group previously created with NewGroup, or +// nil if there's no such group. +func GetGroup(name string) *Group { + mu.RLock() + g := groups[name] + mu.RUnlock() + return g +} +``` + +- 一个 Group 可以认为是一个缓存的命名空间,每个 Group 拥有一个唯一的名称 `name`。比如可以创建三个 Group,缓存学生的成绩命名为 scores,缓存学生信息的命名为 info,缓存学生课程的命名为 courses。 +- 第二个属性是 `getter Getter`,即缓存未命中时获取源数据的回调(callback)。 +- 第三个属性是 `mainCache cache`,即一开始实现的并发缓存。 +- 构建函数 `NewGroup` 用来实例化 Group,并且将 group 存储在全局变量 `groups` 中。 +- `GetGroup` 用来特定名称的 Group,这里使用了只读锁 `RLock()`,因为不涉及任何冲突变量的写操作。 + +### 3.3 Group 的 Get 方法 + +接下来是 GeeCache 最为核心的方法 `Get`: + +```go +// Get value for a key from cache +func (g *Group) Get(key string) (ByteView, error) { + if key == "" { + return ByteView{}, fmt.Errorf("key is required") + } + + if v, ok := g.mainCache.get(key); ok { + log.Println("[GeeCache] hit") + return v, nil + } + + return g.load(key) +} + +func (g *Group) load(key string) (value ByteView, err error) { + return g.getLocally(key) +} + +func (g *Group) getLocally(key string) (ByteView, error) { + bytes, err := g.getter.Get(key) + if err != nil { + return ByteView{}, err + + } + value := ByteView{b: cloneBytes(bytes)} + g.populateCache(key, value) + return value, nil +} + +func (g *Group) populateCache(key string, value ByteView) { + g.mainCache.add(key, value) +} +``` + +- Get 方法实现了上述所说的流程 ⑴ 和 ⑶。 +- 流程 ⑴ :从 mainCache 中查找缓存,如果存在则返回缓存值。 +- 流程 ⑶ :缓存不存在,则调用 load 方法,load 调用 getLocally(分布式场景下会调用 getFromPeer 从其他节点获取),getLocally 调用用户回调函数 `g.getter.Get()` 获取源数据,并且将源数据添加到缓存 mainCache 中(通过 populateCache 方法) + +至此,这一章节的单机并发缓存就已经完成了。 + +## 4 测试 + +可以写测试用例,也可以写 main 函数来测试这一章节实现的功能。那我们通过测试用例来看一下,如何使用我们实现的单机并发缓存吧。 + +首先,用一个 map 模拟耗时的数据库。 + +```go +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} +``` + +创建 group 实例,并测试 `Get` 方法 + +```go +func TestGet(t *testing.T) { + loadCounts := make(map[string]int, len(db)) + gee := NewGroup("scores", 2<<10, GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + if _, ok := loadCounts[key]; !ok { + loadCounts[key] = 0 + } + loadCounts[key] += 1 + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + for k, v := range db { + if view, err := gee.Get(k); err != nil || view.String() != v { + t.Fatal("failed to get value of Tom") + } // load from callback function + if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 { + t.Fatalf("cache %s miss", k) + } // cache hit + } + + if view, err := gee.Get("unknown"); err == nil { + t.Fatalf("the value of unknow should be empty, but %s got", view) + } +} +``` + +- 在这个测试用例中,我们主要测试了 2 种情况 +- 1)在缓存为空的情况下,能够通过回调函数获取到源数据。 +- 2)在缓存已经存在的情况下,是否直接从缓存中获取,为了实现这一点,使用 `loadCounts` 统计某个键调用回调函数的次数,如果次数大于1,则表示调用了多次回调函数,没有缓存。 + +测试结果如下: + +```bash +$ go test -run TestGet +2020/02/11 22:07:31 [SlowDB] search key Sam +2020/02/11 22:07:31 [GeeCache] hit +2020/02/11 22:07:31 [SlowDB] search key Tom +2020/02/11 22:07:31 [GeeCache] hit +2020/02/11 22:07:31 [SlowDB] search key Jack +2020/02/11 22:07:31 [GeeCache] hit +2020/02/11 22:07:31 [SlowDB] search key unknown +PASS +ok geecache 0.008s +``` + +可以很清晰地看到,缓存为空时,调用了回调函数,第二次访问时,则直接从缓存中读取。 + +## 附 推荐阅读 + +- [Go 语言简明教程 - 并发编程](https://geektutu.com/post/quick-golang.html#7-并发编程-goroutine) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [sync 官方文档 - golang.org](https://golang.org/pkg/sync/) diff --git a/gee-cache/doc/geecache-day2/concurrent_cache.jpg b/gee-cache/doc/geecache-day2/concurrent_cache.jpg new file mode 100644 index 0000000..4c2f171 Binary files /dev/null and b/gee-cache/doc/geecache-day2/concurrent_cache.jpg differ diff --git a/gee-cache/doc/geecache-day2/concurrent_cache_logo.jpg b/gee-cache/doc/geecache-day2/concurrent_cache_logo.jpg new file mode 100644 index 0000000..1a6317c Binary files /dev/null and b/gee-cache/doc/geecache-day2/concurrent_cache_logo.jpg differ diff --git a/gee-cache/doc/geecache-day3.md b/gee-cache/doc/geecache-day3.md new file mode 100644 index 0000000..a251dcc --- /dev/null +++ b/gee-cache/doc/geecache-day3.md @@ -0,0 +1,256 @@ +--- +title: 动手写分布式缓存 - GeeCache第三天 HTTP 服务端 +date: 2020-02-12 23:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了如何使用标准库 http 搭建 HTTP Server,为 GeeCache 单机节点搭建 HTTP 服务,并进行相关的测试。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- 分布式缓存 +- HTTP Server +image: post/geecache-day3/http_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day3 HTTP 服务端 +--- + +![geecache http server](geecache-day3/http.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第三篇。 + +- 介绍如何使用 Go 语言标准库 `http` 搭建 HTTP Server +- 并实现 main 函数启动 HTTP Server 测试 API,**代码约60行** + +## 1 http 标准库 + +Go 语言提供了 `http` 标准库,可以非常方便地搭建 HTTP 服务端和客户端。比如我们可以实现一个服务端,无论接收到什么请求,都返回字符串 "Hello World!" + +```go +package main + +import ( + "log" + "net/http" +) + +type server int + +func (h *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.Println(r.URL.Path) + w.Write([]byte("Hello World!")) +} + +func main() { + var s server + http.ListenAndServe("localhost:9999", &s) +} +``` + +- 创建任意类型 server,并实现 `ServeHTTP` 方法。 +- 调用 `http.ListenAndServe` 在 9999 端口启动 http 服务,处理请求的对象为 `s server`。 + +接下来我们执行 `go run .` 启动服务,借助 curl 来测试效果: + +```bash +$ curl http://localhost:9999 +Hello World! +$ curl http://localhost:9999/abc +Hello World! +``` + +Go 程序日志输出 + +```bash +2020/02/11 22:56:32 / +2020/02/11 22:56:34 /abc +``` + +> `http.ListenAndServe` 接收 2 个参数,第一个参数是服务启动的地址,第二个参数是 Handler,任何实现了 `ServeHTTP` 方法的对象都可以作为 HTTP 的 Handler。 + +在标准库中,http.Handler 接口的定义如下: + +```go +package http + +type Handler interface { + ServeHTTP(w ResponseWriter, r *Request) +} +``` + +## 2 GeeCache HTTP 服务端 + +分布式缓存需要实现节点间通信,建立基于 HTTP 的通信机制是比较常见和简单的做法。如果一个节点启动了 HTTP 服务,那么这个节点就可以被其他节点访问。今天我们就为单机节点搭建 HTTP Server。 + +不与其他部分耦合,我们将这部分代码放在新的 `http.go` 文件中,当前的代码结构如下: + +```bash +geecache/ + |--lru/ + |--lru.go // lru 缓存淘汰策略 + |--byteview.go // 缓存值的抽象与封装 + |--cache.go // 并发控制 + |--geecache.go // 负责与外部交互,控制缓存存储和获取的主流程 + |--http.go // 提供被其他节点访问的能力(基于http) +``` + +首先我们创建一个结构体 `HTTPPool`,作为承载节点间 HTTP 通信的核心数据结构(包括服务端和客户端,今天只实现服务端)。 + +[day3-http-server/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day3-http-server/geecache) + +```go +package geecache + +import ( + "fmt" + "log" + "net/http" + "strings" +) + +const defaultBasePath = "/_geecache/" + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string +} + +// NewHTTPPool initializes an HTTP pool of peers. +func NewHTTPPool(self string) *HTTPPool { + return &HTTPPool{ + self: self, + basePath: defaultBasePath, + } +} +``` + +- `HTTPPool` 只有 2 个参数,一个是 self,用来记录自己的地址,包括主机名/IP 和端口。 +- 另一个是 basePath,作为节点间通讯地址的前缀,默认是 `/_geecache/`,那么 http://example.com/_geecache/ 开头的请求,就用于节点间的访问。因为一个主机上还可能承载其他的服务,加一段 Path 是一个好习惯。比如,大部分网站的 API 接口,一般以 `/api` 作为前缀。 + +接下来,实现最为核心的 `ServeHTTP` 方法。 + +```go +// Log info with server name +func (p *HTTPPool) Log(format string, v ...interface{}) { + log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...)) +} + +// ServeHTTP handle all http requests +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + p.Log("%s %s", r.Method, r.URL.Path) + // /// required + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + groupName := parts[0] + key := parts[1] + + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + + view, err := group.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) +} +``` + +- ServeHTTP 的实现逻辑是比较简单的,首先判断访问路径的前缀是否是 `basePath`,不是返回错误。 +- 我们约定访问路径格式为 `///`,通过 groupname 得到 group 实例,再使用 `group.Get(key)` 获取缓存数据。 +- 最终使用 `w.Write()` 将缓存值作为 httpResponse 的 body 返回。 + +到这里,HTTP 服务端已经完整地实现了。接下来,我们将在单机上启动 HTTP 服务,使用 curl 进行测试。 + +## 3 测试 + +实现 main 函数,实例化 group,并启动 HTTP 服务。 + +[day3-http-server/main.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day3-http-server) + +```go +package main + +import ( + "fmt" + "geecache" + "log" + "net/http" +) + +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func main() { + geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) + + addr := "localhost:9999" + peers := geecache.NewHTTPPool(addr) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr, peers)) +} +``` + +- 同样地,我们使用 map 模拟了数据源 db。 +- 创建一个名为 scores 的 Group,若缓存为空,回调函数会从 db 中获取数据并返回。 +- 使用 http.ListenAndServe 在 9999 端口启动了 HTTP 服务。 + +> 需要注意的点: +> main.go 和 geecache/ 在同级目录,但 go modules 不再支持 import <相对路径>,相对路径需要在 go.mod 中声明: +> require geecache v0.0.0 +> replace geecache => ./geecache + +接下来,运行 main 函数,使用 curl 做一些简单测试: + +```bash +$ curl http://localhost:9999/_geecache/scores/Tom +630 +$ curl http://localhost:9999/_geecache/scores/kkk +kkk not exist +``` + +GeeCache 的日志输出如下: + +```bash +2020/02/11 23:28:39 geecache is running at localhost:9999 +2020/02/11 23:29:08 [Server localhost:9999] GET /_geecache/scores/Tom +2020/02/11 23:29:08 [SlowDB] search key Tom +2020/02/11 23:29:16 [Server localhost:9999] GET /_geecache/scores/kkk +2020/02/11 23:29:16 [SlowDB] search key kkk +``` + +节点间的相互通信不仅需要 HTTP 服务端,还需要 HTTP 客户端,这就是我们下一步需要做的事情。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [Go http.Handler 基础](https://geektutu.com/post/gee-day1.html) +- [http 官方文档 - golang.org](https://golang.org/pkg/http) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day3/http.jpg b/gee-cache/doc/geecache-day3/http.jpg new file mode 100755 index 0000000..20bde20 Binary files /dev/null and b/gee-cache/doc/geecache-day3/http.jpg differ diff --git a/gee-cache/doc/geecache-day3/http_logo.jpg b/gee-cache/doc/geecache-day3/http_logo.jpg new file mode 100755 index 0000000..f681ee6 Binary files /dev/null and b/gee-cache/doc/geecache-day3/http_logo.jpg differ diff --git a/gee-cache/doc/geecache-day4.md b/gee-cache/doc/geecache-day4.md new file mode 100644 index 0000000..02db009 --- /dev/null +++ b/gee-cache/doc/geecache-day4.md @@ -0,0 +1,234 @@ +--- +title: 动手写分布式缓存 - GeeCache第四天 一致性哈希(hash) +date: 2020-02-16 20:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了一致性哈希(consistent hashing)的原理、实现以及相关测试用例,一致性哈希为什么能避免缓存雪崩,虚拟节点为什么能解决数据倾斜的问题。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- 一致性hash +- consistent hash +image: post/geecache-day4/hash_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day4 一致性哈希 +--- + +![一致性哈希 consistent hashing](geecache-day4/hash.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第四篇。 + +- 一致性哈希(consistent hashing)的原理以及为什么要使用一致性哈希。 +- 实现一致性哈希代码,添加相应的测试用例,**代码约60行** + +## 1 为什么使用一致性哈希 + +今天我们要实现的是一致性哈希算法,一致性哈希算法是 GeeCache 从单节点走向分布式节点的一个重要的环节。那你可能要问了, + +> 童鞋,一致性哈希算法是啥?为什么要使用一致性哈希算法?这和分布式有什么关系? + +### 1.1 我该访问谁? + +对于分布式缓存来说,当一个节点接收到请求,如果该节点并没有存储缓存值,那么它面临的难题是,从谁那获取数据?自己,还是节点1, 2, 3, 4... 。假设包括自己在内一共有 10 个节点,当一个节点接收到请求时,随机选择一个节点,由该节点从数据源获取数据。 + +假设第一次随机选取了节点 1 ,节点 1 从数据源获取到数据的同时缓存该数据;那第二次,只有 1/10 的可能性再次选择节点 1, 有 9/10 的概率选择了其他节点,如果选择了其他节点,就意味着需要再一次从数据源获取数据,一般来说,这个操作是很耗时的。这样做,一是缓存效率低,二是各个节点上存储着相同的数据,浪费了大量的存储空间。 + +那有什么办法,对于给定的 key,每一次都选择同一个节点呢?使用 hash 算法也能够做到这一点。那把 key 的每一个字符的 ASCII 码加起来,再除以 10 取余数可以吗?当然可以,这可以认为是自定义的 hash 算法。 + +![hash select peer](geecache-day4/hash_select.jpg) + +从上面的图可以看到,任意一个节点任意时刻请求查找键 `Tom` 对应的值,都会分配给节点 2,有效地解决了上述的问题。 + +### 1.2 节点数量变化了怎么办? + +简单求取 Hash 值解决了缓存性能的问题,但是没有考虑节点数量变化的场景。假设,移除了其中一台节点,只剩下 9 个,那么之前 `hash(key) % 10` 变成了 `hash(key) % 9`,也就意味着几乎缓存值对应的节点都发生了改变。即几乎所有的缓存值都失效了。节点在接收到对应的请求时,均需要重新去数据源获取数据,容易引起 `缓存雪崩`。 + +> 缓存雪崩:缓存在同一时刻全部失效,造成瞬时DB请求量大、压力骤增,引起雪崩。常因为缓存服务器宕机,或缓存设置了相同的过期时间引起。 + +那如何解决这个问题呢?一致性哈希算法可以。 + +## 2 算法原理 + +### 2.1 步骤 + +一致性哈希算法将 key 映射到 2^32 的空间中,将这个数字首尾相连,形成一个环。 + +- 计算节点/机器(通常使用节点的名称、编号和 IP 地址)的哈希值,放置在环上。 +- 计算 key 的哈希值,放置在环上,顺时针寻找到的第一个节点,就是应选取的节点/机器。 + +![一致性哈希添加节点 consistent hashing add peer](geecache-day4/add_peer.jpg) + +环上有 peer2,peer4,peer6 三个节点,`key11`,`key2`,`key27` 均映射到 peer2,`key23` 映射到 peer4。此时,如果新增节点/机器 peer8,假设它新增位置如图所示,那么只有 `key27` 从 peer2 调整到 peer8,其余的映射均没有发生改变。 + +也就是说,一致性哈希算法,在新增/删除节点时,只需要重新定位该节点附近的一小部分数据,而不需要重新定位所有的节点,这就解决了上述的问题。 + +### 2.2 数据倾斜问题 + +如果服务器的节点过少,容易引起 key 的倾斜。例如上面例子中的 peer2,peer4,peer6 分布在环的上半部分,下半部分是空的。那么映射到环下半部分的 key 都会被分配给 peer2,key 过度向 peer2 倾斜,缓存节点间负载不均。 + +为了解决这个问题,引入了虚拟节点的概念,一个真实节点对应多个虚拟节点。 + +假设 1 个真实节点对应 3 个虚拟节点,那么 peer1 对应的虚拟节点是 peer1-1、 peer1-2、 peer1-3(通常以添加编号的方式实现),其余节点也以相同的方式操作。 + +- 第一步,计算虚拟节点的 Hash 值,放置在环上。 +- 第二步,计算 key 的 Hash 值,在环上顺时针寻找到应选取的虚拟节点,例如是 peer2-1,那么就对应真实节点 peer2。 + +虚拟节点扩充了节点的数量,解决了节点较少的情况下数据容易倾斜的问题。而且代价非常小,只需要增加一个字典(map)维护真实节点与虚拟节点的映射关系即可。 + +## 3 Go语言实现 + +我们在 geecache 目录下新建 package `consistenthash`,用来实现一致性哈希算法。 + +[day4-consistent-hash/geecache/consistenthash/consistenthash.go](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day4-consistent-hash/geecache/consistenthash) + +```go +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +// Hash maps bytes to uint32 +type Hash func(data []byte) uint32 + +// Map constains all hashed keys +type Map struct { + hash Hash + replicas int + keys []int // Sorted + hashMap map[int]string +} + +// New creates a Map instance +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE + } + return m +} +``` + +- 定义了函数类型 `Hash`,采取依赖注入的方式,允许用于替换成自定义的 Hash 函数,也方便测试时替换,默认为 `crc32.ChecksumIEEE` 算法。 +- `Map` 是一致性哈希算法的主数据结构,包含 4 个成员变量:Hash 函数 `hash`;虚拟节点倍数 `replicas`;哈希环 `keys`;虚拟节点与真实节点的映射表 `hashMap`,键是虚拟节点的哈希值,值是真实节点的名称。 +- 构造函数 `New()` 允许自定义虚拟节点倍数和 Hash 函数。 + +接下来,实现添加真实节点/机器的 `Add()` 方法。 + +```go +// Add adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} +``` + +- `Add` 函数允许传入 0 或 多个真实节点的名称。 +- 对每一个真实节点 `key`,对应创建 `m.replicas` 个虚拟节点,虚拟节点的名称是:`strconv.Itoa(i) + key`,即通过添加编号的方式区分不同虚拟节点。 +- 使用 `m.hash()` 计算虚拟节点的哈希值,使用 `append(m.keys, hash)` 添加到环上。 +- 在 `hashMap` 中增加虚拟节点和真实节点的映射关系。 +- 最后一步,环上的哈希值排序。 + +最后一步,实现选择节点的 `Get()` 方法。 + +```go +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if len(m.keys) == 0 { + return "" + } + + hash := int(m.hash([]byte(key))) + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { + return m.keys[i] >= hash + }) + + return m.hashMap[m.keys[idx%len(m.keys)]] +} +``` + +- 选择节点就非常简单了,第一步,计算 key 的哈希值。 +- 第二步,顺时针找到第一个匹配的虚拟节点的下标 `idx`,从 m.keys 中获取到对应的哈希值。如果 `idx == len(m.keys)`,说明应选择 `m.keys[0]`,因为 `m.keys` 是一个环状结构,所以用取余数的方式来处理这种情况。 +- 第三步,通过 `hashMap` 映射得到真实的节点。 + +至此,整个一致性哈希算法就实现完成了。 + +## 4 测试 + +最后呢,需要测试用例来验证我们的实现是否有问题。 + +[day4-consistent-hash/geecache/consistenthash/consistenthash_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day4-consistent-hash/geecache/consistenthash) + +```go +package consistenthash + +import ( + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + hash := New(3, func(key []byte) uint32 { + i, _ := strconv.Atoi(string(key)) + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") + + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} +``` + +如果要进行测试,那么我们需要明确地知道每一个传入的 key 的哈希值,那使用默认的 `crc32.ChecksumIEEE` 算法显然达不到目的。所以在这里使用了自定义的 Hash 算法。自定义的 Hash 算法只处理数字,传入字符串表示的数字,返回对应的数字即可。 + +- 一开始,有 2/4/6 三个真实节点,对应的虚拟节点的哈希值是 02/12/22、04/14/24、06/16/26。 +- 那么用例 2/11/23/27 选择的虚拟节点分别是 02/12/24/02,也就是真实节点 2/2/4/2。 +- 添加一个真实节点 8,对应虚拟节点的哈希值是 08/18/28,此时,用例 27 对应的虚拟节点从 `02` 变更为 `28`,即真实节点 8。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day4/add_peer.jpg b/gee-cache/doc/geecache-day4/add_peer.jpg new file mode 100755 index 0000000..edf56f5 Binary files /dev/null and b/gee-cache/doc/geecache-day4/add_peer.jpg differ diff --git a/gee-cache/doc/geecache-day4/hash.jpg b/gee-cache/doc/geecache-day4/hash.jpg new file mode 100644 index 0000000..5134f55 Binary files /dev/null and b/gee-cache/doc/geecache-day4/hash.jpg differ diff --git a/gee-cache/doc/geecache-day4/hash_logo.jpg b/gee-cache/doc/geecache-day4/hash_logo.jpg new file mode 100644 index 0000000..c9cab99 Binary files /dev/null and b/gee-cache/doc/geecache-day4/hash_logo.jpg differ diff --git a/gee-cache/doc/geecache-day4/hash_select.jpg b/gee-cache/doc/geecache-day4/hash_select.jpg new file mode 100755 index 0000000..7914cfa Binary files /dev/null and b/gee-cache/doc/geecache-day4/hash_select.jpg differ diff --git a/gee-cache/doc/geecache-day5.md b/gee-cache/doc/geecache-day5.md new file mode 100644 index 0000000..9f995ba --- /dev/null +++ b/gee-cache/doc/geecache-day5.md @@ -0,0 +1,356 @@ +--- +title: 动手写分布式缓存 - GeeCache第五天 分布式节点 +date: 2020-02-16 21:30:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了为 GeeCache 添加了注册节点与选择节点的功能,并实现了 HTTP 客户端,与远程节点的服务端通信。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- HTTP客户端 +- 分布式节点 +image: post/geecache-day5/dist_nodes_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day5 分布式节点 +--- + +![分布式缓存节点](geecache-day5/dist_nodes.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第五篇。 + +- 注册节点(Register Peers),借助一致性哈希算法选择节点。 +- 实现 HTTP 客户端,与远程节点的服务端通信,**代码约90行** + +## 1 流程回顾 + +```bash + 是 +接收 key --> 检查是否被缓存 -----> 返回缓存值 ⑴ + | 否 是 + |-----> 是否应当从远程节点获取 -----> 与远程节点交互 --> 返回缓存值 ⑵ + | 否 + |-----> 调用`回调函数`,获取值并添加到缓存 --> 返回缓存值 ⑶ +``` + +我们在[GeeCache 第二天](https://geektutu.com/post/geecache-day2.html) 中描述了 geecache 的流程。在这之前已经实现了流程 ⑴ 和 ⑶,今天实现流程 ⑵,从远程节点获取缓存值。 + +我们进一步细化流程 ⑵: + +```bash +使用一致性哈希选择节点 是 是 + |-----> 是否是远程节点 -----> HTTP 客户端访问远程节点 --> 成功?-----> 服务端返回返回值 + | 否 ↓ 否 + |----------------------------> 回退到本地节点处理。 +``` + +## 2 抽象 PeerPicker + +[day5-multi-nodes/geecache/peers.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache) + + +```go +package geecache + +// PeerPicker is the interface that must be implemented to locate +// the peer that owns a specific key. +type PeerPicker interface { + PickPeer(key string) (peer PeerGetter, ok bool) +} + +// PeerGetter is the interface that must be implemented by a peer. +type PeerGetter interface { + Get(group string, key string) ([]byte, error) +} +``` + +- 在这里,抽象出 2 个接口,PeerPicker 的 `PickPeer()` 方法用于根据传入的 key 选择相应节点 PeerGetter。 +- 接口 PeerGetter 的 `Get()` 方法用于从对应 group 查找缓存值。PeerGetter 就对应于上述流程中的 HTTP 客户端。 + +## 3 节点选择与 HTTP 客户端 + + +在 [GeeCache 第三天](https://geektutu.com/post/geecache-day3.html) 中我们为 `HTTPPool` 实现了服务端功能,通信不仅需要服务端还需要客户端,因此,我们接下来要为 `HTTPPool` 实现客户端的功能。 + +首先创建具体的 HTTP 客户端类 `httpGetter`,实现 PeerGetter 接口。 + +[day5-multi-nodes/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache) + +```go +type httpGetter struct { + baseURL string +} + +func (h *httpGetter) Get(group string, key string) ([]byte, error) { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(group), + url.QueryEscape(key), + ) + res, err := http.Get(u) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned: %v", res.Status) + } + + bytes, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %v", err) + } + + return bytes, nil +} + +var _ PeerGetter = (*httpGetter)(nil) +``` + +- baseURL 表示将要访问的远程节点的地址,例如 `http://example.com/_geecache/`。 +- 使用 `http.Get()` 方式获取返回值,并转换为 `[]bytes` 类型。 + +第二步,为 HTTPPool 添加节点选择的功能。 + +```go +const ( + defaultBasePath = "/_geecache/" + defaultReplicas = 50 +) +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // this peer's base URL, e.g. "https://example.net:8000" + self string + basePath string + mu sync.Mutex // guards peers and httpGetters + peers *consistenthash.Map + httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008" +} +``` + +- 新增成员变量 `peers`,类型是一致性哈希算法的 `Map`,用来根据具体的 key 选择节点。 +- 新增成员变量 `httpGetters`,映射远程节点与对应的 httpGetter。每一个远程节点对应一个 httpGetter,因为 httpGetter 与远程节点的地址 `baseURL` 有关。 + +第三步,实现 PeerPicker 接口。 + +```go +// Set updates the pool's list of peers. +func (p *HTTPPool) Set(peers ...string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peers = consistenthash.New(defaultReplicas, nil) + p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath} + } +} + +// PickPeer picks a peer according to key +func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if peer := p.peers.Get(key); peer != "" && peer != p.self { + p.Log("Pick peer %s", peer) + return p.httpGetters[peer], true + } + return nil, false +} + +var _ PeerPicker = (*HTTPPool)(nil) +``` + +- `Set()` 方法实例化了一致性哈希算法,并且添加了传入的节点。 +- 并为每一个节点创建了一个 HTTP 客户端 `httpGetter`。 +- `PickerPeer()` 包装了一致性哈希算法的 `Get()` 方法,根据具体的 key,选择节点,返回节点对应的 HTTP 客户端。 + +至此,HTTPPool 既具备了提供 HTTP 服务的能力,也具备了根据具体的 key,创建 HTTP 客户端从远程节点获取缓存值的能力。 + +## 4 实现主流程 + +最后,我们需要将上述新增的功能集成在主流程(geecache.go)中。 + +[day5-multi-nodes/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache) + +```go +// A Group is a cache namespace and associated data loaded spread over +type Group struct { + name string + getter Getter + mainCache cache + peers PeerPicker +} + +// RegisterPeers registers a PeerPicker for choosing remote peer +func (g *Group) RegisterPeers(peers PeerPicker) { + if g.peers != nil { + panic("RegisterPeerPicker called more than once") + } + g.peers = peers +} + +func (g *Group) load(key string) (value ByteView, err error) { + if g.peers != nil { + if peer, ok := g.peers.PickPeer(key); ok { + if value, err = g.getFromPeer(peer, key); err == nil { + return value, nil + } + log.Println("[GeeCache] Failed to get from peer", err) + } + } + + return g.getLocally(key) +} + +func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) { + bytes, err := peer.Get(g.name, key) + if err != nil { + return ByteView{}, err + } + return ByteView{b: bytes}, nil +} +``` + +- 新增 `RegisterPeers()` 方法,将 实现了 PeerPicker 接口的 HTTPPool 注入到 Group 中。 +- 新增 `getFromPeer()` 方法,使用实现了 PeerGetter 接口的 httpGetter 从访问远程节点,获取缓存值。 +- 修改 load 方法,使用 `PickPeer()` 方法选择节点,若非本机节点,则调用 `getFromPeer()` 从远程获取。若是本机节点或失败,则回退到 `getLocally()`。 + +## 5 main 函数测试。 + +[day5-multi-nodes/main.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes) + +```go +var db = map[string]string{ + "Tom": "630", + "Jack": "589", + "Sam": "567", +} + +func createGroup() *geecache.Group { + return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc( + func(key string) ([]byte, error) { + log.Println("[SlowDB] search key", key) + if v, ok := db[key]; ok { + return []byte(v), nil + } + return nil, fmt.Errorf("%s not exist", key) + })) +} + +func startCacheServer(addr string, addrs []string, gee *geecache.Group) { + peers := geecache.NewHTTPPool(addr) + peers.Set(addrs...) + gee.RegisterPeers(peers) + log.Println("geecache is running at", addr) + log.Fatal(http.ListenAndServe(addr[7:], peers)) +} + +func startAPIServer(apiAddr string, gee *geecache.Group) { + http.Handle("/api", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + key := r.URL.Query().Get("key") + view, err := gee.Get(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(view.ByteSlice()) + + })) + log.Println("fontend server is running at", apiAddr) + log.Fatal(http.ListenAndServe(apiAddr[7:], nil)) + +} + +func main() { + var port int + var api bool + flag.IntVar(&port, "port", 8001, "Geecache server port") + flag.BoolVar(&api, "api", false, "Start a api server?") + flag.Parse() + + apiAddr := "http://localhost:9999" + addrMap := map[int]string{ + 8001: "http://localhost:8001", + 8002: "http://localhost:8002", + 8003: "http://localhost:8003", + } + + var addrs []string + for _, v := range addrMap { + addrs = append(addrs, v) + } + + gee := createGroup() + if api { + go startAPIServer(apiAddr, gee) + } + startCacheServer(addrMap[port], []string(addrs), gee) +} +``` + +main 函数的代码比较多,但是逻辑是非常简单的。 + +- `startCacheServer()` 用来启动缓存服务器:创建 HTTPPool,添加节点信息,注册到 gee 中,启动 HTTP 服务(共3个端口,8001/8002/8003),用户不感知。 +- `startAPIServer()` 用来启动一个 API 服务(端口 9999),与用户进行交互,用户感知。 +- `main()` 函数需要命令行传入 `port` 和 `api` 2 个参数,用来在指定端口启动 HTTP 服务。 + +为了方便,我们将启动的命令封装为一个 `shell` 脚本: + +```bash +#!/bin/bash +trap "rm server;kill 0" EXIT + +go build -o server +./server -port=8001 & +./server -port=8002 & +./server -port=8003 -api=1 & + +sleep 2 +echo ">>> start test" +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & +curl "http://localhost:9999/api?key=Tom" & + +wait +``` + +- `trap` 命令用于在 shell 脚本退出时,删掉临时文件,结束子进程。 + +```bash +$ ./run.sh +2020/02/16 21:17:43 geecache is running at http://localhost:8001 +2020/02/16 21:17:43 geecache is running at http://localhost:8002 +2020/02/16 21:17:43 geecache is running at http://localhost:8003 +2020/02/16 21:17:43 fontend server is running at http://localhost:9999 +>>> start test +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +... +630630630 +``` + +此时,我们可以打开一个新的 shell,进行测试: + +```bash +$ curl "http://localhost:9999/api?key=Tom" +630 +$ curl "http://localhost:9999/api?key=kkk" +kkk not exist +``` + +测试的时候,我们并发了 3 个请求 `?key=Tom`,从日志中可以看到,三次均选择了节点 `8001`,这是一致性哈希算法的功劳。但是有一个问题在于,同时向 `8001` 发起了 3 次请求。试想,假如有 10 万个在并发请求该数据呢?那就会向 `8001` 同时发起 10 万次请求,如果 `8001` 又同时向数据库发起 10 万次查询请求,很容易导致缓存被击穿。 + +三次请求的结果是一致的,对于相同的 key,能不能只向 `8001` 发起一次请求?这个问题下一次解决。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day5/dist_nodes.jpg b/gee-cache/doc/geecache-day5/dist_nodes.jpg new file mode 100644 index 0000000..1f9865e Binary files /dev/null and b/gee-cache/doc/geecache-day5/dist_nodes.jpg differ diff --git a/gee-cache/doc/geecache-day5/dist_nodes_logo.jpg b/gee-cache/doc/geecache-day5/dist_nodes_logo.jpg new file mode 100644 index 0000000..eb129ea Binary files /dev/null and b/gee-cache/doc/geecache-day5/dist_nodes_logo.jpg differ diff --git a/gee-cache/doc/geecache-day6.md b/gee-cache/doc/geecache-day6.md new file mode 100644 index 0000000..73ad7dd --- /dev/null +++ b/gee-cache/doc/geecache-day6.md @@ -0,0 +1,205 @@ +--- +title: 动手写分布式缓存 - GeeCache第六天 防止缓存击穿 +date: 2020-02-16 23:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了缓存雪崩、缓存击穿与缓存穿透的概念,使用 singleflight 防止缓存击穿,实现与测试。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- HTTP客户端 +- 分布式节点 +image: post/geecache-day6/singleflight_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day6 防止缓存击穿 +--- + +![geecache single flight](geecache-day6/singleflight.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第六篇。 + +- 缓存雪崩、缓存击穿与缓存穿透的概念简介。 +- 使用 singleflight 防止缓存击穿,实现与测试。**代码约70行** + +## 1 缓存雪崩、缓存击穿与缓存穿透 + +[GeeCache 第五天](https://geektutu.com/post/geecache-day5.html) 提到了缓存雪崩和缓存击穿,在这里做下总结: + +> **缓存雪崩**:缓存在同一时刻全部失效,造成瞬时DB请求量大、压力骤增,引起雪崩。缓存雪崩通常因为缓存服务器宕机、缓存的 key 设置了相同的过期时间等引起。 + +> **缓存击穿**:一个存在的key,在缓存过期的一刻,同时有大量的请求,这些请求都会击穿到 DB ,造成瞬时DB请求量大、压力骤增。 + +> **缓存穿透**:查询一个不存在的数据,因为不存在则不会写到缓存中,所以每次都会去请求 DB,如果瞬间流量过大,穿透到 DB,导致宕机。 + +## 2 singleflight 的实现 + +还记得 [GeeCache 第五天](https://geektutu.com/post/geecache-day5.html) 最后的测试结果吗? + +```bash +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001 +``` + +我们并发了 N 个请求 `?key=Tom`,8003 节点向 8001 同时发起了 N 次请求。假设对数据库的访问没有做任何限制的,很可能向数据库也发起 N 次请求,容易导致缓存击穿和穿透。即使对数据库做了防护,HTTP 请求是非常耗费资源的操作,针对相同的 key,8003 节点向 8001 发起三次请求也是没有必要的。那这种情况下,我们如何做到只向远端节点发起一次请求呢? + +geecache 实现了一个名为 singleflight 的 package 来解决这个问题。 + +[day6-single-flight/geecache/singleflight/singleflight.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day6-single-flight/geecache/singleflight) + +首先创建 `call` 和 `Group` 类型。 + +```go +package singleflight + +import "sync" + +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +type Group struct { + mu sync.Mutex // protects m + m map[string]*call +} +``` + +- `call` 代表正在进行中,或已经结束的请求。使用 `sync.WaitGroup` 锁避免重入。 +- `Group` 是 singleflight 的主数据结构,管理不同 key 的请求(call)。 + +实现 `Do` 方法 + +```go +func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} +``` + +- Do 方法,接收 2 个参数,第一个参数是 `key`,第二个参数是一个函数 `fn`。Do 的作用就是,针对相同的 key,无论 Do 被调用多少次,函数 `fn` 都只会被调用一次,等待 fn 调用结束了,返回返回值或错误。 + +`g.mu` 是保护 Group 的成员变量 `m` 不被并发读写而加上的锁。为了便于理解 `Do` 函数,我们将 `g.mu` 暂时去掉。并且把 `g.m` 延迟初始化的部分去掉,延迟初始化的目的很简单,提高内存使用效率。 + +剩下的逻辑就很清晰了: + +```go +func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + if c, ok := g.m[key]; ok { + c.wg.Wait() // 如果请求正在进行中,则等待 + return c.val, c.err // 请求结束,返回结果 + } + c := new(call) + c.wg.Add(1) // 发起请求前加锁 + g.m[key] = c // 添加到 g.m,表明 key 已经有对应的请求在处理 + + c.val, c.err = fn() // 调用 fn,发起请求 + c.wg.Done() // 请求结束 + + delete(g.m, key) // 更新 g.m + + return c.val, c.err // 返回结果 +} +``` + +并发协程之间不需要消息传递,非常适合 `sync.WaitGroup`。 + +- wg.Add(1) 锁加1。 +- wg.Wait() 阻塞,直到锁被释放。 +- wg.Done() 锁减1。 + +## 3 singleflight 的使用 + +[day6-single-flight/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day6-single-flight/geecache) + +```go +type Group struct { + name string + getter Getter + mainCache cache + peers PeerPicker + // use singleflight.Group to make sure that + // each key is only fetched once + loader *singleflight.Group +} + +func NewGroup(name string, cacheBytes int64, getter Getter) *Group { + // ... + g := &Group{ + // ... + loader: &singleflight.Group{}, + } + return g +} + +func (g *Group) load(key string) (value ByteView, err error) { + // each key is only fetched once (either locally or remotely) + // regardless of the number of concurrent callers. + viewi, err := g.loader.Do(key, func() (interface{}, error) { + if g.peers != nil { + if peer, ok := g.peers.PickPeer(key); ok { + if value, err = g.getFromPeer(peer, key); err == nil { + return value, nil + } + log.Println("[GeeCache] Failed to get from peer", err) + } + } + + return g.getLocally(key) + }) + + if err == nil { + return viewi.(ByteView), nil + } + return +} +``` + +- 修改 `geecache.go` 中的 `Group`,添加成员变量 loader,并更新构建函数 `NewGroup`。 +- 修改 `load` 函数,将原来的 load 的逻辑,使用 `g.loader.Do` 包裹起来即可,这样确保了并发场景下针对相同的 key,`load` 过程只会调用一次。 + +## 4 测试 + +执行 `run.sh` 就可以看到效果了。 + +```bash +$ ./run.sh +2020/02/16 22:36:00 [Server http://localhost:8003] Pick peer http://localhost:8001 +2020/02/16 22:36:00 [Server http://localhost:8001] GET /_geecache/scores/Tom +2020/02/16 22:36:00 [SlowDB] search key Tom +630630630 +``` + +可以看到,向 API 发起了三次并发请求,但8003 只向 8001 发起了一次请求,就搞定了。 + +如果并发度不够高,可能仍会看到向 8001 请求三次的场景。这种情况下三次请求是串行执行的,并没有触发 `singleflight` 的锁机制工作,可以加大并发数量再测试。即,将 `run.sh` 中的 `curl` 命令复制 N 次。 + +## 附 推荐 + +- [Go 语言简明教程#并发编程](https://geektutu.com/post/quick-golang.html#7-%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B-goroutine) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day6/singleflight.jpg b/gee-cache/doc/geecache-day6/singleflight.jpg new file mode 100644 index 0000000..5a0d1a5 Binary files /dev/null and b/gee-cache/doc/geecache-day6/singleflight.jpg differ diff --git a/gee-cache/doc/geecache-day6/singleflight_logo.jpg b/gee-cache/doc/geecache-day6/singleflight_logo.jpg new file mode 100644 index 0000000..4da1fac Binary files /dev/null and b/gee-cache/doc/geecache-day6/singleflight_logo.jpg differ diff --git a/gee-cache/doc/geecache-day7.md b/gee-cache/doc/geecache-day7.md new file mode 100644 index 0000000..2f26210 --- /dev/null +++ b/gee-cache/doc/geecache-day7.md @@ -0,0 +1,175 @@ +--- +title: 动手写分布式缓存 - GeeCache第七天 使用 Protobuf 通信 +date: 2020-02-17 00:30:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了使用 protobuf(protocol buffer) 进行节点间通信,编码报文,提高效率 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现 +- HTTP客户端 +- 分布式节点 +image: post/geecache-day7/protobuf_logo.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day7 使用 Protobuf 通信 +--- + +![geecache protobuf](geecache-day7/protobuf.jpg) + +本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第七篇。 + +- 为什么要使用 protobuf? +- 使用 protobuf 进行节点间通信,编码报文,提高效率。**代码约50行** + +## 1 为什么要使用 protobuf + +> protobuf 即 Protocol Buffers,Google 开发的一种数据描述语言,是一种轻便高效的结构化数据存储格式,与语言、平台无关,可扩展可序列化。protobuf 以二进制方式存储,占用空间小。 + +protobuf 的安装和使用教程请移步 [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html),这篇文章就不再赘述了。protobuf 广泛地应用于远程过程调用(RPC) 的二进制传输,使用 protobuf 的目的非常简单,为了获得更高的性能。传输前使用 protobuf 编码,接收方再进行解码,可以显著地降低二进制传输的大小。另外一方面,protobuf 可非常适合传输结构化数据,便于通信字段的扩展。 + +使用 protobuf 一般分为以下 2 步: + +- 按照 protobuf 的语法,在 `.proto` 文件中定义数据结构,并使用 `protoc` 生成 Go 代码(`.proto` 文件是跨平台的,还可以生成 C、Java 等其他源码文件)。 +- 在项目代码中引用生成的 Go 代码。 + +## 2 使用 protobuf 通信 + +新建 package `geecachepb`,定义 `geecachepb.proto` + +[day7-proto-buf/geecache/geecachepb/geecachepb.proto - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache/geecachepb) + +```go +syntax = "proto3"; + +package geecachepb; + +message Request { + string group = 1; + string key = 2; +} + +message Response { + bytes value = 1; +} + +service GroupCache { + rpc Get(Request) returns (Response); +} +``` + +- `Request` 包含 2 个字段, group 和 cache,这与我们之前定义的接口 `/_geecache//` 所需的参数吻合。 +- `Response` 包含 1 个字段,bytes,类型为 byte 数组,与之前吻合。 + +生成 `geecache.pb.go` + +```bash +$ protoc --go_out=. *.proto +$ ls +geecachepb.pb.go geecachepb.proto +``` + +可以看到 `geecachepb.pb.go` 中有如下数据类型: + +```go +type Request struct { + Group string `protobuf:"bytes,1,opt,name=group,proto3" json:"group,omitempty"` + Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + ... +} +type Response struct { + Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` +} +``` + +接下来,修改 `peers.go` 中的 `PeerGetter` 接口,参数使用 `geecachepb.pb.go` 中的数据类型。 + +[day7-proto-buf/geecache/peers.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache) + +```go +import pb "geecache/geecachepb" + +type PeerGetter interface { + Get(in *pb.Request, out *pb.Response) error +} +``` + +最后,修改 `geecache.go` 和 `http.go` 中使用了 `PeerGetter` 接口的地方。 + +[day7-proto-buf/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache) + +```go +import ( + // ... + pb "geecache/geecachepb" +) + +func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) { + req := &pb.Request{ + Group: g.name, + Key: key, + } + res := &pb.Response{} + err := peer.Get(req, res) + if err != nil { + return ByteView{}, err + } + return ByteView{b: res.Value}, nil +} +``` + +[day7-proto-buf/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache) + +```go +import ( + // ... + pb "geecache/geecachepb" + "github.com/golang/protobuf/proto" +) + +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // ... + // Write the value to the response body as a proto message. + body, err := proto.Marshal(&pb.Response{Value: view.ByteSlice()}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(body) +} + +func (h *httpGetter) Get(in *pb.Request, out *pb.Response) error { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(in.GetGroup()), + url.QueryEscape(in.GetKey()), + ) + res, err := http.Get(u) + // ... + if err = proto.Unmarshal(bytes, out); err != nil { + return fmt.Errorf("decoding response body: %v", err) + } + + return nil +} +``` + +- `ServeHTTP()` 中使用 `proto.Marshal()` 编码 HTTP 响应。 +- `Get()` 中使用 `proto.Unmarshal()` 解码 HTTP 响应。 + +至此,我们已经将 HTTP 通信的中间载体替换成了 protobuf。运行 `run.sh` 即可以测试 GeeCache 能否正常工作。 + +## 总结 + +到这一篇为止,7 天用 Go 动手写/从零实现分布式缓存 GeeCache 这个系列就完成了。简单回顾下。第一天,为了解决资源限制的问题,实现了 LRU 缓存淘汰算法;第二天实现了单机并发,并给用户提供了自定义数据源的回调函数;第三天实现了 HTTP 服务端;第四天实现了一致性哈希算法,解决远程节点的挑选问题;第五天创建 HTTP 客户端,实现了多节点间的通信;第六天实现了 singleflight 解决缓存击穿的问题;第七天,使用 protobuf 库,优化了节点间通信的性能。如果看到这里,还没有动手写的话呢,赶紧动手写起来吧。一天差不多只需要实现 100 行代码呢。 + +## 附 推荐 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html) \ No newline at end of file diff --git a/gee-cache/doc/geecache-day7/protobuf.jpg b/gee-cache/doc/geecache-day7/protobuf.jpg new file mode 100644 index 0000000..85bc336 Binary files /dev/null and b/gee-cache/doc/geecache-day7/protobuf.jpg differ diff --git a/gee-cache/doc/geecache-day7/protobuf_logo.jpg b/gee-cache/doc/geecache-day7/protobuf_logo.jpg new file mode 100644 index 0000000..e97f835 Binary files /dev/null and b/gee-cache/doc/geecache-day7/protobuf_logo.jpg differ diff --git a/gee-cache/doc/geecache.md b/gee-cache/doc/geecache.md new file mode 100644 index 0000000..dccc53d --- /dev/null +++ b/gee-cache/doc/geecache.md @@ -0,0 +1,75 @@ +--- +title: 7天用Go从零实现分布式缓存GeeCache +date: 2020-02-08 01:00:00 +description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。功能包括单机/分布式缓存,LRU (Least Recently Used) 缓存策略,防止缓存击穿、一致性哈希(Consistent Hash),protobuf 通信等。 +tags: +- Go +nav: 从零实现 +categories: +- 分布式缓存 - GeeCache +keywords: +- Go语言 +- 从零实现分布式缓存 +- 动手写分布式缓存 +image: post/geecache/geecache_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day0 序言 +--- + +![分布式缓存geecache](geecache/geecache.jpg) + +## 1 谈谈分布式缓存 + +第一次请求时将一些耗时操作的结果暂存,以后遇到相同的请求,直接返回暂存的数据。我想这是大部分童鞋对于缓存的理解。在计算机系统中,缓存无处不在,比如我们访问一个网页,网页和引用的 JS/CSS 等静态文件,根据不同的策略,会缓存在浏览器本地或是 CDN 服务器,那在第二次访问的时候,就会觉得网页加载的速度快了不少;比如微博的点赞的数量,不可能每个人每次访问,都从数据库中查找所有点赞的记录再统计,数据库的操作是很耗时的,很难支持那么大的流量,所以一般点赞这类数据是缓存在 Redis 服务集群中的。 + +> 商业世界里,现金为王;架构世界里,缓存为王。 + +缓存中最简单的莫过于存储在内存中的键值对缓存了。说到键值对,很容易想到的是字典(dict)类型,Go 语言中称之为 map。那直接创建一个 map,每次有新数据就往 map 中插入不就好了,这不就是键值对缓存么?这样做有什么问题呢? + +1)内存不够了怎么办? + +那就随机删掉几条数据好了。随机删掉好呢?还是按照时间顺序好呢?或者是有没有其他更好的淘汰策略呢?不同数据的访问频率是不一样的,优先删除访问频率低的数据是不是更好呢?数据的访问频率可能随着时间变化,那优先删除最近最少访问的数据可能是一个更好的选择。我们需要实现一个合理的淘汰策略。 + +2)并发写入冲突了怎么办? + +对缓存的访问,一般不可能是串行的。map 是没有并发保护的,应对并发的场景,修改操作(包括新增,更新和删除)需要加锁。 + +3)单机性能不够怎么办? + +单台计算机的资源是有限的,计算、存储等都是有限的。随着业务量和访问量的增加,单台机器很容易遇到瓶颈。如果利用多台计算机的资源,并行处理提高性能就要缓存应用能够支持分布式,这称为水平扩展(scale horizontally)。与水平扩展相对应的是垂直扩展(scale vertically),即通过增加单个节点的计算、存储、带宽等,来提高系统的性能,硬件的成本和性能并非呈线性关系,大部分情况下,分布式系统是一个更优的选择。 + +4)... + +## 2 关于 GeeCache + +设计一个分布式缓存系统,需要考虑资源控制、淘汰策略、并发、分布式节点通信等各个方面的问题。而且,针对不同的应用场景,还需要在不同的特性之间权衡,例如,是否需要支持缓存更新?还是假定缓存在淘汰之前是不允许改变的。不同的权衡对应着不同的实现。 + +[groupcache](https://github.com/golang/groupcache) 是 Go 语言版的 memcached,目的是在某些特定场合替代 memcached。groupcache 的作者也是 memcached 的作者。无论是了解单机缓存还是分布式缓存,深入学习这个库的实现都是非常有意义的。 + +`GeeCache` 基本上模仿了 [groupcache](https://github.com/golang/groupcache) 的实现,为了将代码量限制在 500 行左右(groupcache 约 3000 行),裁剪了部分功能。但总体实现上,还是与 groupcache 非常接近的。支持特性有: + +- 单机缓存和基于 HTTP 的分布式缓存 +- 最近最少访问(Least Recently Used, LRU) 缓存策略 +- 使用 Go 锁机制防止缓存击穿 +- 使用一致性哈希选择节点,实现负载均衡 +- 使用 protobuf 优化节点间二进制通信 +- ... + +`GeeCache` 分7天实现,每天完成的部分都是可以独立运行和测试的,就像搭积木一样,每天实现的特性组合在一起就是最终的分布式缓存系统。每天的代码在 100 行左右。 + +## 3 目录 + +- 第一天:[LRU 缓存淘汰策略](https://geektutu.com/post/geecache-day1.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day1-lru) +- 第二天:[单机并发缓存](https://geektutu.com/post/geecache-day2.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day2-single-node) +- 第三天:[HTTP 服务端](https://geektutu.com/post/geecache-day3.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day3-http-server) +- 第四天:[一致性哈希(Hash)](https://geektutu.com/post/geecache-day4.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day4-consistent-hash) +- 第五天:[分布式节点](https://geektutu.com/post/geecache-day5.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day5-multi-nodes) +- 第六天:[防止缓存击穿](https://geektutu.com/post/geecache-day6.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day6-single-flight) +- 第七天:[使用 Protobuf 通信](https://geektutu.com/post/geecache-day7.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day7-proto-buf) + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html) \ No newline at end of file diff --git a/gee-cache/doc/geecache/geecache.jpg b/gee-cache/doc/geecache/geecache.jpg new file mode 100644 index 0000000..d19e7fe Binary files /dev/null and b/gee-cache/doc/geecache/geecache.jpg differ diff --git a/gee-cache/doc/geecache/geecache_sm.jpg b/gee-cache/doc/geecache/geecache_sm.jpg new file mode 100644 index 0000000..a3832b0 Binary files /dev/null and b/gee-cache/doc/geecache/geecache_sm.jpg differ diff --git a/gee-orm/day1-database-sql/cmd_test/main.go b/gee-orm/day1-database-sql/cmd_test/main.go new file mode 100755 index 0000000..1048cf4 --- /dev/null +++ b/gee-orm/day1-database-sql/cmd_test/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "fmt" + "geeorm" + _ "github.com/mattn/go-sqlite3" +) + +func main() { + engine, _ := geeorm.NewEngine("sqlite3", "gee.db") + defer engine.Close() + s := engine.NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + count, _ := result.RowsAffected() + fmt.Printf("Exec success, %d affected\n", count) + +} diff --git a/gee-orm/day1-database-sql/geeorm.go b/gee-orm/day1-database-sql/geeorm.go new file mode 100644 index 0000000..3611b94 --- /dev/null +++ b/gee-orm/day1-database-sql/geeorm.go @@ -0,0 +1,44 @@ +package geeorm + +import ( + "database/sql" + + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + e = &Engine{db: db} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db) +} diff --git a/gee-orm/day1-database-sql/geeorm_test.go b/gee-orm/day1-database-sql/geeorm_test.go new file mode 100644 index 0000000..c6da191 --- /dev/null +++ b/gee-orm/day1-database-sql/geeorm_test.go @@ -0,0 +1,20 @@ +package geeorm + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} diff --git a/gee-orm/day1-database-sql/go.mod b/gee-orm/day1-database-sql/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day1-database-sql/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day1-database-sql/log/log.go b/gee-orm/day1-database-sql/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day1-database-sql/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day1-database-sql/log/log_test.go b/gee-orm/day1-database-sql/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day1-database-sql/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day1-database-sql/session/raw.go b/gee-orm/day1-database-sql/session/raw.go new file mode 100644 index 0000000..f9f4f87 --- /dev/null +++ b/gee-orm/day1-database-sql/session/raw.go @@ -0,0 +1,66 @@ +package session + +import ( + "database/sql" + "geeorm/log" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB) *Session { + return &Session{db: db} +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil +} + +// DB returns *sql.DB +func (s *Session) DB() *sql.DB { + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} diff --git a/gee-orm/day1-database-sql/session/raw_test.go b/gee-orm/day1-database-sql/session/raw_test.go new file mode 100644 index 0000000..36e9678 --- /dev/null +++ b/gee-orm/day1-database-sql/session/raw_test.go @@ -0,0 +1,43 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +var TestDB *sql.DB + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day2-reflect-schema/dialect/dialect.go b/gee-orm/day2-reflect-schema/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day2-reflect-schema/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day2-reflect-schema/dialect/sqlite3.go b/gee-orm/day2-reflect-schema/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day2-reflect-schema/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day2-reflect-schema/dialect/sqlite3_test.go b/gee-orm/day2-reflect-schema/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day2-reflect-schema/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day2-reflect-schema/geeorm.go b/gee-orm/day2-reflect-schema/geeorm.go new file mode 100644 index 0000000..b1881ce --- /dev/null +++ b/gee-orm/day2-reflect-schema/geeorm.go @@ -0,0 +1,51 @@ +package geeorm + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} diff --git a/gee-orm/day2-reflect-schema/geeorm_test.go b/gee-orm/day2-reflect-schema/geeorm_test.go new file mode 100644 index 0000000..c6da191 --- /dev/null +++ b/gee-orm/day2-reflect-schema/geeorm_test.go @@ -0,0 +1,20 @@ +package geeorm + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} diff --git a/gee-orm/day2-reflect-schema/go.mod b/gee-orm/day2-reflect-schema/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day2-reflect-schema/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day2-reflect-schema/log/log.go b/gee-orm/day2-reflect-schema/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day2-reflect-schema/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day2-reflect-schema/log/log_test.go b/gee-orm/day2-reflect-schema/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day2-reflect-schema/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day2-reflect-schema/schema/schema.go b/gee-orm/day2-reflect-schema/schema/schema.go new file mode 100644 index 0000000..93d36da --- /dev/null +++ b/gee-orm/day2-reflect-schema/schema/schema.go @@ -0,0 +1,76 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day2-reflect-schema/schema/schema_test.go b/gee-orm/day2-reflect-schema/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day2-reflect-schema/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day2-reflect-schema/session/raw.go b/gee-orm/day2-reflect-schema/session/raw.go new file mode 100644 index 0000000..862c501 --- /dev/null +++ b/gee-orm/day2-reflect-schema/session/raw.go @@ -0,0 +1,73 @@ +package session + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil +} + +// DB returns *sql.DB +func (s *Session) DB() *sql.DB { + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} diff --git a/gee-orm/day2-reflect-schema/session/raw_test.go b/gee-orm/day2-reflect-schema/session/raw_test.go new file mode 100644 index 0000000..404bb6e --- /dev/null +++ b/gee-orm/day2-reflect-schema/session/raw_test.go @@ -0,0 +1,48 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day2-reflect-schema/session/table.go b/gee-orm/day2-reflect-schema/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day2-reflect-schema/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day2-reflect-schema/session/table_test.go b/gee-orm/day2-reflect-schema/session/table_test.go new file mode 100644 index 0000000..c070c7b --- /dev/null +++ b/gee-orm/day2-reflect-schema/session/table_test.go @@ -0,0 +1,27 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day3-save-query/clause/clause.go b/gee-orm/day3-save-query/clause/clause.go new file mode 100644 index 0000000..daa930d --- /dev/null +++ b/gee-orm/day3-save-query/clause/clause.go @@ -0,0 +1,48 @@ +package clause + +import ( + "strings" +) + +// Clause contains SQL conditions +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +// Type is the type of Clause +type Type int + +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY +) + +// Set adds a sub clause of specific type +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +// Build generate the final SQL and SQLVars +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} diff --git a/gee-orm/day3-save-query/clause/clause_test.go b/gee-orm/day3-save-query/clause/clause_test.go new file mode 100644 index 0000000..f5267ca --- /dev/null +++ b/gee-orm/day3-save-query/clause/clause_test.go @@ -0,0 +1,39 @@ +package clause + +import ( + "reflect" + "testing" +) + +func TestClause_Set(t *testing.T) { + var clause Clause + clause.Set(INSERT, "User", []string{"Name", "Age"}) + sql := clause.sql[INSERT] + vars := clause.sqlVars[INSERT] + t.Log(sql, vars) + if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 { + t.Fatal("failed to get clause") + } +} + +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) +} diff --git a/gee-orm/day3-save-query/clause/generator.go b/gee-orm/day3-save-query/clause/generator.go new file mode 100644 index 0000000..9ad8e24 --- /dev/null +++ b/gee-orm/day3-save-query/clause/generator.go @@ -0,0 +1,78 @@ +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} diff --git a/gee-orm/day3-save-query/dialect/dialect.go b/gee-orm/day3-save-query/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day3-save-query/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day3-save-query/dialect/sqlite3.go b/gee-orm/day3-save-query/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day3-save-query/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day3-save-query/dialect/sqlite3_test.go b/gee-orm/day3-save-query/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day3-save-query/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day3-save-query/geeorm.go b/gee-orm/day3-save-query/geeorm.go new file mode 100644 index 0000000..b1881ce --- /dev/null +++ b/gee-orm/day3-save-query/geeorm.go @@ -0,0 +1,51 @@ +package geeorm + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} diff --git a/gee-orm/day3-save-query/geeorm_test.go b/gee-orm/day3-save-query/geeorm_test.go new file mode 100644 index 0000000..c6da191 --- /dev/null +++ b/gee-orm/day3-save-query/geeorm_test.go @@ -0,0 +1,20 @@ +package geeorm + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} diff --git a/gee-orm/day3-save-query/go.mod b/gee-orm/day3-save-query/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day3-save-query/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day3-save-query/log/log.go b/gee-orm/day3-save-query/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day3-save-query/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day3-save-query/log/log_test.go b/gee-orm/day3-save-query/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day3-save-query/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day3-save-query/schema/schema.go b/gee-orm/day3-save-query/schema/schema.go new file mode 100644 index 0000000..2c9b927 --- /dev/null +++ b/gee-orm/day3-save-query/schema/schema.go @@ -0,0 +1,75 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day3-save-query/schema/schema_test.go b/gee-orm/day3-save-query/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day3-save-query/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day3-save-query/session/raw.go b/gee-orm/day3-save-query/session/raw.go new file mode 100644 index 0000000..161fcb4 --- /dev/null +++ b/gee-orm/day3-save-query/session/raw.go @@ -0,0 +1,76 @@ +package session + +import ( + "database/sql" + "geeorm/clause" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} + +// DB returns *sql.DB +func (s *Session) DB() *sql.DB { + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} \ No newline at end of file diff --git a/gee-orm/day3-save-query/session/raw_test.go b/gee-orm/day3-save-query/session/raw_test.go new file mode 100644 index 0000000..d2212fe --- /dev/null +++ b/gee-orm/day3-save-query/session/raw_test.go @@ -0,0 +1,47 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day3-save-query/session/record.go b/gee-orm/day3-save-query/session/record.go new file mode 100644 index 0000000..5f033e0 --- /dev/null +++ b/gee-orm/day3-save-query/session/record.go @@ -0,0 +1,52 @@ +package session + +import ( + "geeorm/clause" + "reflect" +) + +// Insert one or more records in database +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + + return result.RowsAffected() +} + +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} diff --git a/gee-orm/day3-save-query/session/record_test.go b/gee-orm/day3-save-query/session/record_test.go new file mode 100644 index 0000000..67bfb2a --- /dev/null +++ b/gee-orm/day3-save-query/session/record_test.go @@ -0,0 +1,37 @@ +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} diff --git a/gee-orm/day3-save-query/session/table.go b/gee-orm/day3-save-query/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day3-save-query/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day3-save-query/session/table_test.go b/gee-orm/day3-save-query/session/table_test.go new file mode 100644 index 0000000..3bb7554 --- /dev/null +++ b/gee-orm/day3-save-query/session/table_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day4-chain-operation/clause/clause.go b/gee-orm/day4-chain-operation/clause/clause.go new file mode 100644 index 0000000..02fcf93 --- /dev/null +++ b/gee-orm/day4-chain-operation/clause/clause.go @@ -0,0 +1,51 @@ +package clause + +import ( + "strings" +) + +// Clause contains SQL conditions +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +// Type is the type of Clause +type Type int + +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY + UPDATE + DELETE + COUNT +) + +// Set adds a sub clause of specific type +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +// Build generate the final SQL and SQLVars +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} diff --git a/gee-orm/day4-chain-operation/clause/clause_test.go b/gee-orm/day4-chain-operation/clause/clause_test.go new file mode 100644 index 0000000..62e0ccb --- /dev/null +++ b/gee-orm/day4-chain-operation/clause/clause_test.go @@ -0,0 +1,74 @@ +package clause + +import ( + "reflect" + "testing" +) + +func TestClause_Set(t *testing.T) { + var clause Clause + clause.Set(INSERT, "User", []string{"Name", "Age"}) + sql := clause.sql[INSERT] + vars := clause.sqlVars[INSERT] + t.Log(sql, vars) + if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 { + t.Fatal("failed to get clause") + } +} + +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func testUpdate(t *testing.T) { + var clause Clause + clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30}) + clause.Set(WHERE, "Name = ?", "Tom") + sql, vars := clause.Build(UPDATE, WHERE) + t.Log(sql, vars) + if sql != "UPDATE User SET Age = ? WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func testDelete(t *testing.T) { + var clause Clause + clause.Set(DELETE, "User") + clause.Set(WHERE, "Name = ?", "Tom") + + sql, vars := clause.Build(DELETE, WHERE) + t.Log(sql, vars) + if sql != "DELETE FROM User WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) + t.Run("update", func(t *testing.T) { + testUpdate(t) + }) + t.Run("delete", func(t *testing.T) { + testDelete(t) + }) +} diff --git a/gee-orm/day4-chain-operation/clause/generator.go b/gee-orm/day4-chain-operation/clause/generator.go new file mode 100644 index 0000000..23635ba --- /dev/null +++ b/gee-orm/day4-chain-operation/clause/generator.go @@ -0,0 +1,101 @@ +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy + generators[UPDATE] = _update + generators[DELETE] = _delete + generators[COUNT] = _count +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} + +func _update(values ...interface{}) (string, []interface{}) { + tableName := values[0] + m := values[1].(map[string]interface{}) + var keys []string + var vars []interface{} + for k, v := range m { + keys = append(keys, k+" = ?") + vars = append(vars, v) + } + return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars +} + +func _delete(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{} +} + +func _count(values ...interface{}) (string, []interface{}) { + return _select(values[0], []string{"count(*)"}) +} diff --git a/gee-orm/day4-chain-operation/dialect/dialect.go b/gee-orm/day4-chain-operation/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day4-chain-operation/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day4-chain-operation/dialect/sqlite3.go b/gee-orm/day4-chain-operation/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day4-chain-operation/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day4-chain-operation/dialect/sqlite3_test.go b/gee-orm/day4-chain-operation/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day4-chain-operation/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day4-chain-operation/geeorm.go b/gee-orm/day4-chain-operation/geeorm.go new file mode 100644 index 0000000..b1881ce --- /dev/null +++ b/gee-orm/day4-chain-operation/geeorm.go @@ -0,0 +1,51 @@ +package geeorm + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} diff --git a/gee-orm/day4-chain-operation/geeorm_test.go b/gee-orm/day4-chain-operation/geeorm_test.go new file mode 100644 index 0000000..c6da191 --- /dev/null +++ b/gee-orm/day4-chain-operation/geeorm_test.go @@ -0,0 +1,20 @@ +package geeorm + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} diff --git a/gee-orm/day4-chain-operation/go.mod b/gee-orm/day4-chain-operation/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day4-chain-operation/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day4-chain-operation/log/log.go b/gee-orm/day4-chain-operation/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day4-chain-operation/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day4-chain-operation/log/log_test.go b/gee-orm/day4-chain-operation/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day4-chain-operation/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day4-chain-operation/schema/schema.go b/gee-orm/day4-chain-operation/schema/schema.go new file mode 100644 index 0000000..2c9b927 --- /dev/null +++ b/gee-orm/day4-chain-operation/schema/schema.go @@ -0,0 +1,75 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day4-chain-operation/schema/schema_test.go b/gee-orm/day4-chain-operation/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day4-chain-operation/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day4-chain-operation/session/raw.go b/gee-orm/day4-chain-operation/session/raw.go new file mode 100644 index 0000000..161fcb4 --- /dev/null +++ b/gee-orm/day4-chain-operation/session/raw.go @@ -0,0 +1,76 @@ +package session + +import ( + "database/sql" + "geeorm/clause" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} + +// DB returns *sql.DB +func (s *Session) DB() *sql.DB { + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} \ No newline at end of file diff --git a/gee-orm/day4-chain-operation/session/raw_test.go b/gee-orm/day4-chain-operation/session/raw_test.go new file mode 100644 index 0000000..404bb6e --- /dev/null +++ b/gee-orm/day4-chain-operation/session/raw_test.go @@ -0,0 +1,48 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day4-chain-operation/session/record.go b/gee-orm/day4-chain-operation/session/record.go new file mode 100644 index 0000000..cef890b --- /dev/null +++ b/gee-orm/day4-chain-operation/session/record.go @@ -0,0 +1,129 @@ +package session + +import ( + "errors" + "geeorm/clause" + "reflect" +) + +// Insert one or more records in database +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + + return result.RowsAffected() +} + +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} + +// First gets the 1st row +func (s *Session) First(value interface{}) error { + dest := reflect.Indirect(reflect.ValueOf(value)) + destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem() + if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil { + return err + } + if destSlice.Len() == 0 { + return errors.New("NOT FOUND") + } + dest.Set(destSlice.Index(0)) + return nil +} + +// Limit adds limit condition to clause +func (s *Session) Limit(num int) *Session { + s.clause.Set(clause.LIMIT, num) + return s +} + +// Where adds limit condition to clause +func (s *Session) Where(desc string, args ...interface{}) *Session { + var vars []interface{} + s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...) + return s +} + +// OrderBy adds order by condition to clause +func (s *Session) OrderBy(desc string) *Session { + s.clause.Set(clause.ORDERBY, desc) + return s +} + +// Update records with where clause +// support map[string]interface{} +// also support kv list: "Name", "Tom", "Age", 18, .... +func (s *Session) Update(kv ...interface{}) (int64, error) { + m, ok := kv[0].(map[string]interface{}) + if !ok { + m = make(map[string]interface{}) + for i := 0; i < len(kv); i += 2 { + m[kv[i].(string)] = kv[i+1] + } + } + s.clause.Set(clause.UPDATE, s.RefTable().Name, m) + sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// Delete records with where clause +func (s *Session) Delete() (int64, error) { + s.clause.Set(clause.DELETE, s.RefTable().Name) + sql, vars := s.clause.Build(clause.DELETE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// Count records with where clause +func (s *Session) Count() (int64, error) { + s.clause.Set(clause.COUNT, s.RefTable().Name) + sql, vars := s.clause.Build(clause.COUNT, clause.WHERE) + row := s.Raw(sql, vars...).QueryRow() + var tmp int64 + if err := row.Scan(&tmp); err != nil { + return 0, err + } + return tmp, nil +} diff --git a/gee-orm/day4-chain-operation/session/record_test.go b/gee-orm/day4-chain-operation/session/record_test.go new file mode 100644 index 0000000..5d482a0 --- /dev/null +++ b/gee-orm/day4-chain-operation/session/record_test.go @@ -0,0 +1,97 @@ +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} + +func TestSession_First(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.First(u) + if err != nil || u.Name != "Tom" || u.Age != 18 { + t.Fatal("failed to query first") + } +} + +func TestSession_Limit(t *testing.T) { + s := testRecordInit(t) + var users []User + err := s.Limit(1).Find(&users) + if err != nil || len(users) != 1 { + t.Fatal("failed to query with limit condition") + } +} + +func TestSession_Where(t *testing.T) { + s := testRecordInit(t) + var users []User + _, err1 := s.Insert(user3) + err2 := s.Where("Age = ?", 25).Find(&users) + + if err1 != nil || err2 != nil || len(users) != 2 { + t.Fatal("failed to query with where condition") + } +} + +func TestSession_OrderBy(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.OrderBy("Age DESC").First(u) + + if err != nil || u.Age != 25 { + t.Fatal("failed to query with order by condition") + } +} + +func TestSession_Update(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30) + u := &User{} + _ = s.OrderBy("Age DESC").First(u) + + if affected != 1 || u.Age != 30 { + t.Fatal("failed to update") + } +} + +func TestSession_DeleteAndCount(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Delete() + count, _ := s.Count() + + if affected != 1 || count != 1 { + t.Fatal("failed to delete or count") + } +} diff --git a/gee-orm/day4-chain-operation/session/table.go b/gee-orm/day4-chain-operation/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day4-chain-operation/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day4-chain-operation/session/table_test.go b/gee-orm/day4-chain-operation/session/table_test.go new file mode 100644 index 0000000..3bb7554 --- /dev/null +++ b/gee-orm/day4-chain-operation/session/table_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day5-hooks/clause/clause.go b/gee-orm/day5-hooks/clause/clause.go new file mode 100644 index 0000000..02fcf93 --- /dev/null +++ b/gee-orm/day5-hooks/clause/clause.go @@ -0,0 +1,51 @@ +package clause + +import ( + "strings" +) + +// Clause contains SQL conditions +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +// Type is the type of Clause +type Type int + +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY + UPDATE + DELETE + COUNT +) + +// Set adds a sub clause of specific type +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +// Build generate the final SQL and SQLVars +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} diff --git a/gee-orm/day5-hooks/clause/clause_test.go b/gee-orm/day5-hooks/clause/clause_test.go new file mode 100644 index 0000000..62e0ccb --- /dev/null +++ b/gee-orm/day5-hooks/clause/clause_test.go @@ -0,0 +1,74 @@ +package clause + +import ( + "reflect" + "testing" +) + +func TestClause_Set(t *testing.T) { + var clause Clause + clause.Set(INSERT, "User", []string{"Name", "Age"}) + sql := clause.sql[INSERT] + vars := clause.sqlVars[INSERT] + t.Log(sql, vars) + if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 { + t.Fatal("failed to get clause") + } +} + +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func testUpdate(t *testing.T) { + var clause Clause + clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30}) + clause.Set(WHERE, "Name = ?", "Tom") + sql, vars := clause.Build(UPDATE, WHERE) + t.Log(sql, vars) + if sql != "UPDATE User SET Age = ? WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func testDelete(t *testing.T) { + var clause Clause + clause.Set(DELETE, "User") + clause.Set(WHERE, "Name = ?", "Tom") + + sql, vars := clause.Build(DELETE, WHERE) + t.Log(sql, vars) + if sql != "DELETE FROM User WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) + t.Run("update", func(t *testing.T) { + testUpdate(t) + }) + t.Run("delete", func(t *testing.T) { + testDelete(t) + }) +} diff --git a/gee-orm/day5-hooks/clause/generator.go b/gee-orm/day5-hooks/clause/generator.go new file mode 100644 index 0000000..23635ba --- /dev/null +++ b/gee-orm/day5-hooks/clause/generator.go @@ -0,0 +1,101 @@ +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy + generators[UPDATE] = _update + generators[DELETE] = _delete + generators[COUNT] = _count +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} + +func _update(values ...interface{}) (string, []interface{}) { + tableName := values[0] + m := values[1].(map[string]interface{}) + var keys []string + var vars []interface{} + for k, v := range m { + keys = append(keys, k+" = ?") + vars = append(vars, v) + } + return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars +} + +func _delete(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{} +} + +func _count(values ...interface{}) (string, []interface{}) { + return _select(values[0], []string{"count(*)"}) +} diff --git a/gee-orm/day5-hooks/dialect/dialect.go b/gee-orm/day5-hooks/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day5-hooks/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day5-hooks/dialect/sqlite3.go b/gee-orm/day5-hooks/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day5-hooks/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day5-hooks/dialect/sqlite3_test.go b/gee-orm/day5-hooks/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day5-hooks/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day5-hooks/geeorm.go b/gee-orm/day5-hooks/geeorm.go new file mode 100644 index 0000000..b1881ce --- /dev/null +++ b/gee-orm/day5-hooks/geeorm.go @@ -0,0 +1,51 @@ +package geeorm + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} diff --git a/gee-orm/day5-hooks/geeorm_test.go b/gee-orm/day5-hooks/geeorm_test.go new file mode 100644 index 0000000..c6da191 --- /dev/null +++ b/gee-orm/day5-hooks/geeorm_test.go @@ -0,0 +1,20 @@ +package geeorm + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} diff --git a/gee-orm/day5-hooks/go.mod b/gee-orm/day5-hooks/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day5-hooks/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day5-hooks/log/log.go b/gee-orm/day5-hooks/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day5-hooks/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day5-hooks/log/log_test.go b/gee-orm/day5-hooks/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day5-hooks/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day5-hooks/schema/schema.go b/gee-orm/day5-hooks/schema/schema.go new file mode 100644 index 0000000..93d36da --- /dev/null +++ b/gee-orm/day5-hooks/schema/schema.go @@ -0,0 +1,76 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day5-hooks/schema/schema_test.go b/gee-orm/day5-hooks/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day5-hooks/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day5-hooks/session/hooks.go b/gee-orm/day5-hooks/session/hooks.go new file mode 100644 index 0000000..d73c3c2 --- /dev/null +++ b/gee-orm/day5-hooks/session/hooks.go @@ -0,0 +1,35 @@ +package session + +import ( + "geeorm/log" + "reflect" +) + +// Hooks constants +const ( + BeforeQuery = "BeforeQuery" + AfterQuery = "AfterQuery" + BeforeUpdate = "BeforeUpdate" + AfterUpdate = "AfterUpdate" + BeforeDelete = "BeforeDelete" + AfterDelete = "AfterDelete" + BeforeInsert = "BeforeInsert" + AfterInsert = "AfterInsert" +) + +// CallMethod calls the registered hooks +func (s *Session) CallMethod(method string, value interface{}) { + fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method) + if value != nil { + fm = reflect.ValueOf(value).MethodByName(method) + } + param := []reflect.Value{reflect.ValueOf(s)} + if fm.IsValid() { + if v := fm.Call(param); len(v) > 0 { + if err, ok := v[0].Interface().(error); ok { + log.Error(err) + } + } + } + return +} diff --git a/gee-orm/day5-hooks/session/hooks_test.go b/gee-orm/day5-hooks/session/hooks_test.go new file mode 100644 index 0000000..f896d01 --- /dev/null +++ b/gee-orm/day5-hooks/session/hooks_test.go @@ -0,0 +1,37 @@ +package session + +import ( + "geeorm/log" + "testing" +) + +type Account struct { + ID int `geeorm:"PRIMARY KEY"` + Password string +} + +func (account *Account) BeforeInsert(s *Session) error { + log.Info("before inert", account) + account.ID += 1000 + return nil +} + +func (account *Account) AfterQuery(s *Session) error { + log.Info("after query", account) + account.Password = "******" + return nil +} + +func TestSession_CallMethod(t *testing.T) { + s := NewSession().Model(&Account{}) + _ = s.DropTable() + _ = s.CreateTable() + _, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"}) + + u := &Account{} + + err := s.First(u) + if err != nil || u.ID != 1001 || u.Password != "******" { + t.Fatal("Failed to call hooks after query, got", u) + } +} diff --git a/gee-orm/day5-hooks/session/raw.go b/gee-orm/day5-hooks/session/raw.go new file mode 100644 index 0000000..161fcb4 --- /dev/null +++ b/gee-orm/day5-hooks/session/raw.go @@ -0,0 +1,76 @@ +package session + +import ( + "database/sql" + "geeorm/clause" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} + +// DB returns *sql.DB +func (s *Session) DB() *sql.DB { + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} \ No newline at end of file diff --git a/gee-orm/day5-hooks/session/raw_test.go b/gee-orm/day5-hooks/session/raw_test.go new file mode 100644 index 0000000..404bb6e --- /dev/null +++ b/gee-orm/day5-hooks/session/raw_test.go @@ -0,0 +1,48 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day5-hooks/session/record.go b/gee-orm/day5-hooks/session/record.go new file mode 100644 index 0000000..fdfca4d --- /dev/null +++ b/gee-orm/day5-hooks/session/record.go @@ -0,0 +1,136 @@ +package session + +import ( + "errors" + "geeorm/clause" + "reflect" +) + +// Insert one or more records in database +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + s.CallMethod(BeforeInsert, value) + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterInsert, nil) + return result.RowsAffected() +} + +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + s.CallMethod(BeforeQuery, nil) + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + s.CallMethod(AfterQuery, dest.Addr().Interface()) + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} + +// First gets the 1st row +func (s *Session) First(value interface{}) error { + dest := reflect.Indirect(reflect.ValueOf(value)) + destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem() + if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil { + return err + } + if destSlice.Len() == 0 { + return errors.New("NOT FOUND") + } + dest.Set(destSlice.Index(0)) + return nil +} + +// Limit adds limit condition to clause +func (s *Session) Limit(num int) *Session { + s.clause.Set(clause.LIMIT, num) + return s +} + +// Where adds limit condition to clause +func (s *Session) Where(desc string, args ...interface{}) *Session { + var vars []interface{} + s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...) + return s +} + +// OrderBy adds order by condition to clause +func (s *Session) OrderBy(desc string) *Session { + s.clause.Set(clause.ORDERBY, desc) + return s +} + +// Update records with where clause +// support map[string]interface{} +// also support kv list: "Name", "Tom", "Age", 18, .... +func (s *Session) Update(kv ...interface{}) (int64, error) { + s.CallMethod(BeforeUpdate, nil) + m, ok := kv[0].(map[string]interface{}) + if !ok { + m = make(map[string]interface{}) + for i := 0; i < len(kv); i += 2 { + m[kv[i].(string)] = kv[i+1] + } + } + s.clause.Set(clause.UPDATE, s.RefTable().Name, m) + sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterUpdate, nil) + return result.RowsAffected() +} + +// Delete records with where clause +func (s *Session) Delete() (int64, error) { + s.CallMethod(BeforeDelete, nil) + s.clause.Set(clause.DELETE, s.RefTable().Name) + sql, vars := s.clause.Build(clause.DELETE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterDelete, nil) + return result.RowsAffected() +} + +// Count records with where clause +func (s *Session) Count() (int64, error) { + s.clause.Set(clause.COUNT, s.RefTable().Name) + sql, vars := s.clause.Build(clause.COUNT, clause.WHERE) + row := s.Raw(sql, vars...).QueryRow() + var tmp int64 + if err := row.Scan(&tmp); err != nil { + return 0, err + } + return tmp, nil +} diff --git a/gee-orm/day5-hooks/session/record_test.go b/gee-orm/day5-hooks/session/record_test.go new file mode 100644 index 0000000..5d482a0 --- /dev/null +++ b/gee-orm/day5-hooks/session/record_test.go @@ -0,0 +1,97 @@ +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} + +func TestSession_First(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.First(u) + if err != nil || u.Name != "Tom" || u.Age != 18 { + t.Fatal("failed to query first") + } +} + +func TestSession_Limit(t *testing.T) { + s := testRecordInit(t) + var users []User + err := s.Limit(1).Find(&users) + if err != nil || len(users) != 1 { + t.Fatal("failed to query with limit condition") + } +} + +func TestSession_Where(t *testing.T) { + s := testRecordInit(t) + var users []User + _, err1 := s.Insert(user3) + err2 := s.Where("Age = ?", 25).Find(&users) + + if err1 != nil || err2 != nil || len(users) != 2 { + t.Fatal("failed to query with where condition") + } +} + +func TestSession_OrderBy(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.OrderBy("Age DESC").First(u) + + if err != nil || u.Age != 25 { + t.Fatal("failed to query with order by condition") + } +} + +func TestSession_Update(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30) + u := &User{} + _ = s.OrderBy("Age DESC").First(u) + + if affected != 1 || u.Age != 30 { + t.Fatal("failed to update") + } +} + +func TestSession_DeleteAndCount(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Delete() + count, _ := s.Count() + + if affected != 1 || count != 1 { + t.Fatal("failed to delete or count") + } +} diff --git a/gee-orm/day5-hooks/session/table.go b/gee-orm/day5-hooks/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day5-hooks/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day5-hooks/session/table_test.go b/gee-orm/day5-hooks/session/table_test.go new file mode 100644 index 0000000..3bb7554 --- /dev/null +++ b/gee-orm/day5-hooks/session/table_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day6-transaction/clause/clause.go b/gee-orm/day6-transaction/clause/clause.go new file mode 100644 index 0000000..02fcf93 --- /dev/null +++ b/gee-orm/day6-transaction/clause/clause.go @@ -0,0 +1,51 @@ +package clause + +import ( + "strings" +) + +// Clause contains SQL conditions +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +// Type is the type of Clause +type Type int + +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY + UPDATE + DELETE + COUNT +) + +// Set adds a sub clause of specific type +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +// Build generate the final SQL and SQLVars +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} diff --git a/gee-orm/day6-transaction/clause/clause_test.go b/gee-orm/day6-transaction/clause/clause_test.go new file mode 100644 index 0000000..62e0ccb --- /dev/null +++ b/gee-orm/day6-transaction/clause/clause_test.go @@ -0,0 +1,74 @@ +package clause + +import ( + "reflect" + "testing" +) + +func TestClause_Set(t *testing.T) { + var clause Clause + clause.Set(INSERT, "User", []string{"Name", "Age"}) + sql := clause.sql[INSERT] + vars := clause.sqlVars[INSERT] + t.Log(sql, vars) + if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 { + t.Fatal("failed to get clause") + } +} + +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func testUpdate(t *testing.T) { + var clause Clause + clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30}) + clause.Set(WHERE, "Name = ?", "Tom") + sql, vars := clause.Build(UPDATE, WHERE) + t.Log(sql, vars) + if sql != "UPDATE User SET Age = ? WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func testDelete(t *testing.T) { + var clause Clause + clause.Set(DELETE, "User") + clause.Set(WHERE, "Name = ?", "Tom") + + sql, vars := clause.Build(DELETE, WHERE) + t.Log(sql, vars) + if sql != "DELETE FROM User WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) + t.Run("update", func(t *testing.T) { + testUpdate(t) + }) + t.Run("delete", func(t *testing.T) { + testDelete(t) + }) +} diff --git a/gee-orm/day6-transaction/clause/generator.go b/gee-orm/day6-transaction/clause/generator.go new file mode 100644 index 0000000..23635ba --- /dev/null +++ b/gee-orm/day6-transaction/clause/generator.go @@ -0,0 +1,101 @@ +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy + generators[UPDATE] = _update + generators[DELETE] = _delete + generators[COUNT] = _count +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} + +func _update(values ...interface{}) (string, []interface{}) { + tableName := values[0] + m := values[1].(map[string]interface{}) + var keys []string + var vars []interface{} + for k, v := range m { + keys = append(keys, k+" = ?") + vars = append(vars, v) + } + return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars +} + +func _delete(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{} +} + +func _count(values ...interface{}) (string, []interface{}) { + return _select(values[0], []string{"count(*)"}) +} diff --git a/gee-orm/day6-transaction/dialect/dialect.go b/gee-orm/day6-transaction/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day6-transaction/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day6-transaction/dialect/sqlite3.go b/gee-orm/day6-transaction/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day6-transaction/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day6-transaction/dialect/sqlite3_test.go b/gee-orm/day6-transaction/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day6-transaction/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day6-transaction/geeorm.go b/gee-orm/day6-transaction/geeorm.go new file mode 100644 index 0000000..a08fd46 --- /dev/null +++ b/gee-orm/day6-transaction/geeorm.go @@ -0,0 +1,75 @@ +package geeorm + +import ( + "database/sql" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} + +// TxFunc will be called between tx.Begin() and tx.Commit() +// https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback +type TxFunc func(*session.Session) (interface{}, error) + +// Transaction executes sql wrapped in a transaction, then automatically commit if no error occurs +func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) { + s := engine.NewSession() + if err := s.Begin(); err != nil { + return nil, err + } + defer func() { + if p := recover(); p != nil { + _ = s.Rollback() + panic(p) // re-throw panic after Rollback + } else if err != nil { + _ = s.Rollback() // err is non-nil; don't change it + } else { + err = s.Commit() // err is nil; if Commit returns error update err + } + }() + + return f(s) +} diff --git a/gee-orm/day6-transaction/geeorm_test.go b/gee-orm/day6-transaction/geeorm_test.go new file mode 100644 index 0000000..c3cf12a --- /dev/null +++ b/gee-orm/day6-transaction/geeorm_test.go @@ -0,0 +1,69 @@ +package geeorm + +import ( + "errors" + "geeorm/session" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func transactionRollback(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return nil, errors.New("Error") + }) + if err == nil || s.HasTable() { + t.Fatal("failed to rollback") + } +} + +func transactionCommit(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return + }) + u := &User{} + _ = s.First(u) + if err != nil || u.Name != "Tom" { + t.Fatal("failed to commit") + } +} + +func TestEngine_Transaction(t *testing.T) { + t.Run("rollback", func(t *testing.T) { + transactionRollback(t) + }) + t.Run("commit", func(t *testing.T) { + transactionCommit(t) + }) +} diff --git a/gee-orm/day6-transaction/go.mod b/gee-orm/day6-transaction/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day6-transaction/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day6-transaction/log/log.go b/gee-orm/day6-transaction/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day6-transaction/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day6-transaction/log/log_test.go b/gee-orm/day6-transaction/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day6-transaction/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day6-transaction/schema/schema.go b/gee-orm/day6-transaction/schema/schema.go new file mode 100644 index 0000000..2c9b927 --- /dev/null +++ b/gee-orm/day6-transaction/schema/schema.go @@ -0,0 +1,75 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day6-transaction/schema/schema_test.go b/gee-orm/day6-transaction/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day6-transaction/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day6-transaction/session/hooks.go b/gee-orm/day6-transaction/session/hooks.go new file mode 100644 index 0000000..d73c3c2 --- /dev/null +++ b/gee-orm/day6-transaction/session/hooks.go @@ -0,0 +1,35 @@ +package session + +import ( + "geeorm/log" + "reflect" +) + +// Hooks constants +const ( + BeforeQuery = "BeforeQuery" + AfterQuery = "AfterQuery" + BeforeUpdate = "BeforeUpdate" + AfterUpdate = "AfterUpdate" + BeforeDelete = "BeforeDelete" + AfterDelete = "AfterDelete" + BeforeInsert = "BeforeInsert" + AfterInsert = "AfterInsert" +) + +// CallMethod calls the registered hooks +func (s *Session) CallMethod(method string, value interface{}) { + fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method) + if value != nil { + fm = reflect.ValueOf(value).MethodByName(method) + } + param := []reflect.Value{reflect.ValueOf(s)} + if fm.IsValid() { + if v := fm.Call(param); len(v) > 0 { + if err, ok := v[0].Interface().(error); ok { + log.Error(err) + } + } + } + return +} diff --git a/gee-orm/day6-transaction/session/hooks_test.go b/gee-orm/day6-transaction/session/hooks_test.go new file mode 100644 index 0000000..f896d01 --- /dev/null +++ b/gee-orm/day6-transaction/session/hooks_test.go @@ -0,0 +1,37 @@ +package session + +import ( + "geeorm/log" + "testing" +) + +type Account struct { + ID int `geeorm:"PRIMARY KEY"` + Password string +} + +func (account *Account) BeforeInsert(s *Session) error { + log.Info("before inert", account) + account.ID += 1000 + return nil +} + +func (account *Account) AfterQuery(s *Session) error { + log.Info("after query", account) + account.Password = "******" + return nil +} + +func TestSession_CallMethod(t *testing.T) { + s := NewSession().Model(&Account{}) + _ = s.DropTable() + _ = s.CreateTable() + _, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"}) + + u := &Account{} + + err := s.First(u) + if err != nil || u.ID != 1001 || u.Password != "******" { + t.Fatal("Failed to call hooks after query, got", u) + } +} diff --git a/gee-orm/day6-transaction/session/raw.go b/gee-orm/day6-transaction/session/raw.go new file mode 100644 index 0000000..5bdd039 --- /dev/null +++ b/gee-orm/day6-transaction/session/raw.go @@ -0,0 +1,90 @@ +package session + +import ( + "database/sql" + "geeorm/clause" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + tx *sql.Tx + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} + +// CommonDB is a minimal function set of db +type CommonDB interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + Exec(query string, args ...interface{}) (sql.Result, error) +} + +var _ CommonDB = (*sql.DB)(nil) +var _ CommonDB = (*sql.Tx)(nil) + +// DB returns tx if a tx begins. otherwise return *sql.DB +func (s *Session) DB() CommonDB { + if s.tx != nil { + return s.tx + } + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} \ No newline at end of file diff --git a/gee-orm/day6-transaction/session/raw_test.go b/gee-orm/day6-transaction/session/raw_test.go new file mode 100644 index 0000000..404bb6e --- /dev/null +++ b/gee-orm/day6-transaction/session/raw_test.go @@ -0,0 +1,48 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day6-transaction/session/record.go b/gee-orm/day6-transaction/session/record.go new file mode 100644 index 0000000..fdfca4d --- /dev/null +++ b/gee-orm/day6-transaction/session/record.go @@ -0,0 +1,136 @@ +package session + +import ( + "errors" + "geeorm/clause" + "reflect" +) + +// Insert one or more records in database +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + s.CallMethod(BeforeInsert, value) + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterInsert, nil) + return result.RowsAffected() +} + +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + s.CallMethod(BeforeQuery, nil) + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + s.CallMethod(AfterQuery, dest.Addr().Interface()) + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} + +// First gets the 1st row +func (s *Session) First(value interface{}) error { + dest := reflect.Indirect(reflect.ValueOf(value)) + destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem() + if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil { + return err + } + if destSlice.Len() == 0 { + return errors.New("NOT FOUND") + } + dest.Set(destSlice.Index(0)) + return nil +} + +// Limit adds limit condition to clause +func (s *Session) Limit(num int) *Session { + s.clause.Set(clause.LIMIT, num) + return s +} + +// Where adds limit condition to clause +func (s *Session) Where(desc string, args ...interface{}) *Session { + var vars []interface{} + s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...) + return s +} + +// OrderBy adds order by condition to clause +func (s *Session) OrderBy(desc string) *Session { + s.clause.Set(clause.ORDERBY, desc) + return s +} + +// Update records with where clause +// support map[string]interface{} +// also support kv list: "Name", "Tom", "Age", 18, .... +func (s *Session) Update(kv ...interface{}) (int64, error) { + s.CallMethod(BeforeUpdate, nil) + m, ok := kv[0].(map[string]interface{}) + if !ok { + m = make(map[string]interface{}) + for i := 0; i < len(kv); i += 2 { + m[kv[i].(string)] = kv[i+1] + } + } + s.clause.Set(clause.UPDATE, s.RefTable().Name, m) + sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterUpdate, nil) + return result.RowsAffected() +} + +// Delete records with where clause +func (s *Session) Delete() (int64, error) { + s.CallMethod(BeforeDelete, nil) + s.clause.Set(clause.DELETE, s.RefTable().Name) + sql, vars := s.clause.Build(clause.DELETE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterDelete, nil) + return result.RowsAffected() +} + +// Count records with where clause +func (s *Session) Count() (int64, error) { + s.clause.Set(clause.COUNT, s.RefTable().Name) + sql, vars := s.clause.Build(clause.COUNT, clause.WHERE) + row := s.Raw(sql, vars...).QueryRow() + var tmp int64 + if err := row.Scan(&tmp); err != nil { + return 0, err + } + return tmp, nil +} diff --git a/gee-orm/day6-transaction/session/record_test.go b/gee-orm/day6-transaction/session/record_test.go new file mode 100644 index 0000000..5d482a0 --- /dev/null +++ b/gee-orm/day6-transaction/session/record_test.go @@ -0,0 +1,97 @@ +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} + +func TestSession_First(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.First(u) + if err != nil || u.Name != "Tom" || u.Age != 18 { + t.Fatal("failed to query first") + } +} + +func TestSession_Limit(t *testing.T) { + s := testRecordInit(t) + var users []User + err := s.Limit(1).Find(&users) + if err != nil || len(users) != 1 { + t.Fatal("failed to query with limit condition") + } +} + +func TestSession_Where(t *testing.T) { + s := testRecordInit(t) + var users []User + _, err1 := s.Insert(user3) + err2 := s.Where("Age = ?", 25).Find(&users) + + if err1 != nil || err2 != nil || len(users) != 2 { + t.Fatal("failed to query with where condition") + } +} + +func TestSession_OrderBy(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.OrderBy("Age DESC").First(u) + + if err != nil || u.Age != 25 { + t.Fatal("failed to query with order by condition") + } +} + +func TestSession_Update(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30) + u := &User{} + _ = s.OrderBy("Age DESC").First(u) + + if affected != 1 || u.Age != 30 { + t.Fatal("failed to update") + } +} + +func TestSession_DeleteAndCount(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Delete() + count, _ := s.Count() + + if affected != 1 || count != 1 { + t.Fatal("failed to delete or count") + } +} diff --git a/gee-orm/day6-transaction/session/table.go b/gee-orm/day6-transaction/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day6-transaction/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day6-transaction/session/table_test.go b/gee-orm/day6-transaction/session/table_test.go new file mode 100644 index 0000000..3bb7554 --- /dev/null +++ b/gee-orm/day6-transaction/session/table_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day6-transaction/session/transaction.go b/gee-orm/day6-transaction/session/transaction.go new file mode 100644 index 0000000..3cdb451 --- /dev/null +++ b/gee-orm/day6-transaction/session/transaction.go @@ -0,0 +1,31 @@ +package session + +import "geeorm/log" + +// Begin a transaction +func (s *Session) Begin() (err error) { + log.Info("transaction begin") + if s.tx, err = s.db.Begin(); err != nil { + log.Error(err) + return + } + return +} + +// Commit a transaction +func (s *Session) Commit() (err error) { + log.Info("transaction commit") + if err = s.tx.Commit(); err != nil { + log.Error(err) + } + return +} + +// Rollback a transaction +func (s *Session) Rollback() (err error) { + log.Info("transaction rollback") + if err = s.tx.Rollback(); err != nil { + log.Error(err) + } + return +} diff --git a/gee-orm/day7-migrate/clause/clause.go b/gee-orm/day7-migrate/clause/clause.go new file mode 100644 index 0000000..02fcf93 --- /dev/null +++ b/gee-orm/day7-migrate/clause/clause.go @@ -0,0 +1,51 @@ +package clause + +import ( + "strings" +) + +// Clause contains SQL conditions +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +// Type is the type of Clause +type Type int + +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY + UPDATE + DELETE + COUNT +) + +// Set adds a sub clause of specific type +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +// Build generate the final SQL and SQLVars +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} diff --git a/gee-orm/day7-migrate/clause/clause_test.go b/gee-orm/day7-migrate/clause/clause_test.go new file mode 100644 index 0000000..62e0ccb --- /dev/null +++ b/gee-orm/day7-migrate/clause/clause_test.go @@ -0,0 +1,74 @@ +package clause + +import ( + "reflect" + "testing" +) + +func TestClause_Set(t *testing.T) { + var clause Clause + clause.Set(INSERT, "User", []string{"Name", "Age"}) + sql := clause.sql[INSERT] + vars := clause.sqlVars[INSERT] + t.Log(sql, vars) + if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 { + t.Fatal("failed to get clause") + } +} + +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func testUpdate(t *testing.T) { + var clause Clause + clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30}) + clause.Set(WHERE, "Name = ?", "Tom") + sql, vars := clause.Build(UPDATE, WHERE) + t.Log(sql, vars) + if sql != "UPDATE User SET Age = ? WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func testDelete(t *testing.T) { + var clause Clause + clause.Set(DELETE, "User") + clause.Set(WHERE, "Name = ?", "Tom") + + sql, vars := clause.Build(DELETE, WHERE) + t.Log(sql, vars) + if sql != "DELETE FROM User WHERE Name = ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom"}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) + t.Run("update", func(t *testing.T) { + testUpdate(t) + }) + t.Run("delete", func(t *testing.T) { + testDelete(t) + }) +} diff --git a/gee-orm/day7-migrate/clause/generator.go b/gee-orm/day7-migrate/clause/generator.go new file mode 100644 index 0000000..23635ba --- /dev/null +++ b/gee-orm/day7-migrate/clause/generator.go @@ -0,0 +1,101 @@ +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy + generators[UPDATE] = _update + generators[DELETE] = _delete + generators[COUNT] = _count +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} + +func _update(values ...interface{}) (string, []interface{}) { + tableName := values[0] + m := values[1].(map[string]interface{}) + var keys []string + var vars []interface{} + for k, v := range m { + keys = append(keys, k+" = ?") + vars = append(vars, v) + } + return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars +} + +func _delete(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{} +} + +func _count(values ...interface{}) (string, []interface{}) { + return _select(values[0], []string{"count(*)"}) +} diff --git a/gee-orm/day7-migrate/dialect/dialect.go b/gee-orm/day7-migrate/dialect/dialect.go new file mode 100644 index 0000000..4696314 --- /dev/null +++ b/gee-orm/day7-migrate/dialect/dialect.go @@ -0,0 +1,22 @@ +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +// Dialect is an interface contains methods that a dialect has to implement +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +// RegisterDialect register a dialect to the global variable +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// Get the dialect from global variable if it exists +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} diff --git a/gee-orm/day7-migrate/dialect/sqlite3.go b/gee-orm/day7-migrate/dialect/sqlite3.go new file mode 100644 index 0000000..f3c3897 --- /dev/null +++ b/gee-orm/day7-migrate/dialect/sqlite3.go @@ -0,0 +1,45 @@ +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +// Get Data Type for sqlite3 Dialect +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +// TableExistSQL returns SQL that judge whether the table exists in database +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} diff --git a/gee-orm/day7-migrate/dialect/sqlite3_test.go b/gee-orm/day7-migrate/dialect/sqlite3_test.go new file mode 100644 index 0000000..3df5f07 --- /dev/null +++ b/gee-orm/day7-migrate/dialect/sqlite3_test.go @@ -0,0 +1,25 @@ +package dialect + +import ( + "reflect" + "testing" +) + +func TestDataTypeOf(t *testing.T) { + dial := &sqlite3{} + cases := []struct { + Value interface{} + Type string + }{ + {"Tom", "text"}, + {123, "integer"}, + {1.2, "real"}, + {[]int{1, 2, 3}, "blob"}, + } + + for _, c := range cases { + if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type { + t.Fatalf("expect %s, but got %s", c.Type, typ) + } + } +} diff --git a/gee-orm/day7-migrate/geeorm.go b/gee-orm/day7-migrate/geeorm.go new file mode 100644 index 0000000..fe6d477 --- /dev/null +++ b/gee-orm/day7-migrate/geeorm.go @@ -0,0 +1,128 @@ +package geeorm + +import ( + "database/sql" + "fmt" + "geeorm/dialect" + "geeorm/log" + "geeorm/session" + "strings" +) + +// Engine is the main struct of geeorm, manages all db sessions and transactions. +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +// NewEngine create a instance of Engine +// connect database and ping it to test whether it's alive +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +// Close database connection +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + return + } + log.Info("Close database success") +} + +// NewSession creates a new session for next operations +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} + +// TxFunc will be called between tx.Begin() and tx.Commit() +// https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback +type TxFunc func(*session.Session) (interface{}, error) + +// Transaction executes sql wrapped in a transaction, then automatically commit if no error occurs +func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) { + s := engine.NewSession() + if err := s.Begin(); err != nil { + return nil, err + } + defer func() { + if p := recover(); p != nil { + _ = s.Rollback() + panic(p) // re-throw panic after Rollback + } else if err != nil { + _ = s.Rollback() // err is non-nil; don't change it + } else { + err = s.Commit() // err is nil; if Commit returns error update err + } + }() + + return f(s) +} + +// difference returns a - b +func difference(a []string, b []string) (diff []string) { + mapB := make(map[string]bool) + for _, v := range b { + mapB[v] = true + } + for _, v := range a { + if _, ok := mapB[v]; !ok { + diff = append(diff, v) + } + } + return +} + +// Migrate table +func (engine *Engine) Migrate(value interface{}) error { + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + if !s.Model(value).HasTable() { + log.Infof("table %s doesn't exist", s.RefTable().Name) + return nil, s.CreateTable() + } + table := s.RefTable() + rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows() + columns, _ := rows.Columns() + addCols := difference(table.FieldNames, columns) + delCols := difference(columns, table.FieldNames) + log.Infof("added cols %v, deleted cols %v", addCols, delCols) + + for _, col := range addCols { + f := table.GetField(col) + sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type) + if _, err = s.Raw(sqlStr).Exec(); err != nil { + return + } + } + + if len(delCols) == 0 { + return + } + tmp := "tmp_" + table.Name + fieldStr := strings.Join(table.FieldNames, ", ") + s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name)) + s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name)) + s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name)) + _, err = s.Exec() + return + }) + return err +} diff --git a/gee-orm/day7-migrate/geeorm_test.go b/gee-orm/day7-migrate/geeorm_test.go new file mode 100644 index 0000000..4ccacf7 --- /dev/null +++ b/gee-orm/day7-migrate/geeorm_test.go @@ -0,0 +1,86 @@ +package geeorm + +import ( + "errors" + "geeorm/session" + "reflect" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +func TestNewEngine(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() +} + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func transactionRollback(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return nil, errors.New("Error") + }) + if err == nil || s.HasTable() { + t.Fatal("failed to rollback") + } +} + +func transactionCommit(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return + }) + u := &User{} + _ = s.First(u) + if err != nil || u.Name != "Tom" { + t.Fatal("failed to commit") + } +} + +func TestEngine_Transaction(t *testing.T) { + t.Run("rollback", func(t *testing.T) { + transactionRollback(t) + }) + t.Run("commit", func(t *testing.T) { + transactionCommit(t) + }) +} + +func TestEngine_Migrate(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text PRIMARY KEY, XXX integer);").Exec() + _, _ = s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + engine.Migrate(&User{}) + + rows, _ := s.Raw("SELECT * FROM User").QueryRows() + columns, _ := rows.Columns() + if !reflect.DeepEqual(columns, []string{"Name", "Age"}) { + t.Fatal("Failed to migrate table User, got columns", columns) + } +} diff --git a/gee-orm/day7-migrate/go.mod b/gee-orm/day7-migrate/go.mod new file mode 100644 index 0000000..043b1c6 --- /dev/null +++ b/gee-orm/day7-migrate/go.mod @@ -0,0 +1,5 @@ +module geeorm + +go 1.13 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/gee-orm/day7-migrate/log/log.go b/gee-orm/day7-migrate/log/log.go new file mode 100644 index 0000000..eacc0c6 --- /dev/null +++ b/gee-orm/day7-migrate/log/log.go @@ -0,0 +1,47 @@ +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) + +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} diff --git a/gee-orm/day7-migrate/log/log_test.go b/gee-orm/day7-migrate/log/log_test.go new file mode 100644 index 0000000..8cd403c --- /dev/null +++ b/gee-orm/day7-migrate/log/log_test.go @@ -0,0 +1,17 @@ +package log + +import ( + "os" + "testing" +) + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout { + t.Fatal("failed to set log level") + } + SetLevel(Disabled) + if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout { + t.Fatal("failed to set log level") + } +} \ No newline at end of file diff --git a/gee-orm/day7-migrate/schema/schema.go b/gee-orm/day7-migrate/schema/schema.go new file mode 100644 index 0000000..2c9b927 --- /dev/null +++ b/gee-orm/day7-migrate/schema/schema.go @@ -0,0 +1,75 @@ +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +// GetField returns field by name +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} + +// Values return the values of dest's member variables +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} + +type ITableName interface { + TableName() string +} + +// Parse a struct to a Schema instance +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + var tableName string + t, ok := dest.(ITableName) + if !ok { + tableName = modelType.Name() + } else { + tableName = t.TableName() + } + schema := &Schema{ + Model: dest, + Name: tableName, + fieldMap: make(map[string]*Field), + } + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} diff --git a/gee-orm/day7-migrate/schema/schema_test.go b/gee-orm/day7-migrate/schema/schema_test.go new file mode 100644 index 0000000..8f625cb --- /dev/null +++ b/gee-orm/day7-migrate/schema/schema_test.go @@ -0,0 +1,51 @@ +package schema + +import ( + "geeorm/dialect" + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} + +func TestSchema_RecordValues(t *testing.T) { + schema := Parse(&User{}, TestDial) + values := schema.RecordValues(&User{"Tom", 18}) + + name := values[0].(string) + age := values[1].(int) + + if name != "Tom" || age != 18 { + t.Fatal("failed to get values") + } +} + +type UserTest struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func (u *UserTest) TableName() string { + return "ns_user_test" +} + +func TestSchema_TableName(t *testing.T) { + schema := Parse(&UserTest{}, TestDial) + if schema.Name != "ns_user_test" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } +} diff --git a/gee-orm/day7-migrate/session/hooks.go b/gee-orm/day7-migrate/session/hooks.go new file mode 100644 index 0000000..d73c3c2 --- /dev/null +++ b/gee-orm/day7-migrate/session/hooks.go @@ -0,0 +1,35 @@ +package session + +import ( + "geeorm/log" + "reflect" +) + +// Hooks constants +const ( + BeforeQuery = "BeforeQuery" + AfterQuery = "AfterQuery" + BeforeUpdate = "BeforeUpdate" + AfterUpdate = "AfterUpdate" + BeforeDelete = "BeforeDelete" + AfterDelete = "AfterDelete" + BeforeInsert = "BeforeInsert" + AfterInsert = "AfterInsert" +) + +// CallMethod calls the registered hooks +func (s *Session) CallMethod(method string, value interface{}) { + fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method) + if value != nil { + fm = reflect.ValueOf(value).MethodByName(method) + } + param := []reflect.Value{reflect.ValueOf(s)} + if fm.IsValid() { + if v := fm.Call(param); len(v) > 0 { + if err, ok := v[0].Interface().(error); ok { + log.Error(err) + } + } + } + return +} diff --git a/gee-orm/day7-migrate/session/hooks_test.go b/gee-orm/day7-migrate/session/hooks_test.go new file mode 100644 index 0000000..f896d01 --- /dev/null +++ b/gee-orm/day7-migrate/session/hooks_test.go @@ -0,0 +1,37 @@ +package session + +import ( + "geeorm/log" + "testing" +) + +type Account struct { + ID int `geeorm:"PRIMARY KEY"` + Password string +} + +func (account *Account) BeforeInsert(s *Session) error { + log.Info("before inert", account) + account.ID += 1000 + return nil +} + +func (account *Account) AfterQuery(s *Session) error { + log.Info("after query", account) + account.Password = "******" + return nil +} + +func TestSession_CallMethod(t *testing.T) { + s := NewSession().Model(&Account{}) + _ = s.DropTable() + _ = s.CreateTable() + _, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"}) + + u := &Account{} + + err := s.First(u) + if err != nil || u.ID != 1001 || u.Password != "******" { + t.Fatal("Failed to call hooks after query, got", u) + } +} diff --git a/gee-orm/day7-migrate/session/raw.go b/gee-orm/day7-migrate/session/raw.go new file mode 100644 index 0000000..5bdd039 --- /dev/null +++ b/gee-orm/day7-migrate/session/raw.go @@ -0,0 +1,90 @@ +package session + +import ( + "database/sql" + "geeorm/clause" + "geeorm/dialect" + "geeorm/log" + "geeorm/schema" + "strings" +) + +// Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. +type Session struct { + db *sql.DB + dialect dialect.Dialect + tx *sql.Tx + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// New creates a instance of Session +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} + +// Clear initialize the state of a session +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} + +// CommonDB is a minimal function set of db +type CommonDB interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + Exec(query string, args ...interface{}) (sql.Result, error) +} + +var _ CommonDB = (*sql.DB)(nil) +var _ CommonDB = (*sql.Tx)(nil) + +// DB returns tx if a tx begins. otherwise return *sql.DB +func (s *Session) DB() CommonDB { + if s.tx != nil { + return s.tx + } + return s.db +} + +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// Raw appends sql and sqlVars +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} \ No newline at end of file diff --git a/gee-orm/day7-migrate/session/raw_test.go b/gee-orm/day7-migrate/session/raw_test.go new file mode 100644 index 0000000..404bb6e --- /dev/null +++ b/gee-orm/day7-migrate/session/raw_test.go @@ -0,0 +1,48 @@ +package session + +import ( + "database/sql" + "os" + "testing" + + "geeorm/dialect" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + TestDB *sql.DB + TestDial, _ = dialect.GetDialect("sqlite3") +) + +func TestMain(m *testing.M) { + TestDB, _ = sql.Open("sqlite3", "../gee.db") + code := m.Run() + _ = TestDB.Close() + os.Exit(code) +} + +func NewSession() *Session { + return New(TestDB, TestDial) +} + +func TestSession_Exec(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + if count, err := result.RowsAffected(); err != nil || count != 2 { + t.Fatal("expect 2, but got", count) + } +} + +func TestSession_QueryRows(t *testing.T) { + s := NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + row := s.Raw("SELECT count(*) FROM User").QueryRow() + var count int + if err := row.Scan(&count); err != nil || count != 0 { + t.Fatal("failed to query db", err) + } +} diff --git a/gee-orm/day7-migrate/session/record.go b/gee-orm/day7-migrate/session/record.go new file mode 100644 index 0000000..fdfca4d --- /dev/null +++ b/gee-orm/day7-migrate/session/record.go @@ -0,0 +1,136 @@ +package session + +import ( + "errors" + "geeorm/clause" + "reflect" +) + +// Insert one or more records in database +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + s.CallMethod(BeforeInsert, value) + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterInsert, nil) + return result.RowsAffected() +} + +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + s.CallMethod(BeforeQuery, nil) + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + s.CallMethod(AfterQuery, dest.Addr().Interface()) + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} + +// First gets the 1st row +func (s *Session) First(value interface{}) error { + dest := reflect.Indirect(reflect.ValueOf(value)) + destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem() + if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil { + return err + } + if destSlice.Len() == 0 { + return errors.New("NOT FOUND") + } + dest.Set(destSlice.Index(0)) + return nil +} + +// Limit adds limit condition to clause +func (s *Session) Limit(num int) *Session { + s.clause.Set(clause.LIMIT, num) + return s +} + +// Where adds limit condition to clause +func (s *Session) Where(desc string, args ...interface{}) *Session { + var vars []interface{} + s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...) + return s +} + +// OrderBy adds order by condition to clause +func (s *Session) OrderBy(desc string) *Session { + s.clause.Set(clause.ORDERBY, desc) + return s +} + +// Update records with where clause +// support map[string]interface{} +// also support kv list: "Name", "Tom", "Age", 18, .... +func (s *Session) Update(kv ...interface{}) (int64, error) { + s.CallMethod(BeforeUpdate, nil) + m, ok := kv[0].(map[string]interface{}) + if !ok { + m = make(map[string]interface{}) + for i := 0; i < len(kv); i += 2 { + m[kv[i].(string)] = kv[i+1] + } + } + s.clause.Set(clause.UPDATE, s.RefTable().Name, m) + sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterUpdate, nil) + return result.RowsAffected() +} + +// Delete records with where clause +func (s *Session) Delete() (int64, error) { + s.CallMethod(BeforeDelete, nil) + s.clause.Set(clause.DELETE, s.RefTable().Name) + sql, vars := s.clause.Build(clause.DELETE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + s.CallMethod(AfterDelete, nil) + return result.RowsAffected() +} + +// Count records with where clause +func (s *Session) Count() (int64, error) { + s.clause.Set(clause.COUNT, s.RefTable().Name) + sql, vars := s.clause.Build(clause.COUNT, clause.WHERE) + row := s.Raw(sql, vars...).QueryRow() + var tmp int64 + if err := row.Scan(&tmp); err != nil { + return 0, err + } + return tmp, nil +} diff --git a/gee-orm/day7-migrate/session/record_test.go b/gee-orm/day7-migrate/session/record_test.go new file mode 100644 index 0000000..5d482a0 --- /dev/null +++ b/gee-orm/day7-migrate/session/record_test.go @@ -0,0 +1,97 @@ +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} + +func TestSession_First(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.First(u) + if err != nil || u.Name != "Tom" || u.Age != 18 { + t.Fatal("failed to query first") + } +} + +func TestSession_Limit(t *testing.T) { + s := testRecordInit(t) + var users []User + err := s.Limit(1).Find(&users) + if err != nil || len(users) != 1 { + t.Fatal("failed to query with limit condition") + } +} + +func TestSession_Where(t *testing.T) { + s := testRecordInit(t) + var users []User + _, err1 := s.Insert(user3) + err2 := s.Where("Age = ?", 25).Find(&users) + + if err1 != nil || err2 != nil || len(users) != 2 { + t.Fatal("failed to query with where condition") + } +} + +func TestSession_OrderBy(t *testing.T) { + s := testRecordInit(t) + u := &User{} + err := s.OrderBy("Age DESC").First(u) + + if err != nil || u.Age != 25 { + t.Fatal("failed to query with order by condition") + } +} + +func TestSession_Update(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30) + u := &User{} + _ = s.OrderBy("Age DESC").First(u) + + if affected != 1 || u.Age != 30 { + t.Fatal("failed to update") + } +} + +func TestSession_DeleteAndCount(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Delete() + count, _ := s.Count() + + if affected != 1 || count != 1 { + t.Fatal("failed to delete or count") + } +} diff --git a/gee-orm/day7-migrate/session/table.go b/gee-orm/day7-migrate/session/table.go new file mode 100644 index 0000000..58e7b0f --- /dev/null +++ b/gee-orm/day7-migrate/session/table.go @@ -0,0 +1,54 @@ +package session + +import ( + "fmt" + "geeorm/log" + "reflect" + "strings" + + "geeorm/schema" +) + +// Model assigns refTable +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +// RefTable returns a Schema instance that contains all parsed fields +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} + +// CreateTable create a table in database with a model +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +// DropTable drops a table with the name of model +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +// HasTable returns true of the table exists +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} diff --git a/gee-orm/day7-migrate/session/table_test.go b/gee-orm/day7-migrate/session/table_test.go new file mode 100644 index 0000000..3bb7554 --- /dev/null +++ b/gee-orm/day7-migrate/session/table_test.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" +) + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} + +func TestSession_Model(t *testing.T) { + s := NewSession().Model(&User{}) + table := s.RefTable() + s.Model(&Session{}) + if table.Name != "User" || s.RefTable().Name != "Session" { + t.Fatal("Failed to change model") + } +} diff --git a/gee-orm/day7-migrate/session/transaction.go b/gee-orm/day7-migrate/session/transaction.go new file mode 100644 index 0000000..3cdb451 --- /dev/null +++ b/gee-orm/day7-migrate/session/transaction.go @@ -0,0 +1,31 @@ +package session + +import "geeorm/log" + +// Begin a transaction +func (s *Session) Begin() (err error) { + log.Info("transaction begin") + if s.tx, err = s.db.Begin(); err != nil { + log.Error(err) + return + } + return +} + +// Commit a transaction +func (s *Session) Commit() (err error) { + log.Info("transaction commit") + if err = s.tx.Commit(); err != nil { + log.Error(err) + } + return +} + +// Rollback a transaction +func (s *Session) Rollback() (err error) { + log.Info("transaction rollback") + if err = s.tx.Rollback(); err != nil { + log.Error(err) + } + return +} diff --git a/gee-orm/doc/geeorm-day1.md b/gee-orm/doc/geeorm-day1.md new file mode 100644 index 0000000..9fe3adf --- /dev/null +++ b/gee-orm/doc/geeorm-day1.md @@ -0,0 +1,416 @@ +--- +title: 动手写ORM框架 - GeeORM第一天 database/sql 基础 +date: 2020-03-07 23:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。介绍了 SQLite 的基础操作(连接数据库,创建表、增删记录等),使用 Go 标准库 database/sql 操作 SQLite 数据库,包括执行(Exec),查询(Query, QueryRow)。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day1 database/sql 基础 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第一篇。介绍了 + +- SQLite 的基础操作(连接数据库,创建表、增删记录等)。 +- 使用 Go 语言标准库 database/sql 连接并操作 SQLite 数据库,并简单封装。**代码约150行** + +## 1 初识 SQLite + +> SQLite is a C-language library that implements a small, fast, self-contained, high-reliability, full-featured, SQL database engine. +> -- [SQLite 官网](https://sqlite.org/index.html) + +SQLite 是一款轻量级的,遵守 ACID 事务原则的关系型数据库。SQLite 可以直接嵌入到代码中,不需要像 MySQL、PostgreSQL 需要启动独立的服务才能使用。SQLite 将数据存储在单一的磁盘文件中,使用起来非常方便。也非常适合初学者用来学习关系型数据的使用。GeeORM 的所有的开发和测试均基于 SQLite。 + +在 Ubuntu 上,安装 SQLite 只需要一行命令,无需配置即可使用。 + +```bash +apt-get install sqlite3 +``` + +接下来,连接数据库(gee.db),如若 gee.db 不存在,则会新建。如果连接成功,就进入到了 SQLite 的命令行模式,执行 `.help` 可以看到所有的帮助命令。 + +```bash +> sqlite3 gee.db +SQLite version 3.22.0 2018-01-22 18:45:57 +Enter ".help" for usage hints. +sqlite> +``` + +使用 SQL 语句新建一张表 `User`,包含两个字段,字符串 Name 和 整型 Age。 + +```bash +sqlite> CREATE TABLE User(Name text, Age integer); +``` + +插入两条数据 + +```bash +sqlite> INSERT INTO User(Name, Age) VALUES ("Tom", 18), ("Jack", 25); +``` + +执行简单的查询操作,在执行之前使用 `.head on` 打开显示列名的开关,这样查询结果看上去更直观。 + +```bash +sqlite> .head on + +# 查找 `Age > 20` 的记录; +sqlite> SELECT * FROM User WHERE Age > 20; +Name|Age +Jack|25 + +# 统计记录个数。 +sqlite> SELECT COUNT(*) FROM User; +COUNT(*) +2 +``` + +使用 `.table` 查看当前数据库中所有的表(table),执行 `.schema ` 查看建表的 SQL 语句。 + +```bash +sqlite> .table +User + +sqlite> .schema User +CREATE TABLE User(Name text, Age integer); +``` + +SQLite 的使用暂时介绍这么多,了解了以上使用方法已经足够我们完成今天的任务了。如果想了解更多用法,可参考 [SQLite 常用命令](https://geektutu.com/post/cheat-sheet-sqlite.html)。 + + +## 2 database/sql 标准库 + +Go 语言提供了标准库 `database/sql` 用于和数据库的交互,接下来我们写一个 Demo,看一看这个库的用法。 + +```go +package main + +import ( + "database/sql" + "log" + + _ "github.com/mattn/go-sqlite3" +) + +func main() { + db, _ := sql.Open("sqlite3", "gee.db") + defer func() { _ = db.Close() }() + _, _ = db.Exec("DROP TABLE IF EXISTS User;") + _, _ = db.Exec("CREATE TABLE User(Name text);") + result, err := db.Exec("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam") + if err == nil { + affected, _ := result.RowsAffected() + log.Println(affected) + } + row := db.QueryRow("SELECT Name FROM User LIMIT 1") + var name string + if err := row.Scan(&name); err == nil { + log.Println(name) + } +} +``` + +> go-sqlite3 依赖于 gcc,如果这份代码在 Windows 上运行的话,需要安装 [mingw](http://mingw.org/) 或其他包含有 gcc 编译器的工具包。 + +执行 `go run .`,输出如下。 + +```bash +> go run . +2020/03/07 20:28:37 2 +2020/03/07 20:28:37 Tom +``` + +- 使用 `sql.Open()` 连接数据库,第一个参数是驱动名称,import 语句 `_ "github.com/mattn/go-sqlite3"` 包导入时会注册 sqlite3 的驱动,第二个参数是数据库的名称,对于 SQLite 来说,也就是文件名,不存在会新建。返回一个 `sql.DB` 实例的指针。 +- `Exec()` 用于执行 SQL 语句,如果是查询语句,不会返回相关的记录。所以查询语句通常使用 `Query()` 和 `QueryRow()`,前者可以返回多条记录,后者只返回一条记录。 +- `Exec()`、`Query()`、`QueryRow()` 接受1或多个入参,第一个入参是 SQL 语句,后面的入参是 SQL 语句中的占位符 `?` 对应的值,占位符一般用来防 SQL 注入。 +- `QueryRow()` 的返回值类型是 `*sql.Row`,`row.Scan()` 接受1或多个指针作为参数,可以获取对应列(column)的值,在这个示例中,只有 `Name` 一列,因此传入字符串指针 `&name` 即可获取到查询的结果。 + + +掌握了基础的 SQL 语句和 Go 标准库 `database/sql` 的使用,可以开始实现 ORM 框架的雏形了。 + +## 3 实现一个简单的 log 库 + +开发一个框架/库并不容易,详细的日志能够帮助我们快速地定位问题。因此,在写核心代码之前,我们先用几十行代码实现一个简单的 log 库。 + +> 为什么不直接使用原生的 log 库呢?log 标准库没有日志分级,不打印文件和行号,这就意味着我们很难快速知道是哪个地方发生了错误。 + +这个简易的 log 库具备以下特性: + +- 支持日志分级(Info、Error、Disabled 三级)。 +- 不同层级日志显示时使用不同的颜色区分。 +- 显示打印日志代码对应的文件名和行号。 + +```bash +go mod init geeorm +``` + +首先创建一个名为 geeorm 的 module,并新建文件 log/log.go,用于放置和日志相关的代码。GeeORM 现在长这个样子: + +```bash +day1-database-sql/ + |--log/ + |--log.go + |--go.mod +``` + +第一步,创建 2 个日志实例分别用于打印 Info 和 Error 日志。 + +[day1-database-sql/log/log.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/log) + +```go +package log + +import ( + "io/ioutil" + "log" + "os" + "sync" +) + +var ( + errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile) + infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile) + loggers = []*log.Logger{errorLog, infoLog} + mu sync.Mutex +) + +// log methods +var ( + Error = errorLog.Println + Errorf = errorLog.Printf + Info = infoLog.Println + Infof = infoLog.Printf +) +``` + +- `[info ]` 颜色为蓝色,`[error]` 为红色。 +- 使用 `log.Lshortfile` 支持显示文件名和代码行号。 +- 暴露 `Error`,`Errorf`,`Info`,`Infof` 4个方法。 + +第二步呢,支持设置日志的层级(InfoLevel, ErrorLevel, Disabled)。 + +```go +// log levels +const ( + InfoLevel = iota + ErrorLevel + Disabled +) + +// SetLevel controls log level +func SetLevel(level int) { + mu.Lock() + defer mu.Unlock() + + for _, logger := range loggers { + logger.SetOutput(os.Stdout) + } + + if ErrorLevel < level { + errorLog.SetOutput(ioutil.Discard) + } + if InfoLevel < level { + infoLog.SetOutput(ioutil.Discard) + } +} +``` + +- 这一部分的实现非常简单,三个层级声明为三个常量,通过控制 `Output`,来控制日志是否打印。 +- 如果设置为 ErrorLevel,infoLog 的输出会被定向到 `ioutil.Discard`,即不打印该日志。 + +至此呢,一个简单的支持分级的 log 库就实现完成了。 + +## 4 核心结构 Session + +我们在根目录下新建一个文件夹 session,用于实现与数据库的交互。今天我们只实现直接调用 SQL 语句进行原生交互的部分,这部分代码实现在 `session/raw.go` 中。 + +[day1-database-sql/session/raw.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/session) + +```go +package session + +import ( + "database/sql" + "geeorm/log" + "strings" +) + +type Session struct { + db *sql.DB + sql strings.Builder + sqlVars []interface{} +} + +func New(db *sql.DB) *Session { + return &Session{db: db} +} + +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil +} + +func (s *Session) DB() *sql.DB { + return s.db +} + +func (s *Session) Raw(sql string, values ...interface{}) *Session { + s.sql.WriteString(sql) + s.sql.WriteString(" ") + s.sqlVars = append(s.sqlVars, values...) + return s +} +``` + +- Session 结构体目前只包含三个成员变量,第一个是 `db *sql.DB`,即使用 `sql.Open()` 方法连接数据库成功之后返回的指针。 +- 第二个和第三个成员变量用来拼接 SQL 语句和 SQL 语句中占位符的对应值。用户调用 `Raw()` 方法即可改变这两个变量的值。 + +接下来呢,封装 `Exec()`、`Query()` 和 `QueryRow()` 三个原生方法。 + +```go +// Exec raw sql with sqlVars +func (s *Session) Exec() (result sql.Result, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} + +// QueryRow gets a record from db +func (s *Session) QueryRow() *sql.Row { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + return s.DB().QueryRow(s.sql.String(), s.sqlVars...) +} + +// QueryRows gets a list of records from db +func (s *Session) QueryRows() (rows *sql.Rows, err error) { + defer s.Clear() + log.Info(s.sql.String(), s.sqlVars) + if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil { + log.Error(err) + } + return +} +``` + +- 封装有 2 个目的,一是统一打印日志(包括 执行的SQL 语句和错误日志)。 +- 二是执行完成后,清空 `(s *Session).sql` 和 `(s *Session).sqlVars` 两个变量。这样 Session 可以复用,开启一次会话,可以执行多次 SQL。 + +## 5 核心结构 Engine + +Session 负责与数据库的交互,那交互前的准备工作(比如连接/测试数据库),交互后的收尾工作(关闭连接)等就交给 Engine 来负责了。Engine 是 GeeORM 与用户交互的入口。代码位于根目录的 `geeorm.go`。 + +[day1-database-sql/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql) + +```go +package geeorm + +import ( + "database/sql" + + "geeorm/log" + "geeorm/session" +) + +type Engine struct { + db *sql.DB +} + +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + e = &Engine{db: db} + log.Info("Connect database success") + return +} + +func (engine *Engine) Close() { + if err := engine.db.Close(); err != nil { + log.Error("Failed to close database") + } + log.Info("Close database success") +} + +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db) +} +``` + +Engine 的逻辑非常简单,最重要的方法是 `NewEngine`,`NewEngine` 主要做了两件事。 + +- 连接数据库,返回 `*sql.DB`。 +- 调用 `db.Ping()`,检查数据库是否能够正常连接。 + +另外呢,提供了 Engine 提供了 `NewSession()` 方法,这样可以通过 `Engine` 实例创建会话,进而与数据库进行交互了。到这一步,整个 GeeORM 的框架雏形已经出来了。 + +```bash +day1-database-sql/ + |--log/ # 日志 + |--log.go + |--session/ # 数据库交互 + |--raw.go + |--geeorm.go # 用户交互 + |--go.mod +``` + +## 6 测试 + +GeeORM 的单元测试是比较完备的,可以参考 `log_test.go`、`raw_test.go` 和 `geeorm_test.go` 等几个测试文件,在这里呢,就不一一讲解了。接下来呢,我们将 geeorm 视为第三方库来使用。 + +在根目录下新建 cmd_test 目录放置测试代码,新建文件 main.go。 + +[day1-database-sql/cmd_test/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/cmd_test) + +```go +package main + +import ( + "geeorm" + "geeorm/log" + + _ "github.com/mattn/go-sqlite3" +) + +func main() { + engine, _ := geeorm.NewEngine("sqlite3", "gee.db") + defer engine.Close() + s := engine.NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text);").Exec() + result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + count, _ := result.RowsAffected() + fmt.Printf("Exec success, %d affected\n", count) +} +``` + +执行 `go run main.go`,将会看到如下的输出: + +![geeorm log](geeorm-day1/geeorm_log.png) + +日志中出现了一行报错信息,*table User already exists*,因为我们在 main 函数中执行了两次创建表 `User` 的语句。可以看到,每一行日志均标明了报错的文件和行号,而且不同层级日志的颜色是不同的。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day1/geeorm_log.png b/gee-orm/doc/geeorm-day1/geeorm_log.png new file mode 100755 index 0000000..dc58aaf Binary files /dev/null and b/gee-orm/doc/geeorm-day1/geeorm_log.png differ diff --git a/gee-orm/doc/geeorm-day2.md b/gee-orm/doc/geeorm-day2.md new file mode 100644 index 0000000..f4a1794 --- /dev/null +++ b/gee-orm/doc/geeorm-day2.md @@ -0,0 +1,390 @@ +--- +title: 动手写ORM框架 - GeeORM第二天 对象表结构映射 +date: 2020-03-08 00:20:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。使用反射(reflect)获取任意 struct 对象的名称和字段,映射为数据中的表;使用 dialect 隔离不同数据库之间的差异,便于扩展;数据库表的创建(create)、删除(drop)。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- reflect +- table mapping +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day2 对象表结构映射 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第二篇。 + +- 使用 dialect 隔离不同数据库之间的差异,便于扩展。 +- 使用反射(reflect)获取任意 struct 对象的名称和字段,映射为数据中的表。 +- 数据库表的创建(create)、删除(drop)。**代码约150行** + +## 1 Dialect + +SQL 语句中的类型和 Go 语言中的类型是不同的,例如Go 语言中的 `int`、`int8`、`int16` 等类型均对应 SQLite 中的 `integer` 类型。因此实现 ORM 映射的第一步,需要思考如何将 Go 语言的类型映射为数据库中的类型。 + +同时,不同数据库支持的数据类型也是有差异的,即使功能相同,在 SQL 语句的表达上也可能有差异。ORM 框架往往需要兼容多种数据库,因此我们需要将差异的这一部分提取出来,每一种数据库分别实现,实现最大程度的复用和解耦。这部分代码称之为 `dialect`。 + +在根目录下新建文件夹 dialect,并在 dialect 文件夹下新建文件 `dialect.go`,抽象出各个数据库差异的部分。 + +[day2-reflect-schema/dialect/dialect.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/dialect) + +```go +package dialect + +import "reflect" + +var dialectsMap = map[string]Dialect{} + +type Dialect interface { + DataTypeOf(typ reflect.Value) string + TableExistSQL(tableName string) (string, []interface{}) +} + +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} +``` + +`Dialect` 接口包含 2 个方法: + +- `DataTypeOf` 用于将 Go 语言的类型转换为该数据库的数据类型。 +- `TableExistSQL` 返回某个表是否存在的 SQL 语句,参数是表名(table)。 + +当然,不同数据库之间的差异远远不止这两个地方,随着 ORM 框架功能的增多,dialect 的实现也会逐渐丰富起来,同时框架的其他部分不会受到影响。 + +同时,声明了 `RegisterDialect` 和 `GetDialect` 两个方法用于注册和获取 dialect 实例。如果新增加对某个数据库的支持,那么调用 `RegisterDialect` 即可注册到全局。 + +接下来,在`dialect` 目录下新建文件 `sqlite3.go` 增加对 SQLite 的支持。 + +[day2-reflect-schema/dialect/sqlite3.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/dialect) + +```go +package dialect + +import ( + "fmt" + "reflect" + "time" +) + +type sqlite3 struct{} + +var _ Dialect = (*sqlite3)(nil) + +func init() { + RegisterDialect("sqlite3", &sqlite3{}) +} + +func (s *sqlite3) DataTypeOf(typ reflect.Value) string { + switch typ.Kind() { + case reflect.Bool: + return "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "real" + case reflect.String: + return "text" + case reflect.Array, reflect.Slice: + return "blob" + case reflect.Struct: + if _, ok := typ.Interface().(time.Time); ok { + return "datetime" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind())) +} + +func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) { + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +} +``` + +- `sqlite3.go` 的实现虽然比较繁琐,但是整体逻辑还是非常清晰的。`DataTypeOf` 将 Go 语言的类型映射为 SQLite 的数据类型。`TableExistSQL` 返回了在 SQLite 中判断表 `tableName` 是否存在的 SQL 语句。 +- 实现了 `init()` 函数,包在第一次加载时,会将 sqlite3 的 dialect 自动注册到全局。 + +## 2 Schema + +Dialect 实现了一些特定的 SQL 语句的转换,接下来我们将要实现 ORM 框架中最为核心的转换——对象(object)和表(table)的转换。给定一个任意的对象,转换为关系型数据库中的表结构。 + +在数据库中创建一张表需要哪些要素呢? + +- 表名(table name) —— 结构体名(struct name) +- 字段名和字段类型 —— 成员变量和类型。 +- 额外的约束条件(例如非空、主键等) —— 成员变量的Tag(Go 语言通过 Tag 实现,Java、Python 等语言通过注解实现) + +举一个实际的例子: + +```go +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} +``` + +期望对应的 schema 语句: + +```sql +CREATE TABLE `User` (`Name` text PRIMARY KEY, `Age` integer); +``` + +我们将这部分代码的实现放置在一个子包 `schema/schema.go` 中。 + +[day2-reflect-schema/schema/schema.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/schema) + +```go +package schema + +import ( + "geeorm/dialect" + "go/ast" + "reflect" +) + +// Field represents a column of database +type Field struct { + Name string + Type string + Tag string +} + +// Schema represents a table of database +type Schema struct { + Model interface{} + Name string + Fields []*Field + FieldNames []string + fieldMap map[string]*Field +} + +func (schema *Schema) GetField(name string) *Field { + return schema.fieldMap[name] +} +``` + +- Field 包含 3 个成员变量,字段名 Name、类型 Type、和约束条件 Tag +- Schema 主要包含被映射的对象 Model、表名 Name 和字段 Fields。 +- FieldNames 包含所有的字段名(列名),fieldMap 记录字段名和 Field 的映射关系,方便之后直接使用,无需遍历 Fields。 + +接下来实现 Parse 函数,将任意的对象解析为 Schema 实例。 + +```go +func Parse(dest interface{}, d dialect.Dialect) *Schema { + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + schema := &Schema{ + Model: dest, + Name: modelType.Name(), + fieldMap: make(map[string]*Field), + } + + for i := 0; i < modelType.NumField(); i++ { + p := modelType.Field(i) + if !p.Anonymous && ast.IsExported(p.Name) { + field := &Field{ + Name: p.Name, + Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))), + } + if v, ok := p.Tag.Lookup("geeorm"); ok { + field.Tag = v + } + schema.Fields = append(schema.Fields, field) + schema.FieldNames = append(schema.FieldNames, p.Name) + schema.fieldMap[p.Name] = field + } + } + return schema +} +``` + +- `TypeOf()` 和 `ValueOf()` 是 reflect 包最为基本也是最重要的 2 个方法,分别用来返回入参的类型和值。因为设计的入参是一个对象的指针,因此需要 `reflect.Indirect()` 获取指针指向的实例。 +- `modelType.Name()` 获取到结构体的名称作为表名。 +- `NumField()` 获取实例的字段的个数,然后通过下标获取到特定字段 `p := modelType.Field(i)`。 +- `p.Name` 即字段名,`p.Type` 即字段类型,通过 `(Dialect).DataTypeOf()` 转换为数据库的字段类型,`p.Tag` 即额外的约束条件。 + +写一个测试用例来验证 Parse 函数。 + +```go +// schema_test.go +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +var TestDial, _ = dialect.GetDialect("sqlite3") + +func TestParse(t *testing.T) { + schema := Parse(&User{}, TestDial) + if schema.Name != "User" || len(schema.Fields) != 2 { + t.Fatal("failed to parse User struct") + } + if schema.GetField("Name").Tag != "PRIMARY KEY" { + t.Fatal("failed to parse primary key") + } +} +``` + +## 3 Session + +Session 的核心功能是与数据库进行交互。因此,我们将数据库表的增/删操作实现在子包 session 中。在此之前,Session 的结构需要做一些调整。 + +```go +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + sql strings.Builder + sqlVars []interface{} +} + +func New(db *sql.DB, dialect dialect.Dialect) *Session { + return &Session{ + db: db, + dialect: dialect, + } +} +``` + +- `Session` 成员变量新增 dialect 和 refTable +- 构造函数 `New` 的参数改为 2 个,db 和 dialect。 + +在文件夹 `session` 下新建 `table.go` 用于放置操作数据库表相关的代码。 + +[day2-reflect-schema/session/table.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/session) + +```go +func (s *Session) Model(value interface{}) *Session { + // nil or different model, update refTable + if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) { + s.refTable = schema.Parse(value, s.dialect) + } + return s +} + +func (s *Session) RefTable() *schema.Schema { + if s.refTable == nil { + log.Error("Model is not set") + } + return s.refTable +} +``` + +- `Model()` 方法用于给 refTable 赋值。解析操作是比较耗时的,因此将解析的结果保存在成员变量 refTable 中,即使 `Model()` 被调用多次,如果传入的结构体名称不发生变化,则不会更新 refTable 的值。 +- `RefTable()` 方法返回 refTable 的值,如果 refTable 未被赋值,则打印错误日志。 + +接下来实现数据库表的创建、删除和判断是否存在的功能。三个方法的实现逻辑是相似的,利用 `RefTable()` 返回的数据库表和字段的信息,拼接出 SQL 语句,调用原生 SQL 接口执行。 + +```go +func (s *Session) CreateTable() error { + table := s.RefTable() + var columns []string + for _, field := range table.Fields { + columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag)) + } + desc := strings.Join(columns, ",") + _, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec() + return err +} + +func (s *Session) DropTable() error { + _, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec() + return err +} + +func (s *Session) HasTable() bool { + sql, values := s.dialect.TableExistSQL(s.RefTable().Name) + row := s.Raw(sql, values...).QueryRow() + var tmp string + _ = row.Scan(&tmp) + return tmp == s.RefTable().Name +} +``` + +在 `table_test.go` 中实现对应的测试用例: + +```go +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestSession_CreateTable(t *testing.T) { + s := NewSession().Model(&User{}) + _ = s.DropTable() + _ = s.CreateTable() + if !s.HasTable() { + t.Fatal("Failed to create table User") + } +} +``` + +## 4 Engine + +因为 Session 构造函数增加了对 dialect 的依赖,Engine 需要作一些细微的调整。 + +[day2-reflect-schema/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema) + +```go +type Engine struct { + db *sql.DB + dialect dialect.Dialect +} + +func NewEngine(driver, source string) (e *Engine, err error) { + db, err := sql.Open(driver, source) + if err != nil { + log.Error(err) + return + } + // Send a ping to make sure the database connection is alive. + if err = db.Ping(); err != nil { + log.Error(err) + return + } + // make sure the specific dialect exists + dial, ok := dialect.GetDialect(driver) + if !ok { + log.Errorf("dialect %s Not Found", driver) + return + } + e = &Engine{db: db, dialect: dial} + log.Info("Connect database success") + return +} + +func (engine *Engine) NewSession() *session.Session { + return session.New(engine.db, engine.dialect) +} +``` + +- `NewEngine` 创建 Engine 实例时,获取 driver 对应的 dialect。 +- `NewSession` 创建 Session 实例时,传递 dialect 给构造函数 New。 + +至此,第二天的内容已经完成了,总结一下今天的成果: + +- 1)为适配不同的数据库,映射数据类型和特定的 SQL 语句,创建 Dialect 层屏蔽数据库差异。 +- 2)设计 Schema,利用反射(reflect)完成结构体和数据库表结构的映射,包括表名、字段名、字段类型、字段 tag 等。 +- 3)构造创建(create)、删除(drop)、存在性(table exists) 的 SQL 语句完成数据库表的基本操作。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [Go Reflect 提高反射性能](https://geektutu.com/post/hpg-reflect.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day3.md b/gee-orm/doc/geeorm-day3.md new file mode 100644 index 0000000..d541ec3 --- /dev/null +++ b/gee-orm/doc/geeorm-day3.md @@ -0,0 +1,402 @@ +--- +title: 动手写ORM框架 - GeeORM第三天 记录新增和查询 +date: 2020-03-08 01:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。实现新增(insert)记录的功能;使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- insert into +- select from +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day3 记录新增和查询 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第三篇。 + +- 实现新增(insert)记录的功能。 +- 使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。**代码约150行** + +## 1 Clause 构造 SQL 语句 + +从第三天开始,GeeORM 需要涉及一些较为复杂的操作,例如查询操作。查询语句一般由很多个子句(clause) 构成。SELECT 语句的构成通常是这样的: + +```sql +SELECT col1, col2, ... + FROM table_name + WHERE [ conditions ] + GROUP BY col1 + HAVING [ conditions ] +``` + +也就是说,如果想一次构造出完整的 SQL 语句是比较困难的,因此我们将构造 SQL 语句这一部分独立出来,放在子package clause 中实现。 + +首先在 `clause/generator.go` 中实现各个子句的生成规则。 + +[day3-save-query/clause/generator.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/clause) + + +```go +package clause + +import ( + "fmt" + "strings" +) + +type generator func(values ...interface{}) (string, []interface{}) + +var generators map[Type]generator + +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy +} + +func genBindVars(num int) string { + var vars []string + for i := 0; i < num; i++ { + vars = append(vars, "?") + } + return strings.Join(vars, ", ") +} + +func _insert(values ...interface{}) (string, []interface{}) { + // INSERT INTO $tableName ($fields) + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{} +} + +func _values(values ...interface{}) (string, []interface{}) { + // VALUES ($v1), ($v2), ... + var bindStr string + var sql strings.Builder + var vars []interface{} + sql.WriteString("VALUES ") + for i, value := range values { + v := value.([]interface{}) + if bindStr == "" { + bindStr = genBindVars(len(v)) + } + sql.WriteString(fmt.Sprintf("(%v)", bindStr)) + if i+1 != len(values) { + sql.WriteString(", ") + } + vars = append(vars, v...) + } + return sql.String(), vars + +} + +func _select(values ...interface{}) (string, []interface{}) { + // SELECT $fields FROM $tableName + tableName := values[0] + fields := strings.Join(values[1].([]string), ",") + return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{} +} + +func _limit(values ...interface{}) (string, []interface{}) { + // LIMIT $num + return "LIMIT ?", values +} + +func _where(values ...interface{}) (string, []interface{}) { + // WHERE $desc + desc, vars := values[0], values[1:] + return fmt.Sprintf("WHERE %s", desc), vars +} + +func _orderBy(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{} +} +``` + +然后在 `clause/clause.go` 中实现结构体 `Clause` 拼接各个独立的子句。 + +[day3-save-query/clause/clause.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/clause) + +```go +package clause + +import "strings" + +type Clause struct { + sql map[Type]string + sqlVars map[Type][]interface{} +} + +type Type int +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY +) + +func (c *Clause) Set(name Type, vars ...interface{}) { + if c.sql == nil { + c.sql = make(map[Type]string) + c.sqlVars = make(map[Type][]interface{}) + } + sql, vars := generators[name](vars...) + c.sql[name] = sql + c.sqlVars[name] = vars +} + +func (c *Clause) Build(orders ...Type) (string, []interface{}) { + var sqls []string + var vars []interface{} + for _, order := range orders { + if sql, ok := c.sql[order]; ok { + sqls = append(sqls, sql) + vars = append(vars, c.sqlVars[order]...) + } + } + return strings.Join(sqls, " "), vars +} +``` + +- `Set` 方法根据 `Type` 调用对应的 generator,生成该子句对应的 SQL 语句。 +- `Build` 方法根据传入的 `Type` 的顺序,构造出最终的 SQL 语句。 + +在 `clause_test.go` 实现对应的测试用例: + +```go +func testSelect(t *testing.T) { + var clause Clause + clause.Set(LIMIT, 3) + clause.Set(SELECT, "User", []string{"*"}) + clause.Set(WHERE, "Name = ?", "Tom") + clause.Set(ORDERBY, "Age ASC") + sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) + t.Log(sql, vars) + if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { + t.Fatal("failed to build SQL") + } + if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) { + t.Fatal("failed to build SQLVars") + } +} + +func TestClause_Build(t *testing.T) { + t.Run("select", func(t *testing.T) { + testSelect(t) + }) +} +``` + +## 2 实现 Insert 功能 + +首先为 Session 添加成员变量 clause + +```go +// session/raw.go +type Session struct { + db *sql.DB + dialect dialect.Dialect + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +func (s *Session) Clear() { + s.sql.Reset() + s.sqlVars = nil + s.clause = clause.Clause{} +} +``` + +clause 已经支持生成简单的插入(INSERT) 和 查询(SELECT) 的 SQL 语句,那么紧接着我们就可以在 session 中实现对应的功能了。 + +INSERT 对应的 SQL 语句一般是这样的: + +```sql +INSERT INTO table_name(col1, col2, col3, ...) VALUES + (A1, A2, A3, ...), + (B1, B2, B3, ...), + ... +``` + +在 ORM 框架中期望 Insert 的调用方式如下: + +```go +s := geeorm.NewEngine("sqlite3", "gee.db").NewSession() +u1 := &User{Name: "Tom", Age: 18} +u2 := &User{Name: "Sam", Age: 25} +s.Insert(u1, u2, ...) +``` + +也就是说,我们还需要一个步骤,根据数据库中列的顺序,从对象中找到对应的值,按顺序平铺。即 `u1`、`u2` 转换为 `("Tom", 18), ("Same", 25)` 这样的格式。 + +因此在实现 Insert 功能之前,还需要给 `Schema` 新增一个函数 `RecordValues` 完成上述的转换。 + +[day3-save-query/schema/schema.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/schema) + +```go +func (schema *Schema) RecordValues(dest interface{}) []interface{} { + destValue := reflect.Indirect(reflect.ValueOf(dest)) + var fieldValues []interface{} + for _, field := range schema.Fields { + fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface()) + } + return fieldValues +} +``` + +在 session 文件夹下新建 record.go,用于实现记录增删查改相关的代码。 + +[day3-save-query/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/session) + +```go +package session + +import ( + "geeorm/clause" + "reflect" +) + +func (s *Session) Insert(values ...interface{}) (int64, error) { + recordValues := make([]interface{}, 0) + for _, value := range values { + table := s.Model(value).RefTable() + s.clause.Set(clause.INSERT, table.Name, table.FieldNames) + recordValues = append(recordValues, table.RecordValues(value)) + } + + s.clause.Set(clause.VALUES, recordValues...) + sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + + return result.RowsAffected() +} +``` + +后续所有构造 SQL 语句的方式都将与 `Insert` 中构造 SQL 语句的方式一致。分两步: + +- 1)多次调用 `clause.Set()` 构造好每一个子句。 +- 2)调用一次 `clause.Build()` 按照传入的顺序构造出最终的 SQL 语句。 + +构造完成后,调用 `Raw().Exec()` 方法执行。 + +## 3 实现 Find 功能 + +期望的调用方式是这样的:传入一个切片指针,查询的结果保存在切片中。 + +```go +s := geeorm.NewEngine("sqlite3", "gee.db").NewSession() +var users []User +s.Find(&users); +``` + +Find 功能的难点和 Insert 恰好反了过来。Insert 需要将已经存在的对象的每一个字段的值平铺开来,而 Find 则是需要根据平铺开的字段的值构造出对象。同样,也需要用到反射(reflect)。 + +```go +func (s *Session) Find(values interface{}) error { + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() + + s.clause.Set(clause.SELECT, table.Name, table.FieldNames) + sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) + rows, err := s.Raw(sql, vars...).QueryRows() + if err != nil { + return err + } + + for rows.Next() { + dest := reflect.New(destType).Elem() + var values []interface{} + for _, name := range table.FieldNames { + values = append(values, dest.FieldByName(name).Addr().Interface()) + } + if err := rows.Scan(values...); err != nil { + return err + } + destSlice.Set(reflect.Append(destSlice, dest)) + } + return rows.Close() +} +``` + +Find 的代码实现比较复杂,主要分为以下几步: + +- 1) `destSlice.Type().Elem()` 获取切片的单个元素的类型 `destType`,使用 `reflect.New()` 方法创建一个 `destType` 的实例,作为 `Model()` 的入参,映射出表结构 `RefTable()`。 +- 2)根据表结构,使用 clause 构造出 SELECT 语句,查询到所有符合条件的记录 `rows`。 +- 3)遍历每一行记录,利用反射创建 `destType` 的实例 `dest`,将 `dest` 的所有字段平铺开,构造切片 `values`。 +- 4)调用 `rows.Scan()` 将该行记录每一列的值依次赋值给 values 中的每一个字段。 +- 5)将 `dest` 添加到切片 `destSlice` 中。循环直到所有的记录都添加到切片 `destSlice` 中。 + +## 4 测试 + +在 session 文件夹下新建 `record_test.go`,创建测试用例。 + +> `User` 和 `NewSession()` 的定义位于 raw_test.go 中。 + +[day3-save-query/session/record_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/session) + +```go +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Insert(t *testing.T) { + s := testRecordInit(t) + affected, err := s.Insert(user3) + if err != nil || affected != 1 { + t.Fatal("failed to create record") + } +} + +func TestSession_Find(t *testing.T) { + s := testRecordInit(t) + var users []User + if err := s.Find(&users); err != nil || len(users) != 2 { + t.Fatal("failed to query all") + } +} +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) +- [Laws Of Reflection - golang.org](https://blog.golang.org/laws-of-reflection) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day4.md b/gee-orm/doc/geeorm-day4.md new file mode 100644 index 0000000..8cfcc75 --- /dev/null +++ b/gee-orm/doc/geeorm-day4.md @@ -0,0 +1,285 @@ +--- +title: 动手写ORM框架 - GeeORM第四天 链式操作与更新删除 +date: 2020-03-08 16:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。通过链式(chain)操作,支持查询条件(where, order by, limit 等)的叠加;实现记录的更新(update)、删除(delete)和统计(count)功能。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- chain operation +- delete from +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day4 链式操作与更新删除 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第四篇。 + +- 通过链式(chain)操作,支持查询条件(where, order by, limit 等)的叠加。 +- 实现记录的更新(update)、删除(delete)和统计(count)功能。**代码约100行** + +## 1 支持 Update、Delete 和 Count + +### 1.1 子句生成器 + +clause 负责构造 SQL 语句,如果需要增加对更新(update)、删除(delete)和统计(count)功能的支持,第一步自然是在 clause 中实现 update、delete 和 count 子句的生成器。 + +第一步:在原来的基础上,新增 UPDATE、DELETE、COUNT 三个 `Type` 类型的枚举值。 + +[day4-chain-operation/clause/clause.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/clause) + +```go +// Support types for Clause +const ( + INSERT Type = iota + VALUES + SELECT + LIMIT + WHERE + ORDERBY + UPDATE + DELETE + COUNT +) +``` + +第二步:实现对应字句的 generator,并注册到全局变量 `generators` 中 + +[day4-chain-operation/clause/generator.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/clause) + +```go +func init() { + generators = make(map[Type]generator) + generators[INSERT] = _insert + generators[VALUES] = _values + generators[SELECT] = _select + generators[LIMIT] = _limit + generators[WHERE] = _where + generators[ORDERBY] = _orderBy + generators[UPDATE] = _update + generators[DELETE] = _delete + generators[COUNT] = _count +} + +func _update(values ...interface{}) (string, []interface{}) { + tableName := values[0] + m := values[1].(map[string]interface{}) + var keys []string + var vars []interface{} + for k, v := range m { + keys = append(keys, k+" = ?") + vars = append(vars, v) + } + return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars +} + +func _delete(values ...interface{}) (string, []interface{}) { + return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{} +} + +func _count(values ...interface{}) (string, []interface{}) { + return _select(values[0], []string{"count(*)"}) +} +``` + +- `_update` 设计入参是2个,第一个参数是表名(table),第二个参数是 map 类型,表示待更新的键值对。 +- `_delete` 只有一个入参,即表名。 +- `_count` 只有一个入参,即表名,并复用了 `_select` 生成器。 + + +### 1.2 Update 方法 + +子句的 generator 已经准备好了,接下来和 Insert、Find 等方法一样,在 `session/record.go` 中按照一定顺序拼接 SQL 语句并调用就可以了。 + +[day4-chain-operation/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/session) + +```go +// support map[string]interface{} +// also support kv list: "Name", "Tom", "Age", 18, .... +func (s *Session) Update(kv ...interface{}) (int64, error) { + m, ok := kv[0].(map[string]interface{}) + if !ok { + m = make(map[string]interface{}) + for i := 0; i < len(kv); i += 2 { + m[kv[i].(string)] = kv[i+1] + } + } + s.clause.Set(clause.UPDATE, s.RefTable().Name, m) + sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + return result.RowsAffected() +} +``` + +Update 方法比较特别的一点在于,Update 接受 2 种入参,平铺开来的键值对和 map 类型的键值对。因为 generator 接受的参数是 map 类型的键值对,因此 `Update` 方法会动态地判断传入参数的类型,如果是不是 map 类型,则会自动转换。 + + +### 1.3 Delete 方法 + +```go +// Delete records with where clause +func (s *Session) Delete() (int64, error) { + s.clause.Set(clause.DELETE, s.RefTable().Name) + sql, vars := s.clause.Build(clause.DELETE, clause.WHERE) + result, err := s.Raw(sql, vars...).Exec() + if err != nil { + return 0, err + } + return result.RowsAffected() +} +``` + +### 1.4 Count 方法 + +```go +// Count records with where clause +func (s *Session) Count() (int64, error) { + s.clause.Set(clause.COUNT, s.RefTable().Name) + sql, vars := s.clause.Build(clause.COUNT, clause.WHERE) + row := s.Raw(sql, vars...).QueryRow() + var tmp int64 + if err := row.Scan(&tmp); err != nil { + return 0, err + } + return tmp, nil +} +``` + +## 2 链式调用(chain) + +链式调用是一种简化代码的编程方式,能够使代码更简洁、易读。链式调用的原理也非常简单,某个对象调用某个方法后,将该对象的引用/指针返回,即可以继续调用该对象的其他方法。通常来说,当某个对象需要一次调用多个方法来设置其属性时,就非常适合改造为链式调用了。 + +SQL 语句的构造过程就非常符合这个条件。SQL 语句由多个子句构成,典型的例如 SELECT 语句,往往需要设置查询条件(WHERE)、限制返回行数(LIMIT)等。理想的调用方式应该是这样的: + +```go +s := geeorm.NewEngine("sqlite3", "gee.db").NewSession() +var users []User +s.Where("Age > 18").Limit(3).Find(&users) +``` + +从上面的示例中,可以看出,`WHERE`、`LIMIT`、`ORDER BY` 等查询条件语句非常适合链式调用。这几个子句的 generator 在之前就已经实现了,那我们接下来在 `session/record.go` 中添加对应的方法即可。 + +[day4-chain-operation/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/session) + +```go +// Limit adds limit condition to clause +func (s *Session) Limit(num int) *Session { + s.clause.Set(clause.LIMIT, num) + return s +} + +// Where adds limit condition to clause +func (s *Session) Where(desc string, args ...interface{}) *Session { + var vars []interface{} + s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...) + return s +} + +// OrderBy adds order by condition to clause +func (s *Session) OrderBy(desc string) *Session { + s.clause.Set(clause.ORDERBY, desc) + return s +} +``` + +## 3 First 只返回一条记录 + +很多时候,我们期望 SQL 语句只返回一条记录,比如根据某个童鞋的学号查询他的信息,返回结果有且只有一条。结合链式调用,我们可以非常容易地实现 First 方法。 + +```go +func (s *Session) First(value interface{}) error { + dest := reflect.Indirect(reflect.ValueOf(value)) + destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem() + if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil { + return err + } + if destSlice.Len() == 0 { + return errors.New("NOT FOUND") + } + dest.Set(destSlice.Index(0)) + return nil +} +``` + +First 方法可以这么使用: + +```go +u := &User{} +_ = s.OrderBy("Age DESC").First(u) +``` + +> 实现原理:根据传入的类型,利用反射构造切片,调用 `Limit(1)` 限制返回的行数,调用 `Find` 方法获取到查询结果。 + +## 4 测试 + +接下来呢,我们在 `record_test.go` 中添加几个测试用例,检测功能是否正常。 + +```go +package session + +import "testing" + +var ( + user1 = &User{"Tom", 18} + user2 = &User{"Sam", 25} + user3 = &User{"Jack", 25} +) + +func testRecordInit(t *testing.T) *Session { + t.Helper() + s := NewSession().Model(&User{}) + err1 := s.DropTable() + err2 := s.CreateTable() + _, err3 := s.Insert(user1, user2) + if err1 != nil || err2 != nil || err3 != nil { + t.Fatal("failed init test records") + } + return s +} + +func TestSession_Limit(t *testing.T) { + s := testRecordInit(t) + var users []User + err := s.Limit(1).Find(&users) + if err != nil || len(users) != 1 { + t.Fatal("failed to query with limit condition") + } +} + +func TestSession_Update(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30) + u := &User{} + _ = s.OrderBy("Age DESC").First(u) + + if affected != 1 || u.Age != 30 { + t.Fatal("failed to update") + } +} + +func TestSession_DeleteAndCount(t *testing.T) { + s := testRecordInit(t) + affected, _ := s.Where("Name = ?", "Tom").Delete() + count, _ := s.Count() + + if affected != 1 || count != 1 { + t.Fatal("failed to delete or count") + } +} +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day5.md b/gee-orm/doc/geeorm-day5.md new file mode 100644 index 0000000..3d6ffda --- /dev/null +++ b/gee-orm/doc/geeorm-day5.md @@ -0,0 +1,157 @@ +--- +title: 动手写ORM框架 - GeeORM第五天 实现钩子(Hooks) +date: 2020-03-08 18:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。通过反射(reflect)获取结构体绑定的钩子(hooks),并调用;支持增删查改(CRUD)前后调用钩子。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- hooks +- BeforeUpdate +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day5 实现钩子 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第五篇。 + +- 通过反射(reflect)获取结构体绑定的钩子(hooks),并调用。 +- 支持增删查改(CRUD)前后调用钩子。**代码约50行** + +## 1 Hook 机制 + +Hook,翻译为钩子,其主要思想是提前在可能增加功能的地方埋好(预设)一个钩子,当我们需要重新修改或者增加这个地方的逻辑的时候,把扩展的类或者方法挂载到这个点即可。钩子的应用非常广泛,例如 Github 支持的 travis 持续集成服务,当有 `git push` 事件发生时,会触发 travis 拉取新的代码进行构建。IDE 中钩子也非常常见,比如,当按下 `Ctrl + s` 后,自动格式化代码。再比如前端常用的 `hot reload` 机制,前端代码发生变更时,自动编译打包,通知浏览器自动刷新页面,实现所写即所得。 + +钩子机制设计的好坏,取决于扩展点选择的是否合适。例如对于持续集成来说,代码如果不发生变更,反复构建是没有意义的,因此钩子应设计在代码可能发生变更的地方,比如 MR、PR 合并前后。 + +那对于 ORM 框架来说,合适的扩展点在哪里呢?很显然,记录的增删查改前后都是非常合适的。 + +比如,我们设计一个 `Account` 类,`Account` 包含有一个隐私字段 `Password`,那么每次查询后都需要做脱敏处理,才能继续使用。如果提供了 `AfterQuery` 的钩子,查询后,自动地将 `Password` 字段的值脱敏,是不是能省去很多冗余的代码呢? + +## 2 实现钩子 + +GeeORM 的钩子与结构体绑定,即每个结构体需要实现各自的钩子。hook 相关的代码实现在 `session/hooks.go` 中。 + +[day5-hooks/session/hooks.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day5-hooks/session) + +```go +package session + +import ( + "geeorm/log" + "reflect" +) + +// Hooks constants +const ( + BeforeQuery = "BeforeQuery" + AfterQuery = "AfterQuery" + BeforeUpdate = "BeforeUpdate" + AfterUpdate = "AfterUpdate" + BeforeDelete = "BeforeDelete" + AfterDelete = "AfterDelete" + BeforeInsert = "BeforeInsert" + AfterInsert = "AfterInsert" +) + +// CallMethod calls the registered hooks +func (s *Session) CallMethod(method string, value interface{}) { + fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method) + if value != nil { + fm = reflect.ValueOf(value).MethodByName(method) + } + param := []reflect.Value{reflect.ValueOf(s)} + if fm.IsValid() { + if v := fm.Call(param); len(v) > 0 { + if err, ok := v[0].Interface().(error); ok { + log.Error(err) + } + } + } + return +} +``` + +- 钩子机制同样是通过反射来实现的,`s.RefTable().Model` 或 `value` 即当前会话正在操作的对象,使用 `MethodByName` 方法反射得到该对象的方法。 +- 将 `s *Session` 作为入参调用。每一个钩子的入参类型均是 `*Session`。 + +接下来,将 `CallMethod()` 方法在 Find、Insert、Update、Delete 方法内部调用即可。例如,`Find` 方法修改为: + +```go +// Find gets all eligible records +func (s *Session) Find(values interface{}) error { + s.CallMethod(BeforeQuery, nil) + // ... + for rows.Next() { + dest := reflect.New(destType).Elem() + // ... + s.CallMethod(AfterQuery, dest.Addr().Interface()) + // ... + } + return rows.Close() +} +``` + +- `AfterQuery` 钩子可以操作每一行记录。 + +## 3 测试 + +新建 `session/hooks.go` 文件添加对应的测试用例。 + +```go +package session + +import ( + "geeorm/log" + "testing" +) + +type Account struct { + ID int `geeorm:"PRIMARY KEY"` + Password string +} + +func (account *Account) BeforeInsert(s *Session) error { + log.Info("before inert", account) + account.ID += 1000 + return nil +} + +func (account *Account) AfterQuery(s *Session) error { + log.Info("after query", account) + account.Password = "******" + return nil +} + +func TestSession_CallMethod(t *testing.T) { + s := NewSession().Model(&Account{}) + _ = s.DropTable() + _ = s.CreateTable() + _, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"}) + + u := &Account{} + + err := s.First(u) + if err != nil || u.ID != 1001 || u.Password != "******" { + t.Fatal("Failed to call hooks after query, got", u) + } +} +``` + +在这个测试用例中,测试了 `BeforeInsert` 和 `AfterQuery` 2 个钩子。 + +- `BeforeInsert` 将 account.ID 的值增加 1000 +- `AfterQuery` 将密码脱敏,显示为 6 个 `*`。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day6.md b/gee-orm/doc/geeorm-day6.md new file mode 100644 index 0000000..abe43a9 --- /dev/null +++ b/gee-orm/doc/geeorm-day6.md @@ -0,0 +1,273 @@ +--- +title: 动手写ORM框架 - GeeORM第六天 支持事务(Transaction) +date: 2020-03-08 21:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。介绍数据库中的事务(transaction);封装事务,用户自定义回调函数实现原子操作。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- transaction +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day6 支持事务 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第六篇。 + +- 介绍数据库中的事务(transaction)。 +- 封装事务,用户自定义回调函数实现原子操作。**代码约100行** + +## 1 事务的 ACID 属性 + +> 数据库事务(transaction)是访问并可能操作各种数据项的一个数据库操作序列,这些操作要么全部执行,要么全部不执行,是一个不可分割的工作单位。事务由事务开始与事务结束之间执行的全部数据库操作组成。 + +举一个简单的例子,转账。A 转账给 B 一万元,那么数据库至少需要执行 2 个操作: + +- 1)A 的账户减掉一万元。 +- 2)B 的账户增加一万元。 + +这两个操作要么全部执行,代表转账成功。任意一个操作失败了,之前的操作都必须回退,代表转账失败。一个操作完成,另一个操作失败,这种结果是不能够接受的。这种场景就非常适合利用数据库事务的特性来解决。 + +如果一个数据库支持事务,那么必须具备 ACID 四个属性。 + +- 1)原子性(Atomicity):事务中的全部操作在数据库中是不可分割的,要么全部完成,要么全部不执行。 +- 2)一致性(Consistency): 几个并行执行的事务,其执行结果必须与按某一顺序 串行执行的结果相一致。 +- 3)隔离性(Isolation):事务的执行不受其他事务的干扰,事务执行的中间结果对其他事务必须是透明的。 +- 4)持久性(Durability):对于任意已提交事务,系统必须保证该事务对数据库的改变不被丢失,即使数据库出现故障。 + +## 2 SQLite 和 Go 标准库中的事务 + +SQLite 中创建一个事务的原生 SQL 长什么样子呢? + +```sql +sqlite> BEGIN; +sqlite> DELETE FROM User WHERE Age > 25; +sqlite> INSERT INTO User VALUES ("Tom", 25), ("Jack", 18); +sqlite> COMMIT; +``` + +`BEGIN` 开启事务,`COMMIT` 提交事务,`ROLLBACK` 回滚事务。任何一个事务,均以 `BEGIN` 开始,`COMMIT` 或 `ROLLBACK` 结束。 + +Go 语言标准库 database/sql 提供了支持事务的接口。用一个简单的例子,看一看 Go 语言标准是如何支持事务的。 + +```go +package main + +import ( + "database/sql" + _ "github.com/mattn/go-sqlite3" + "log" +) + +func main() { + db, _ := sql.Open("sqlite3", "gee.db") + defer func() { _ = db.Close() }() + _, _ = db.Exec("CREATE TABLE IF NOT EXISTS User(`Name` text);") + + tx, _ := db.Begin() + _, err1 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Tom") + _, err2 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Jack") + if err1 != nil || err2 != nil { + _ = tx.Rollback() + log.Println("Rollback", err1, err2) + } else { + _ = tx.Commit() + log.Println("Commit") + } +} +``` + +Go 语言中实现事务和 SQL 原生语句其实是非常接近的。调用 `db.Begin()` 得到 `*sql.Tx` 对象,使用 `tx.Exec()` 执行一系列操作,如果发生错误,通过 `tx.Rollback()` 回滚,如果没有发生错误,则通过 `tx.Commit()` 提交。 + +## 3 GeeORM 支持事务 + +GeeORM 之前的操作均是执行完即自动提交的,每个操作是相互独立的。之前直接使用 `sql.DB` 对象执行 SQL 语句,如果要支持事务,需要更改为 `sql.Tx` 执行。在 Session 结构体中新增成员变量 `tx *sql.Tx`,当 `tx` 不为空时,则使用 `tx` 执行 SQL 语句,否则使用 `db` 执行 SQL 语句。这样既兼容了原有的执行方式,又提供了对事务的支持。 + +[day6-transaction/session/raw.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction/session) + +```go +type Session struct { + db *sql.DB + dialect dialect.Dialect + tx *sql.Tx + refTable *schema.Schema + clause clause.Clause + sql strings.Builder + sqlVars []interface{} +} + +// CommonDB is a minimal function set of db +type CommonDB interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + Exec(query string, args ...interface{}) (sql.Result, error) +} + +var _ CommonDB = (*sql.DB)(nil) +var _ CommonDB = (*sql.Tx)(nil) + +// DB returns tx if a tx begins. otherwise return *sql.DB +func (s *Session) DB() CommonDB { + if s.tx != nil { + return s.tx + } + return s.db +} +``` + +新建文件 `session/transaction.go` 封装事务的 Begin、Commit 和 Rollback 三个接口。 + +[day6-transaction/session/transaction.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction/session) + +```go +package session + +import "geeorm/log" + +func (s *Session) Begin() (err error) { + log.Info("transaction begin") + if s.tx, err = s.db.Begin(); err != nil { + log.Error(err) + return + } + return +} + +func (s *Session) Commit() (err error) { + log.Info("transaction commit") + if err = s.tx.Commit(); err != nil { + log.Error(err) + } + return +} + +func (s *Session) Rollback() (err error) { + log.Info("transaction rollback") + if err = s.tx.Rollback(); err != nil { + log.Error(err) + } + return +} +``` + +- 调用 `s.db.Begin()` 得到 `*sql.Tx` 对象,赋值给 s.tx。 +- 封装的另一个目的是统一打印日志,方便定位问题。 + + +最后一步,在 `geeorm.go` 中为用户提供傻瓜式/一键式使用的接口。 + +[day6-transaction/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction) + +```go +type TxFunc func(*session.Session) (interface{}, error) + +func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) { + s := engine.NewSession() + if err := s.Begin(); err != nil { + return nil, err + } + defer func() { + if p := recover(); p != nil { + _ = s.Rollback() + panic(p) // re-throw panic after Rollback + } else if err != nil { + _ = s.Rollback() // err is non-nil; don't change it + } else { + err = s.Commit() // err is nil; if Commit returns error update err + } + }() + + return f(s) +} +``` + +> Transaction 的实现参考了 [stackoverflow](https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback) + +用户只需要将所有的操作放到一个回调函数中,作为入参传递给 `engine.Transaction()`,发生任何错误,自动回滚,如果没有错误发生,则提交。 + +## 4 测试 + +在 `geeorm_test.go` 中添加测试用例看看 Transaction 如何工作的吧。 + +```go +func OpenDB(t *testing.T) *Engine { + t.Helper() + engine, err := NewEngine("sqlite3", "gee.db") + if err != nil { + t.Fatal("failed to connect", err) + } + return engine +} + +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestEngine_Transaction(t *testing.T) { + t.Run("rollback", func(t *testing.T) { + transactionRollback(t) + }) + t.Run("commit", func(t *testing.T) { + transactionCommit(t) + }) +} +``` + +首先是 rollback 的用例: + +```go +func transactionRollback(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return nil, errors.New("Error") + }) + if err == nil || s.HasTable() { + t.Fatal("failed to rollback") + } +} +``` + +- 在这个用例中,如何执行成功,则会创建一张表 `User`,并插入一条记录。 +- 故意返回了一个自定义 error,最终事务回滚,表创建失败。 + +接下来是 commit 的用例: + +```go +func transactionCommit(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _ = s.Model(&User{}).DropTable() + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + _ = s.Model(&User{}).CreateTable() + _, err = s.Insert(&User{"Tom", 18}) + return + }) + u := &User{} + _ = s.First(u) + if err != nil || u.Name != "Tom" { + t.Fatal("failed to commit") + } +} +``` + +- 创建表和插入记录均成功执行,最终通过 `s.First()` 方法查询到插入的记录。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm-day7.md b/gee-orm/doc/geeorm-day7.md new file mode 100644 index 0000000..efd96db --- /dev/null +++ b/gee-orm/doc/geeorm-day7.md @@ -0,0 +1,170 @@ +--- +title: 动手写ORM框架 - GeeORM第七天 数据库迁移(Migrate) +date: 2020-03-08 23:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。结构体(struct)变更时,数据库表的字段(field)自动迁移(migrate);仅支持字段新增与删除,不支持字段类型变更。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- database/sql +- sqlite +- migrate +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day7 数据库迁移 +--- + +本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第七篇。 + +- 结构体(struct)变更时,数据库表的字段(field)自动迁移(migrate)。 +- 仅支持字段新增与删除,不支持字段类型变更。**代码约70行** + +## 1 使用 SQL 语句 Migrate + +数据库 Migrate 一直是数据库运维人员最为头痛的问题,如果仅仅是一张表增删字段还比较容易,那如果涉及到外键等复杂的关联关系,数据库的迁移就会变得非常困难。 + +GeeORM 的 Migrate 操作仅针对最为简单的场景,即支持字段的新增与删除,不支持字段类型变更。 + +在实现 Migrate 之前,我们先看看如何使用原生的 SQL 语句增删字段。 + +### 1.1 新增字段 + +```sql +ALTER TABLE table_name ADD COLUMN col_name, col_type; +``` + +大部分数据支持使用 `ALTER` 关键字新增字段,或者重命名字段。 + +### 1.2 删除字段 + +> 参考 [sqlite delete or add column - stackoverflow](https://stackoverflow.com/questions/8442147/how-to-delete-or-add-column-in-sqlite) + +对于 SQLite 来说,删除字段并不像新增字段那么容易,一个比较可行的方法需要执行下列几个步骤: + +```sql +CREATE TABLE new_table AS SELECT col1, col2, ... from old_table +DROP TABLE old_table +ALTER TABLE new_table RENAME TO old_table; +``` + +- 第一步:从 `old_table` 中挑选需要保留的字段到 `new_table` 中。 +- 第二步:删除 `old_table`。 +- 第三步:重命名 `new_table` 为 `old_table`。 + +## 2 GeeORM 实现 Migrate + +按照原生的 SQL 命令,利用之前实现的事务,在 `geeorm.go` 中实现 Migrate 方法。 + +```go +// difference returns a - b +func difference(a []string, b []string) (diff []string) { + mapB := make(map[string]bool) + for _, v := range b { + mapB[v] = true + } + for _, v := range a { + if _, ok := mapB[v]; !ok { + diff = append(diff, v) + } + } + return +} + +// Migrate table +func (engine *Engine) Migrate(value interface{}) error { + _, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) { + if !s.Model(value).HasTable() { + log.Infof("table %s doesn't exist", s.RefTable().Name) + return nil, s.CreateTable() + } + table := s.RefTable() + rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows() + columns, _ := rows.Columns() + addCols := difference(table.FieldNames, columns) + delCols := difference(columns, table.FieldNames) + log.Infof("added cols %v, deleted cols %v", addCols, delCols) + + for _, col := range addCols { + f := table.GetField(col) + sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type) + if _, err = s.Raw(sqlStr).Exec(); err != nil { + return + } + } + + if len(delCols) == 0 { + return + } + tmp := "tmp_" + table.Name + fieldStr := strings.Join(table.FieldNames, ", ") + s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name)) + s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name)) + s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name)) + _, err = s.Exec() + return + }) + return err +} +``` + +- `difference` 用来计算前后两个字段切片的差集。新表 - 旧表 = 新增字段,旧表 - 新表 = 删除字段。 +- 使用 `ALTER` 语句新增字段。 +- 使用创建新表并重命名的方式删除字段。 + +## 3 测试 + +在 `geeorm_test.go` 中添加 Migrate 的测试用例: + +```go +type User struct { + Name string `geeorm:"PRIMARY KEY"` + Age int +} + +func TestEngine_Migrate(t *testing.T) { + engine := OpenDB(t) + defer engine.Close() + s := engine.NewSession() + _, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec() + _, _ = s.Raw("CREATE TABLE User(Name text PRIMARY KEY, XXX integer);").Exec() + _, _ = s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec() + engine.Migrate(&User{}) + + rows, _ := s.Raw("SELECT * FROM User").QueryRows() + columns, _ := rows.Columns() + if !reflect.DeepEqual(columns, []string{"Name", "Age"}) { + t.Fatal("Failed to migrate table User, got columns", columns) + } +} +``` + +- 首先假设原有的 `User` 包含两个字段 `Name` 和 `XXX`,在一次业务变更之后,`User` 结构体的字段变更为 `Name` 和 `Age`。 +- 即需要删除原有字段 `XXX`,并新增字段 `Age`。 +- 调用 `Migrate(&User{})` 之后,新表的结构为 `Name`,`Age` + +## 4 总结 + +GeeORM 的整体实现比较粗糙,比如数据库的迁移仅仅考虑了最简单的场景。实现的特性也比较少,比如结构体嵌套的场景,外键的场景,复合主键的场景都没有覆盖。ORM 框架的代码规模一般都比较大,如果想尽可能地逼近数据库,就需要大量的代码来实现相关的特性;二是数据库之间的差异也是比较大的,实现的功能越多,数据库之间的差异就会越突出,有时候为了达到较好的性能,就不得不为每个数据做特殊处理;还有些 ORM 框架同时支持关系型数据库和非关系型数据库,这就要求框架本身有更高层次的抽象,不能局限在 SQL 这一层。 + +GeeORM 仅 800 左右的代码是不可能做到这一点的。不过,GeeORM 的目的并不是实现一个可以在生产使用的 ORM 框架,而是希望尽可能多地介绍 ORM 框架大致的实现原理,例如 + +- 在框架中如何屏蔽不同数据库之间的差异; +- 数据库中表结构和编程语言中的对象是如何映射的; +- 如何优雅地模拟查询条件,链式调用是个不错的选择; +- 为什么 ORM 框架通常会提供 hooks 扩展的能力; +- 事务的原理和 ORM 框架如何集成对事务的支持; +- 一些难点问题,例如数据库迁移。 +- ... + +基于这几点,我觉得 GeeORM 的目的达到了。 + +## 附 推荐阅读 + +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) +- [sqlite delete or add column - stackoverflow](https://stackoverflow.com/questions/8442147/how-to-delete-or-add-column-in-sqlite) diff --git a/gee-orm/doc/geeorm.md b/gee-orm/doc/geeorm.md new file mode 100644 index 0000000..6db5c98 --- /dev/null +++ b/gee-orm/doc/geeorm.md @@ -0,0 +1,141 @@ +--- +title: 7天用Go从零实现ORM框架GeeORM +date: 2020-03-01 01:00:00 +description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。功能包括对象和表结构的相互映射,表的创建删除(table),记录的增删查改,事务支持(transaction),数据库迁移(migrate),钩子(hooks)等。 +tags: +- Go +nav: 从零实现 +categories: +- ORM框架 - GeeORM +keywords: +- Go语言 +- 从零实现ORM框架 +- 动手写ORM框架 +- database/sql +- sqlite3 +image: post/geeorm/geeorm_sm.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day0 序言 +--- + +![golang ORM framework](geeorm/geeorm.jpg) + +## 1 谈谈 ORM 框架 + +> 对象关系映射(Object Relational Mapping,简称ORM)是通过使用描述对象和数据库之间映射的元数据,将面向对象语言程序中的对象自动持久化到关系数据库中。 + +那对象和数据库是如何映射的呢? + +| 数据库 | 面向对象的编程语言 | +|:---:|:---:| +| 表(table) | 类(class/struct) | +| 记录(record, row) | 对象 (object) | +| 字段(field, column) | 对象属性(attribute) | + +举一个具体的例子,来理解 ORM。 + +```sql +CREATE TABLE `User` (`Name` text, `Age` integer); +INSERT INTO `User` (`Name`, `Age`) VALUES ("Tom", 18); +SELECT * FROM `User`; +``` + +第一条 SQL 语句,在数据库中创建了表 `User`,并且定义了 2 个字段 `Name` 和 `Age`;第二条 SQL 语句往表中添加了一条记录;最后一条语句返回表中的所有记录。 + +假如我们使用了 ORM 框架,可以这么写: + +```go +type User struct { + Name string + Age int +} + +orm.CreateTable(&User{}) +orm.Save(&User{"Tom", 18}) +var users []User +orm.Find(&users) +``` + +ORM 框架相当于对象和数据库中间的一个桥梁,借助 ORM 可以避免写繁琐的 SQL 语言,仅仅通过操作具体的对象,就能够完成对关系型数据库的操作。 + +那如何实现一个 ORM 框架呢? + +- `CreateTable` 方法需要从参数 `&User{}` 得到对应的结构体的名称 User 作为表名,成员变量 Name, Age 作为列名,同时还需要知道成员变量对应的类型。 +- `Save` 方法则需要知道每个成员变量的值。 +- `Find` 方法仅从传入的空切片 `&[]User`,得到对应的结构体名也就是表名 User,并从数据库中取到所有的记录,将其转换成 User 对象,添加到切片中。 + +如果这些方法只接受 User 类型的参数,那是很容易实现的。但是 ORM 框架是通用的,也就是说可以将任意合法的对象转换成数据库中的表和记录。例如: + +```go +type Account struct { + Username string + Password string +} + +orm.CreateTable(&Account{}) +``` + +这就面临了一个很重要的问题:如何根据任意类型的指针,得到其对应的结构体的信息。这涉及到了 Go 语言的反射机制(reflect),通过反射,可以获取到对象对应的结构体名称,成员变量、方法等信息,例如: + +```go +typ := reflect.Indirect(reflect.ValueOf(&Account{})).Type() +fmt.Println(typ.Name()) // Account + +for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + fmt.Println(field.Name) // Username Password +} +``` + +- `reflect.ValueOf()` 获取指针对应的反射值。 +- `reflect.Indirect()` 获取指针指向的对象的反射值。 +- `(reflect.Type).Name()` 返回类名(字符串)。 +- `(reflect.Type).Field(i)` 获取第 i 个成员变量。 + +除了对象和表结构/记录的映射以外,设计 ORM 框架还需要关注什么问题呢? + +1)MySQL,PostgreSQL,SQLite 等数据库的 SQL 语句是有区别的,ORM 框架如何在开发者不感知的情况下适配多种数据库? + +2)如何对象的字段发生改变,数据库表结构能够自动更新,即是否支持数据库自动迁移(migrate)? + +3)数据库支持的功能很多,例如事务(transaction),ORM 框架能实现哪些? + +4)... + +## 2 关于 GeeORM + +数据库的特性非常多,简单的增删查改使用 ORM 替代 SQL 语句是没有问题的,但是也有很多特性难以用 ORM 替代,比如复杂的多表关联查询,ORM 也可能支持,但是基于性能的考虑,开发者自己写 SQL 语句很可能更高效。 + +因此,设计实现一个 ORM 框架,就需要给功能特性排优先级了。 + +Go 语言中使用比较广泛 ORM 框架是 [gorm](https://github.com/jinzhu/gorm) 和 [xorm](https://github.com/go-xorm/xorm)。除了基础的功能,比如表的操作,记录的增删查改,gorm 还实现了关联关系(一对一、一对多等),回调插件等;xorm 实现了读写分离(支持配置多个数据库),数据同步,导入导出等。 + +gorm 正在彻底重构 v1 版本,短期内看不到发布 v2 的可能。相比于 gorm-v1,xorm 在设计上更清晰。GeeORM 的设计主要参考了 xorm,一些细节上的实现参考了 gorm。GeeORM 的目的主要是了解 ORM 框架设计的原理,具体实现上鲁棒性做得不够,一些复杂的特性,例如 gorm 的关联关系,xorm 的读写分离没有实现。目前支持的特性有: + +- 表的创建、删除、迁移。 +- 记录的增删查改,查询条件的链式操作。 +- 单一主键的设置(primary key)。 +- 钩子(在创建/更新/删除/查找之前或之后) +- 事务(transaction)。 +- ... + +`GeeORM` 分7天实现,每天完成的部分都是可以独立运行和测试的,就像搭积木一样,一个个独立的特性组合在一起就是最终的 ORM 框架。每天的代码在 100 行左右,同时配有较为完备的单元测试用例。 + +## 3 目录 + +- 第一天:[database/sql 基础](https://geektutu.com/post/geeorm-day1.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day1-database-sql) +- 第二天:[对象表结构映射](https://geektutu.com/post/geeorm-day2.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day2-reflect-schema) +- 第三天:[记录新增和查询](https://geektutu.com/post/geeorm-day3.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day3-save-query) +- 第四天:[链式操作与更新删除](https://geektutu.com/post/geeorm-day4.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day4-chain-operation) +- 第五天:[实现钩子(Hooks)](https://geektutu.com/post/geeorm-day5.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day5-hooks) +- 第六天:[支持事务(Transaction)](https://geektutu.com/post/geeorm-day6.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day6-transaction) +- 第七天:[数据库迁移(Migrate)](https://geektutu.com/post/geeorm-day7.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day7-migrate) + + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) +- [Go Reflect 提高反射性能](https://geektutu.com/post/hpg-reflect.html) +- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html) \ No newline at end of file diff --git a/gee-orm/doc/geeorm/geeorm.jpg b/gee-orm/doc/geeorm/geeorm.jpg new file mode 100644 index 0000000..6b980ac Binary files /dev/null and b/gee-orm/doc/geeorm/geeorm.jpg differ diff --git a/gee-orm/doc/geeorm/geeorm_sm.jpg b/gee-orm/doc/geeorm/geeorm_sm.jpg new file mode 100644 index 0000000..960a985 Binary files /dev/null and b/gee-orm/doc/geeorm/geeorm_sm.jpg differ diff --git a/gee-orm/run_test.sh b/gee-orm/run_test.sh new file mode 100755 index 0000000..23475a2 --- /dev/null +++ b/gee-orm/run_test.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -eou pipefail + +cur=$PWD +for item in "$cur"/day*/ +do + echo "$item" + cd "$item" + go test geeorm/... 2>&1 | grep -v warning +done \ No newline at end of file diff --git a/gee-rpc/day1-codec/codec/codec.go b/gee-rpc/day1-codec/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day1-codec/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day1-codec/codec/gob.go b/gee-rpc/day1-codec/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day1-codec/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day1-codec/go.mod b/gee-rpc/day1-codec/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day1-codec/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day1-codec/main/main.go b/gee-rpc/day1-codec/main/main.go new file mode 100644 index 0000000..2bc6a8a --- /dev/null +++ b/gee-rpc/day1-codec/main/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "encoding/json" + "fmt" + "geerpc" + "geerpc/codec" + "log" + "net" + "time" +) + +func startServer(addr chan string) { + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + + // in fact, following code is like a simple geerpc client + conn, _ := net.Dial("tcp", <-addr) + defer func() { _ = conn.Close() }() + + time.Sleep(time.Second) + // send options + _ = json.NewEncoder(conn).Encode(geerpc.DefaultOption) + cc := codec.NewGobCodec(conn) + // send request & receive response + for i := 0; i < 5; i++ { + h := &codec.Header{ + ServiceMethod: "Foo.Sum", + Seq: uint64(i), + } + _ = cc.Write(h, fmt.Sprintf("geerpc req %d", h.Seq)) + _ = cc.ReadHeader(h) + var reply string + _ = cc.ReadBody(&reply) + log.Println("reply:", reply) + } +} diff --git a/gee-rpc/day1-codec/server.go b/gee-rpc/day1-codec/server.go new file mode 100644 index 0000000..fb93e4f --- /dev/null +++ b/gee-rpc/day1-codec/server.go @@ -0,0 +1,149 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "reflect" + "sync" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, +} + +// Server represents an RPC Server. +type Server struct{} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn)) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + // TODO: now we don't know the type of request argv + // day 1, just suppose it's string + req.argv = reflect.New(reflect.TypeOf("")) + if err = cc.ReadBody(req.argv.Interface()); err != nil { + log.Println("rpc server: read argv err:", err) + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + // TODO, should call registered rpc methods to get the right replyv + // day 1, just print argv and send a hello message + defer wg.Done() + log.Println(req.h, req.argv.Elem()) + req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq)) + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go new file mode 100644 index 0000000..41cff0f --- /dev/null +++ b/gee-rpc/day2-client/client.go @@ -0,0 +1,246 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "sync" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(serviceMethod string, args, reply interface{}) error { + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) +} diff --git a/gee-rpc/day2-client/codec/codec.go b/gee-rpc/day2-client/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day2-client/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day2-client/codec/gob.go b/gee-rpc/day2-client/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day2-client/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day2-client/go.mod b/gee-rpc/day2-client/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day2-client/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day2-client/main/main.go b/gee-rpc/day2-client/main/main.go new file mode 100644 index 0000000..099eb50 --- /dev/null +++ b/gee-rpc/day2-client/main/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "geerpc" + "log" + "net" + "sync" + "time" +) + +func startServer(addr chan string) { + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := fmt.Sprintf("geerpc req %d", i) + var reply string + if err := client.Call("Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Println("reply:", reply) + }(i) + } + wg.Wait() +} diff --git a/gee-rpc/day2-client/server.go b/gee-rpc/day2-client/server.go new file mode 100644 index 0000000..fb93e4f --- /dev/null +++ b/gee-rpc/day2-client/server.go @@ -0,0 +1,149 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "reflect" + "sync" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, +} + +// Server represents an RPC Server. +type Server struct{} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn)) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + // TODO: now we don't know the type of request argv + // day 1, just suppose it's string + req.argv = reflect.New(reflect.TypeOf("")) + if err = cc.ReadBody(req.argv.Interface()); err != nil { + log.Println("rpc server: read argv err:", err) + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + // TODO, should call registered rpc methods to get the right replyv + // day 1, just print argv and send a hello message + defer wg.Done() + log.Println(req.h, req.argv.Elem()) + req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq)) + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go new file mode 100644 index 0000000..5beef93 --- /dev/null +++ b/gee-rpc/day3-service/client.go @@ -0,0 +1,245 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "sync" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(serviceMethod string, args, reply interface{}) error { + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) +} diff --git a/gee-rpc/day3-service/codec/codec.go b/gee-rpc/day3-service/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day3-service/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day3-service/codec/gob.go b/gee-rpc/day3-service/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day3-service/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day3-service/go.mod b/gee-rpc/day3-service/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day3-service/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day3-service/main/main.go b/gee-rpc/day3-service/main/main.go new file mode 100644 index 0000000..89add53 --- /dev/null +++ b/gee-rpc/day3-service/main/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "geerpc" + "log" + "net" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addr chan string) { + var foo Foo + if err := geerpc.Register(&foo); err != nil { + log.Fatal("register error:", err) + } + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call("Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} diff --git a/gee-rpc/day3-service/server.go b/gee-rpc/day3-service/server.go new file mode 100644 index 0000000..4634394 --- /dev/null +++ b/gee-rpc/day3-service/server.go @@ -0,0 +1,203 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "geerpc/codec" + "io" + "log" + "net" + "reflect" + "strings" + "sync" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn)) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + defer wg.Done() + err := req.svc.call(req.mtype, req.argv, req.replyv) + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } diff --git a/gee-rpc/day3-service/service.go b/gee-rpc/day3-service/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day3-service/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day3-service/service_test.go b/gee-rpc/day3-service/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day3-service/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day4-timeout/client.go b/gee-rpc/day4-timeout/client.go new file mode 100644 index 0000000..bc3602a --- /dev/null +++ b/gee-rpc/day4-timeout/client.go @@ -0,0 +1,278 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (client *Client, err error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err = fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return + } + // send options with server + if err = json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + return + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) + +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + ch := make(chan clientResult) + go func() { + client, err := f(conn, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewClient, network, address, opts...) +} diff --git a/gee-rpc/day4-timeout/client_test.go b/gee-rpc/day4-timeout/client_test.go new file mode 100644 index 0000000..4488455 --- /dev/null +++ b/gee-rpc/day4-timeout/client_test.go @@ -0,0 +1,67 @@ +package geerpc + +import ( + "context" + "net" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + time.Sleep(time.Second) + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} diff --git a/gee-rpc/day4-timeout/codec/codec.go b/gee-rpc/day4-timeout/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day4-timeout/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day4-timeout/codec/gob.go b/gee-rpc/day4-timeout/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day4-timeout/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day4-timeout/go.mod b/gee-rpc/day4-timeout/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day4-timeout/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day4-timeout/main/main.go b/gee-rpc/day4-timeout/main/main.go new file mode 100644 index 0000000..efcf75c --- /dev/null +++ b/gee-rpc/day4-timeout/main/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "geerpc" + "log" + "net" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addr chan string) { + var foo Foo + if err := geerpc.Register(&foo); err != nil { + log.Fatal("register error:", err) + } + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} diff --git a/gee-rpc/day4-timeout/server.go b/gee-rpc/day4-timeout/server.go new file mode 100644 index 0000000..a049914 --- /dev/null +++ b/gee-rpc/day4-timeout/server.go @@ -0,0 +1,228 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } diff --git a/gee-rpc/day4-timeout/service.go b/gee-rpc/day4-timeout/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day4-timeout/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day4-timeout/service_test.go b/gee-rpc/day4-timeout/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day4-timeout/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day5-http-debug/client.go b/gee-rpc/day5-http-debug/client.go new file mode 100644 index 0000000..1a62b1e --- /dev/null +++ b/gee-rpc/day5-http-debug/client.go @@ -0,0 +1,323 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) + +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + ch := make(chan clientResult) + go func() { + client, err := f(conn, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewClient, network, address, opts...) +} + +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewHTTPClient, network, address, opts...) +} + +// XDial calls different functions to connect to a RPC server +// according the first parameter rpcAddr. +// rpcAddr is a general format (protocol@addr) to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } +} diff --git a/gee-rpc/day5-http-debug/client_test.go b/gee-rpc/day5-http-debug/client_test.go new file mode 100644 index 0000000..3b13cb0 --- /dev/null +++ b/gee-rpc/day5-http-debug/client_test.go @@ -0,0 +1,88 @@ +package geerpc + +import ( + "context" + "net" + "os" + "runtime" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + time.Sleep(time.Second) + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day5-http-debug/codec/codec.go b/gee-rpc/day5-http-debug/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day5-http-debug/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day5-http-debug/codec/gob.go b/gee-rpc/day5-http-debug/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day5-http-debug/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day5-http-debug/debug.go b/gee-rpc/day5-http-debug/debug.go new file mode 100644 index 0000000..ece1ffd --- /dev/null +++ b/gee-rpc/day5-http-debug/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+
+ + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day5-http-debug/go.mod b/gee-rpc/day5-http-debug/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day5-http-debug/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day5-http-debug/main/main.go b/gee-rpc/day5-http-debug/main/main.go new file mode 100644 index 0000000..cfd1b88 --- /dev/null +++ b/gee-rpc/day5-http-debug/main/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "geerpc" + "log" + "net" + "net/http" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addrCh chan string) { + var foo Foo + l, _ := net.Listen("tcp", ":9999") + _ = geerpc.Register(&foo) + geerpc.HandleHTTP() + addrCh <- l.Addr().String() + _ = http.Serve(l, nil) +} + +func call(addrCh chan string) { + client, _ := geerpc.DialHTTP("tcp", <-addrCh) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} + +func main() { + log.SetFlags(0) + ch := make(chan string) + go call(ch) + startServer(ch) +} diff --git a/gee-rpc/day5-http-debug/server.go b/gee-rpc/day5-http-debug/server.go new file mode 100644 index 0000000..38fad20 --- /dev/null +++ b/gee-rpc/day5-http-debug/server.go @@ -0,0 +1,266 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} diff --git a/gee-rpc/day5-http-debug/service.go b/gee-rpc/day5-http-debug/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day5-http-debug/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day5-http-debug/service_test.go b/gee-rpc/day5-http-debug/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day5-http-debug/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day6-load-balance/client.go b/gee-rpc/day6-load-balance/client.go new file mode 100644 index 0000000..1a62b1e --- /dev/null +++ b/gee-rpc/day6-load-balance/client.go @@ -0,0 +1,323 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) + +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + ch := make(chan clientResult) + go func() { + client, err := f(conn, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewClient, network, address, opts...) +} + +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewHTTPClient, network, address, opts...) +} + +// XDial calls different functions to connect to a RPC server +// according the first parameter rpcAddr. +// rpcAddr is a general format (protocol@addr) to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } +} diff --git a/gee-rpc/day6-load-balance/client_test.go b/gee-rpc/day6-load-balance/client_test.go new file mode 100644 index 0000000..3b13cb0 --- /dev/null +++ b/gee-rpc/day6-load-balance/client_test.go @@ -0,0 +1,88 @@ +package geerpc + +import ( + "context" + "net" + "os" + "runtime" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + time.Sleep(time.Second) + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day6-load-balance/codec/codec.go b/gee-rpc/day6-load-balance/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day6-load-balance/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day6-load-balance/codec/gob.go b/gee-rpc/day6-load-balance/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day6-load-balance/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day6-load-balance/debug.go b/gee-rpc/day6-load-balance/debug.go new file mode 100644 index 0000000..ece1ffd --- /dev/null +++ b/gee-rpc/day6-load-balance/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day6-load-balance/go.mod b/gee-rpc/day6-load-balance/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day6-load-balance/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day6-load-balance/main/main.go b/gee-rpc/day6-load-balance/main/main.go new file mode 100644 index 0000000..2a476b7 --- /dev/null +++ b/gee-rpc/day6-load-balance/main/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "context" + "geerpc" + "geerpc/xclient" + "log" + "net" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func (f Foo) Sleep(args Args, reply *int) error { + time.Sleep(time.Second * time.Duration(args.Num1)) + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addrCh chan string) { + var foo Foo + l, _ := net.Listen("tcp", ":0") + server := geerpc.NewServer() + _ = server.Register(&foo) + addrCh <- l.Addr().String() + server.Accept(l) +} + +func foo(xc *xclient.XClient, ctx context.Context, typ, serviceMethod string, args *Args) { + var reply int + var err error + switch typ { + case "call": + err = xc.Call(ctx, serviceMethod, args, &reply) + case "broadcast": + err = xc.Broadcast(ctx, serviceMethod, args, &reply) + } + if err != nil { + log.Printf("%s %s error: %v", typ, serviceMethod, err) + } else { + log.Printf("%s %s success: %d + %d = %d", typ, serviceMethod, args.Num1, args.Num2, reply) + } +} + +func call(addr1, addr2 string) { + d := xclient.NewMultiServerDiscovery([]string{"tcp@" + addr1, "tcp@" + addr2}) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func broadcast(addr1, addr2 string) { + d := xclient.NewMultiServerDiscovery([]string{"tcp@" + addr1, "tcp@" + addr2}) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + // expect 2 - 5 timeout + ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func main() { + log.SetFlags(0) + ch1 := make(chan string) + ch2 := make(chan string) + // start two servers + go startServer(ch1) + go startServer(ch2) + + addr1 := <-ch1 + addr2 := <-ch2 + + time.Sleep(time.Second) + call(addr1, addr2) + broadcast(addr1, addr2) +} diff --git a/gee-rpc/day6-load-balance/server.go b/gee-rpc/day6-load-balance/server.go new file mode 100644 index 0000000..38fad20 --- /dev/null +++ b/gee-rpc/day6-load-balance/server.go @@ -0,0 +1,266 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} diff --git a/gee-rpc/day6-load-balance/service.go b/gee-rpc/day6-load-balance/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day6-load-balance/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day6-load-balance/service_test.go b/gee-rpc/day6-load-balance/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day6-load-balance/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day6-load-balance/xclient/discovery.go b/gee-rpc/day6-load-balance/xclient/discovery.go new file mode 100644 index 0000000..70d1cbb --- /dev/null +++ b/gee-rpc/day6-load-balance/xclient/discovery.go @@ -0,0 +1,87 @@ +package xclient + +import ( + "errors" + "math" + "math/rand" + "sync" + "time" +) + +type SelectMode int + +const ( + RandomSelect SelectMode = iota // select randomly + RoundRobinSelect // select using Robbin algorithm +) + +type Discovery interface { + Refresh() error // refresh from remote registry + Update(servers []string) error + Get(mode SelectMode) (string, error) + GetAll() ([]string, error) +} + +var _ Discovery = (*MultiServersDiscovery)(nil) + +// MultiServersDiscovery is a discovery for multi servers without a registry center +// user provides the server addresses explicitly instead +type MultiServersDiscovery struct { + r *rand.Rand // generate random number + mu sync.RWMutex // protect following + servers []string + index int // record the selected position for robin algorithm +} + +// Refresh doesn't make sense for MultiServersDiscovery, so ignore it +func (d *MultiServersDiscovery) Refresh() error { + return nil +} + +// Update the servers of discovery dynamically if needed +func (d *MultiServersDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + return nil +} + +// Get a server according to mode +func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) { + d.mu.Lock() + defer d.mu.Unlock() + n := len(d.servers) + if n == 0 { + return "", errors.New("rpc discovery: no available servers") + } + switch mode { + case RandomSelect: + return d.servers[d.r.Intn(n)], nil + case RoundRobinSelect: + s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety + d.index = (d.index + 1) % n + return s, nil + default: + return "", errors.New("rpc discovery: not supported select mode") + } +} + +// returns all servers in discovery +func (d *MultiServersDiscovery) GetAll() ([]string, error) { + d.mu.RLock() + defer d.mu.RUnlock() + // return a copy of d.servers + servers := make([]string, len(d.servers), len(d.servers)) + copy(servers, d.servers) + return servers, nil +} + +// NewMultiServerDiscovery creates a MultiServersDiscovery instance +func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery { + d := &MultiServersDiscovery{ + servers: servers, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } + d.index = d.r.Intn(math.MaxInt32 - 1) + return d +} diff --git a/gee-rpc/day6-load-balance/xclient/xclient.go b/gee-rpc/day6-load-balance/xclient/xclient.go new file mode 100644 index 0000000..3194d27 --- /dev/null +++ b/gee-rpc/day6-load-balance/xclient/xclient.go @@ -0,0 +1,109 @@ +package xclient + +import ( + "context" + . "geerpc" + "io" + "reflect" + "sync" +) + +type XClient struct { + d Discovery + mode SelectMode + opt *Option + mu sync.Mutex // protect following + clients map[string]*Client +} + +var _ io.Closer = (*XClient)(nil) + +func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient { + return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)} +} + +func (xc *XClient) Close() error { + xc.mu.Lock() + defer xc.mu.Unlock() + for key, client := range xc.clients { + // I have no idea how to deal with error, just ignore it. + _ = client.Close() + delete(xc.clients, key) + } + return nil +} + +func (xc *XClient) dial(rpcAddr string) (*Client, error) { + xc.mu.Lock() + defer xc.mu.Unlock() + client, ok := xc.clients[rpcAddr] + if ok && !client.IsAvailable() { + _ = client.Close() + delete(xc.clients, rpcAddr) + client = nil + } + if client == nil { + var err error + client, err = XDial(rpcAddr, xc.opt) + if err != nil { + return nil, err + } + xc.clients[rpcAddr] = client + } + return client, nil +} + +func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error { + client, err := xc.dial(rpcAddr) + if err != nil { + return err + } + return client.Call(ctx, serviceMethod, args, reply) +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +// xc will choose a proper server. +func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + rpcAddr, err := xc.d.Get(xc.mode) + if err != nil { + return err + } + return xc.call(rpcAddr, ctx, serviceMethod, args, reply) +} + +// Broadcast invokes the named function for every server registered in discovery +func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error { + servers, err := xc.d.GetAll() + if err != nil { + return err + } + var wg sync.WaitGroup + var mu sync.Mutex // protect e and replyDone + var e error + replyDone := reply == nil // if reply is nil, don't need to set value + ctx, cancel := context.WithCancel(ctx) + for _, rpcAddr := range servers { + wg.Add(1) + go func(rpcAddr string) { + defer wg.Done() + var clonedReply interface{} + if reply != nil { + clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface() + } + err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply) + mu.Lock() + if err != nil && e == nil { + e = err + cancel() // if any call failed, cancel unfinished calls + } + if err == nil && !replyDone { + reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) + replyDone = true + } + mu.Unlock() + }(rpcAddr) + } + wg.Wait() + return e +} diff --git a/gee-rpc/day7-registry/client.go b/gee-rpc/day7-registry/client.go new file mode 100644 index 0000000..1a62b1e --- /dev/null +++ b/gee-rpc/day7-registry/client.go @@ -0,0 +1,323 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) + +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + ch := make(chan clientResult) + go func() { + client, err := f(conn, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewClient, network, address, opts...) +} + +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewHTTPClient, network, address, opts...) +} + +// XDial calls different functions to connect to a RPC server +// according the first parameter rpcAddr. +// rpcAddr is a general format (protocol@addr) to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } +} diff --git a/gee-rpc/day7-registry/client_test.go b/gee-rpc/day7-registry/client_test.go new file mode 100644 index 0000000..3b13cb0 --- /dev/null +++ b/gee-rpc/day7-registry/client_test.go @@ -0,0 +1,88 @@ +package geerpc + +import ( + "context" + "net" + "os" + "runtime" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + time.Sleep(time.Second) + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day7-registry/codec/codec.go b/gee-rpc/day7-registry/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day7-registry/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day7-registry/codec/gob.go b/gee-rpc/day7-registry/codec/gob.go new file mode 100644 index 0000000..d9ef2e6 --- /dev/null +++ b/gee-rpc/day7-registry/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err = c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return + } + if err = c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return + } + return +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day7-registry/debug.go b/gee-rpc/day7-registry/debug.go new file mode 100644 index 0000000..ece1ffd --- /dev/null +++ b/gee-rpc/day7-registry/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day7-registry/go.mod b/gee-rpc/day7-registry/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day7-registry/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day7-registry/main/main.go b/gee-rpc/day7-registry/main/main.go new file mode 100644 index 0000000..1797707 --- /dev/null +++ b/gee-rpc/day7-registry/main/main.go @@ -0,0 +1,114 @@ +package main + +import ( + "context" + "geerpc" + "geerpc/registry" + "geerpc/xclient" + "log" + "net" + "net/http" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func (f Foo) Sleep(args Args, reply *int) error { + time.Sleep(time.Second * time.Duration(args.Num1)) + *reply = args.Num1 + args.Num2 + return nil +} + +func startRegistry(wg *sync.WaitGroup) { + l, _ := net.Listen("tcp", ":9999") + registry.HandleHTTP() + wg.Done() + _ = http.Serve(l, nil) +} + +func startServer(registryAddr string, wg *sync.WaitGroup) { + var foo Foo + l, _ := net.Listen("tcp", ":0") + server := geerpc.NewServer() + _ = server.Register(&foo) + registry.Heartbeat(registryAddr, "tcp@"+l.Addr().String(), 0) + wg.Done() + server.Accept(l) +} + +func foo(xc *xclient.XClient, ctx context.Context, typ, serviceMethod string, args *Args) { + var reply int + var err error + switch typ { + case "call": + err = xc.Call(ctx, serviceMethod, args, &reply) + case "broadcast": + err = xc.Broadcast(ctx, serviceMethod, args, &reply) + } + if err != nil { + log.Printf("%s %s error: %v", typ, serviceMethod, err) + } else { + log.Printf("%s %s success: %d + %d = %d", typ, serviceMethod, args.Num1, args.Num2, reply) + } +} + +func call(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func broadcast(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + // expect 2 - 5 timeout + ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func main() { + log.SetFlags(0) + registryAddr := "http://localhost:9999/_geerpc_/registry" + var wg sync.WaitGroup + wg.Add(1) + go startRegistry(&wg) + wg.Wait() + + time.Sleep(time.Second) + wg.Add(2) + go startServer(registryAddr, &wg) + go startServer(registryAddr, &wg) + wg.Wait() + + time.Sleep(time.Second) + call(registryAddr) + broadcast(registryAddr) +} diff --git a/gee-rpc/day7-registry/registry/registry.go b/gee-rpc/day7-registry/registry/registry.go new file mode 100644 index 0000000..29d23a8 --- /dev/null +++ b/gee-rpc/day7-registry/registry/registry.go @@ -0,0 +1,125 @@ +package registry + +import ( + "log" + "net/http" + "sort" + "strings" + "sync" + "time" +) + +// GeeRegistry is a simple register center, provide following functions. +// add a server and receive heartbeat to keep it alive. +// returns all alive servers and delete dead servers sync simultaneously. +type GeeRegistry struct { + timeout time.Duration + mu sync.Mutex // protect following + servers map[string]*ServerItem +} + +type ServerItem struct { + Addr string + start time.Time +} + +const ( + defaultPath = "/_geerpc_/registry" + defaultTimeout = time.Minute * 5 +) + +// New create a registry instance with timeout setting +func New(timeout time.Duration) *GeeRegistry { + return &GeeRegistry{ + servers: make(map[string]*ServerItem), + timeout: timeout, + } +} + +var DefaultGeeRegister = New(defaultTimeout) + +func (r *GeeRegistry) putServer(addr string) { + r.mu.Lock() + defer r.mu.Unlock() + s := r.servers[addr] + if s == nil { + r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()} + } else { + s.start = time.Now() // if exists, update start time to keep alive + } +} + +func (r *GeeRegistry) aliveServers() []string { + r.mu.Lock() + defer r.mu.Unlock() + var alive []string + for addr, s := range r.servers { + if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) { + alive = append(alive, addr) + } else { + delete(r.servers, addr) + } + } + sort.Strings(alive) + return alive +} + +// Runs at /_geerpc_/registry +func (r *GeeRegistry) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + // keep it simple, server is in req.Header + w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ",")) + case "POST": + // keep it simple, server is in req.Header + addr := req.Header.Get("X-Geerpc-Server") + if addr == "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + r.putServer(addr) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath +func (r *GeeRegistry) HandleHTTP(registryPath string) { + http.Handle(registryPath, r) + log.Println("rpc registry path:", registryPath) +} + +func HandleHTTP() { + DefaultGeeRegister.HandleHTTP(defaultPath) +} + +// Heartbeat send a heartbeat message every once in a while +// it's a helper function for a server to register or send heartbeat +func Heartbeat(registry, addr string, duration time.Duration) { + if duration == 0 { + // make sure there is enough time to send heart beat + // before it's removed from registry + duration = defaultTimeout - time.Duration(1)*time.Minute + } + var err error + err = sendHeartbeat(registry, addr) + go func() { + t := time.NewTicker(duration) + for err == nil { + <-t.C + err = sendHeartbeat(registry, addr) + } + }() +} + +func sendHeartbeat(registry, addr string) error { + log.Println(addr, "send heart beat to registry", registry) + httpClient := &http.Client{} + req, _ := http.NewRequest("POST", registry, nil) + req.Header.Set("X-Geerpc-Server", addr) + if _, err := httpClient.Do(req); err != nil { + log.Println("rpc server: heart beat err:", err) + return err + } + return nil +} diff --git a/gee-rpc/day7-registry/server.go b/gee-rpc/day7-registry/server.go new file mode 100644 index 0000000..38fad20 --- /dev/null +++ b/gee-rpc/day7-registry/server.go @@ -0,0 +1,266 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} diff --git a/gee-rpc/day7-registry/service.go b/gee-rpc/day7-registry/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day7-registry/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day7-registry/service_test.go b/gee-rpc/day7-registry/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day7-registry/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day7-registry/xclient/discovery.go b/gee-rpc/day7-registry/xclient/discovery.go new file mode 100644 index 0000000..70d1cbb --- /dev/null +++ b/gee-rpc/day7-registry/xclient/discovery.go @@ -0,0 +1,87 @@ +package xclient + +import ( + "errors" + "math" + "math/rand" + "sync" + "time" +) + +type SelectMode int + +const ( + RandomSelect SelectMode = iota // select randomly + RoundRobinSelect // select using Robbin algorithm +) + +type Discovery interface { + Refresh() error // refresh from remote registry + Update(servers []string) error + Get(mode SelectMode) (string, error) + GetAll() ([]string, error) +} + +var _ Discovery = (*MultiServersDiscovery)(nil) + +// MultiServersDiscovery is a discovery for multi servers without a registry center +// user provides the server addresses explicitly instead +type MultiServersDiscovery struct { + r *rand.Rand // generate random number + mu sync.RWMutex // protect following + servers []string + index int // record the selected position for robin algorithm +} + +// Refresh doesn't make sense for MultiServersDiscovery, so ignore it +func (d *MultiServersDiscovery) Refresh() error { + return nil +} + +// Update the servers of discovery dynamically if needed +func (d *MultiServersDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + return nil +} + +// Get a server according to mode +func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) { + d.mu.Lock() + defer d.mu.Unlock() + n := len(d.servers) + if n == 0 { + return "", errors.New("rpc discovery: no available servers") + } + switch mode { + case RandomSelect: + return d.servers[d.r.Intn(n)], nil + case RoundRobinSelect: + s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety + d.index = (d.index + 1) % n + return s, nil + default: + return "", errors.New("rpc discovery: not supported select mode") + } +} + +// returns all servers in discovery +func (d *MultiServersDiscovery) GetAll() ([]string, error) { + d.mu.RLock() + defer d.mu.RUnlock() + // return a copy of d.servers + servers := make([]string, len(d.servers), len(d.servers)) + copy(servers, d.servers) + return servers, nil +} + +// NewMultiServerDiscovery creates a MultiServersDiscovery instance +func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery { + d := &MultiServersDiscovery{ + servers: servers, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } + d.index = d.r.Intn(math.MaxInt32 - 1) + return d +} diff --git a/gee-rpc/day7-registry/xclient/discovery_gee.go b/gee-rpc/day7-registry/xclient/discovery_gee.go new file mode 100644 index 0000000..865c30e --- /dev/null +++ b/gee-rpc/day7-registry/xclient/discovery_gee.go @@ -0,0 +1,74 @@ +package xclient + +import ( + "log" + "net/http" + "strings" + "time" +) + +type GeeRegistryDiscovery struct { + *MultiServersDiscovery + registry string + timeout time.Duration + lastUpdate time.Time +} + +const defaultUpdateTimeout = time.Second * 10 + +func (d *GeeRegistryDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + d.lastUpdate = time.Now() + return nil +} + +func (d *GeeRegistryDiscovery) Refresh() error { + d.mu.Lock() + defer d.mu.Unlock() + if d.lastUpdate.Add(d.timeout).After(time.Now()) { + return nil + } + log.Println("rpc registry: refresh servers from registry", d.registry) + resp, err := http.Get(d.registry) + if err != nil { + log.Println("rpc registry refresh err:", err) + return err + } + servers := strings.Split(resp.Header.Get("X-Geerpc-Servers"), ",") + d.servers = make([]string, 0, len(servers)) + for _, server := range servers { + if strings.TrimSpace(server) != "" { + d.servers = append(d.servers, strings.TrimSpace(server)) + } + } + d.lastUpdate = time.Now() + return nil +} + +func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) { + if err := d.Refresh(); err != nil { + return "", err + } + return d.MultiServersDiscovery.Get(mode) +} + +func (d *GeeRegistryDiscovery) GetAll() ([]string, error) { + if err := d.Refresh(); err != nil { + return nil, err + } + return d.MultiServersDiscovery.GetAll() +} + +func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery { + if timeout == 0 { + timeout = defaultUpdateTimeout + } + d := &GeeRegistryDiscovery{ + MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)), + registry: registerAddr, + timeout: timeout, + } + return d +} diff --git a/gee-rpc/day7-registry/xclient/xclient.go b/gee-rpc/day7-registry/xclient/xclient.go new file mode 100644 index 0000000..3194d27 --- /dev/null +++ b/gee-rpc/day7-registry/xclient/xclient.go @@ -0,0 +1,109 @@ +package xclient + +import ( + "context" + . "geerpc" + "io" + "reflect" + "sync" +) + +type XClient struct { + d Discovery + mode SelectMode + opt *Option + mu sync.Mutex // protect following + clients map[string]*Client +} + +var _ io.Closer = (*XClient)(nil) + +func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient { + return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)} +} + +func (xc *XClient) Close() error { + xc.mu.Lock() + defer xc.mu.Unlock() + for key, client := range xc.clients { + // I have no idea how to deal with error, just ignore it. + _ = client.Close() + delete(xc.clients, key) + } + return nil +} + +func (xc *XClient) dial(rpcAddr string) (*Client, error) { + xc.mu.Lock() + defer xc.mu.Unlock() + client, ok := xc.clients[rpcAddr] + if ok && !client.IsAvailable() { + _ = client.Close() + delete(xc.clients, rpcAddr) + client = nil + } + if client == nil { + var err error + client, err = XDial(rpcAddr, xc.opt) + if err != nil { + return nil, err + } + xc.clients[rpcAddr] = client + } + return client, nil +} + +func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error { + client, err := xc.dial(rpcAddr) + if err != nil { + return err + } + return client.Call(ctx, serviceMethod, args, reply) +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +// xc will choose a proper server. +func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + rpcAddr, err := xc.d.Get(xc.mode) + if err != nil { + return err + } + return xc.call(rpcAddr, ctx, serviceMethod, args, reply) +} + +// Broadcast invokes the named function for every server registered in discovery +func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error { + servers, err := xc.d.GetAll() + if err != nil { + return err + } + var wg sync.WaitGroup + var mu sync.Mutex // protect e and replyDone + var e error + replyDone := reply == nil // if reply is nil, don't need to set value + ctx, cancel := context.WithCancel(ctx) + for _, rpcAddr := range servers { + wg.Add(1) + go func(rpcAddr string) { + defer wg.Done() + var clonedReply interface{} + if reply != nil { + clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface() + } + err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply) + mu.Lock() + if err != nil && e == nil { + e = err + cancel() // if any call failed, cancel unfinished calls + } + if err == nil && !replyDone { + reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) + replyDone = true + } + mu.Unlock() + }(rpcAddr) + } + wg.Wait() + return e +} diff --git a/gee-rpc/doc/geerpc-day1.md b/gee-rpc/doc/geerpc-day1.md new file mode 100644 index 0000000..76df010 --- /dev/null +++ b/gee-rpc/doc/geerpc-day1.md @@ -0,0 +1,442 @@ +--- +title: 动手写RPC框架 - GeeRPC第一天 服务端与消息编码 +date: 2020-10-06 17:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第一天实现了一个简单的服务端和消息的编码与解码。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- Codec +- 序列化 +- 反序列化 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day1 服务端与消息编码 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第一篇。 + +- 使用 `encoding/gob` 实现消息的编解码(序列化与反序列化) +- 实现一个简易的服务端,仅接受消息,不处理,代码约 200 行 + + +## 消息的序列化与反序列化 + +一个典型的 RPC 调用如下: + +```go +err = client.Call("Arith.Multiply", args, &reply) +``` + +客户端发送的请求包括服务名 `Arith`,方法名 `Multiply`,参数 `args` 三个,服务端的响应包括错误 `error`,返回值 `reply` 2 个。我们将请求和响应中的参数和返回值抽象为 body,剩余的信息放在 header 中,那么就可以抽象出数据结构 Header: + +[day1-codec/codec/codec.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day1-codec) + +```go +package codec + +import "io" + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} +``` + +- ServiceMethod 是服务名和方法名,通常与 Go 语言中的结构体和方法相映射。 +- Seq 是请求的序号,也可以认为是某个请求的 ID,用来区分不同的请求。 +- Error 是错误信息,客户端置为空,服务端如果如果发生错误,将错误信息置于 Error 中。 + + +我们将和消息编解码相关的代码都放到 codec 子目录中,在此之前,还需要在根目录下使用 `go mod init geerpc` 初始化项目,方便后续子 package 之间的引用。 + +进一步,抽象出对消息体进行编解码的接口 Codec,抽象出接口是为了实现不同的 Codec 实例: + +```go +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} +``` + +紧接着,抽象出 Codec 的构造函数,客户端和服务端可以通过 Codec 的 `Type` 得到构造函数,从而创建 Codec 实例。这部分代码和工厂模式类似,与工厂模式不同的是,返回的是构造函数,而非实例。 + +```go +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} +``` + +我们定义了 2 种 Codec,`Gob` 和 `Json`,但是实际代码中只实现了 `Gob` 一种,事实上,2 者的实现非常接近,甚至只需要把 `gob` 换成 `json` 即可。 + +首先定义 `GobCodec` 结构体,这个结构体由四部分构成,`conn` 是由构建函数传入,通常是通过 TCP 或者 Unix 建立 socket 时得到的链接实例,dec 和 enc 对应 gob 的 Decoder 和 Encoder,buf 是为了防止阻塞而创建的带缓冲的 `Writer`,一般这么做能提升性能。 + +[day1-codec/codec/gob.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day1-codec) + +```go +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} +``` + +接着实现 `ReadHeader`、`ReadBody`、`Write` 和 `Close` 方法。 + +```go +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err := c.enc.Encode(h); err != nil { + log.Println("rpc codec: gob error encoding header:", err) + return err + } + if err := c.enc.Encode(body); err != nil { + log.Println("rpc codec: gob error encoding body:", err) + return err + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} +``` + +## 通信过程 + +客户端与服务端的通信需要协商一些内容,例如 HTTP 报文,分为 header 和 body 2 部分,body 的格式和长度通过 header 中的 `Content-Type` 和 `Content-Length` 指定,服务端通过解析 header 就能够知道如何从 body 中读取需要的信息。对于 RPC 协议来说,这部分协商是需要自主设计的。为了提升性能,一般在报文的最开始会规划固定的字节,来协商相关的信息。比如第1个字节用来表示序列化方式,第2个字节表示压缩方式,第3-6字节表示 header 的长度,7-10 字节表示 body 的长度。 + +对于 GeeRPC 来说,目前需要协商的唯一一项内容是消息的编解码方式。我们将这部分信息,放到结构体 `Option` 中承载。目前,已经进入到服务端的实现阶段了。 + +[day1-codec/server.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day1-codec) + +```go +package geerpc + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, +} +``` + +一般来说,涉及协议协商的这部分信息,需要设计固定的字节来传输的。但是为了实现上更简单,GeeRPC 客户端固定采用 JSON 编码 Option,后续的 header 和 body 的编码方式由 Option 中的 CodeType 指定,服务端首先使用 JSON 解码 Option,然后通过 Option 的 CodeType 解码剩余的内容。即报文将以这样的形式发送: + +```bash +| Option{MagicNumber: xxx, CodecType: xxx} | Header{ServiceMethod ...} | Body interface{} | +| <------ 固定 JSON 编码 ------> | <------- 编码方式由 CodeType 决定 ------->| +``` + +在一次连接中,Option 固定在报文的最开始,Header 和 Body 可以有多个,即报文可能是这样的。 + +```go +| Option | Header1 | Body1 | Header2 | Body2 | ... +``` + +## 服务端的实现 + +通信过程已经定义清楚了,那么服务端的实现就比较直接了。 + +[day1-codec/server.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day1-codec) + +```go +// Server represents an RPC Server. +type Server struct{} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } +``` + +- 首先定义了结构体 `Server`,没有任何的成员字段。 +- 实现了 `Accept` 方式,`net.Listener` 作为参数,for 循环等待 socket 连接建立,并开启子协程处理,处理过程交给了 `ServerConn` 方法。 +- DefaultServer 是一个默认的 `Server` 实例,主要为了用户使用方便。 + +如果想启动服务,过程是非常简单的,传入 listener 即可,tcp 协议和 unix 协议都支持。 + +```go +lis, _ := net.Listen("tcp", ":9999") +geerpc.Accept(lis) +``` + +`ServeConn` 的实现就和之前讨论的通信过程紧密相关了,首先使用 `json.NewDecoder` 反序列化得到 Option 实例,检查 MagicNumber 和 CodeType 的值是否正确。然后根据 CodeType 得到对应的消息编解码器,接下来的处理交给 `serverCodec`。 + +```go +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn)) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} +``` + +`serveCodec` 的过程非常简单。主要包含三个阶段 + +- 读取请求 readRequest +- 处理请求 handleRequest +- 回复请求 sendResponse + +之前提到过,在一次连接中,允许接收多个请求,即多个 request header 和 request body,因此这里使用了 for 无限制地等待请求的到来,直到发生错误(例如连接被关闭,接收到的报文有问题等),这里需要注意的点有三个: + +- handleRequest 使用了协程并发执行请求。 +- 处理请求是并发的,但是回复请求的报文必须是逐个发送的,并发容易导致多个回复报文交织在一起,客户端无法解析。在这里使用锁(sending)保证。 +- 尽力而为,只有在 header 解析失败时,才终止循环。 + +```go +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + // TODO: now we don't know the type of request argv + // day 1, just suppose it's string + req.argv = reflect.New(reflect.TypeOf("")) + if err = cc.ReadBody(req.argv.Interface()); err != nil { + log.Println("rpc server: read argv err:", err) + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + // TODO, should call registered rpc methods to get the right replyv + // day 1, just print argv and send a hello message + defer wg.Done() + log.Println(req.h, req.argv.Elem()) + req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq)) + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} +``` + +目前还不能判断 body 的类型,因此在 readRequest 和 handleRequest 中,day1 将 body 作为字符串处理。接收到请求,打印 header,并回复 `geerpc resp ${req.h.Seq}`。这一部分后续再实现。 + + +## main 函数(一个简易的客户端) + +day1 的内容就到此为止了,在这里我们已经实现了一个消息的编解码器 `GobCodec`,并且客户端与服务端实现了简单的协议交换(protocol exchange),即允许客户端使用不同的编码方式。同时实现了服务端的雏形,建立连接,读取、处理并回复客户端的请求。 + +接下来,我们就在 main 函数中看看如何使用刚实现的 GeeRPC 吧。 + +[day1-codec/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day1-codec) + +```go +package main + +import ( + "encoding/json" + "fmt" + "geerpc" + "geerpc/codec" + "log" + "net" + "time" +) + +func startServer(addr chan string) { + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + addr := make(chan string) + go startServer(addr) + + // in fact, following code is like a simple geerpc client + conn, _ := net.Dial("tcp", <-addr) + defer func() { _ = conn.Close() }() + + time.Sleep(time.Second) + // send options + _ = json.NewEncoder(conn).Encode(geerpc.DefaultOption) + cc := codec.NewGobCodec(conn) + // send request & receive response + for i := 0; i < 5; i++ { + h := &codec.Header{ + ServiceMethod: "Foo.Sum", + Seq: uint64(i), + } + _ = cc.Write(h, fmt.Sprintf("geerpc req %d", h.Seq)) + _ = cc.ReadHeader(h) + var reply string + _ = cc.ReadBody(&reply) + log.Println("reply:", reply) + } +} +``` + +- 在 `startServer` 中使用了信道 `addr`,确保服务端端口监听成功,客户端再发起请求。 +- 客户端首先发送 `Option` 进行协议交换,接下来发送消息头 `h := &codec.Header{}`,和消息体 `geerpc req ${h.Seq}`。 +- 最后解析服务端的响应 `reply`,并打印出来。 + +执行结果如下: + +```bash +start rpc server on [::]:63662 +&{Foo.Sum 0 } geerpc req 0 +reply: geerpc resp 0 +&{Foo.Sum 1 } geerpc req 1 +reply: geerpc resp 1 +&{Foo.Sum 2 } geerpc req 2 +reply: geerpc resp 2 +&{Foo.Sum 3 } geerpc req 3 +reply: geerpc resp 3 +&{Foo.Sum 4 } geerpc req 4 +reply: geerpc resp 4 +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) diff --git a/gee-rpc/doc/geerpc-day2.md b/gee-rpc/doc/geerpc-day2.md new file mode 100644 index 0000000..da70f3a --- /dev/null +++ b/gee-rpc/doc/geerpc-day2.md @@ -0,0 +1,401 @@ +--- +title: 动手写RPC框架 - GeeRPC第二天 支持并发与异步的客户端 +date: 2020-10-07 18:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第二天实现了一个支持异步(asynchronous)和并发(concurrent)的客户端。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 客户端 +- 异步 +- 并发 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day2 高性能客户端 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第二篇。 + +- 实现一个支持异步和并发的高性能客户端,代码约 250 行 + + +## Call 的设计 + +对 `net/rpc` 而言,一个函数需要能够被远程调用,需要满足如下五个条件: + +- the method's type is exported. +- the method is exported. +- the method has two arguments, both exported (or builtin) types. +- the method's second argument is a pointer. +- the method has return type error. + +更直观一些: + +```go +func (t *T) MethodName(argType T1, replyType *T2) error +``` + +根据上述要求,首先我们封装了结构体 Call 来承载一次 RPC 调用所需要的信息。 + +[day2-client/client.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day2-client) + +```go +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} +``` + +为了支持异步调用,Call 结构体中添加了一个字段 Done,Done 的类型是 `chan *Call`,当调用结束时,会调用 `call.done()` 通知调用方。 + + +## 实现 Client + +接下来,我们将实现 GeeRPC 客户端最核心的部分 Client。 + +```go +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} +``` + +Client 的字段比较复杂: + +- cc 是消息的编解码器,和服务端类似,用来序列化将要发送出去的请求,以及反序列化接收到的响应。 +- sending 是一个互斥锁,和服务端类似,为了保证请求的有序发送,即防止出现多个请求报文混淆。 +- header 是每个请求的消息头,header 只有在请求发送时才需要,而请求发送是互斥的,因此每个客户端只需要一个,声明在 Client 结构体中可以复用。 +- seq 用于给发送的请求编号,每个请求拥有唯一编号。 +- pending 存储未处理完的请求,键是编号,值是 Call 实例。 +- closing 和 shutdown 任意一个值置为 true,则表示 Client 处于不可用的状态,但有些许的差别,closing 是用户主动关闭的,即调用 `Close` 方法,而 shutdown 置为 true 一般是有错误发生。 + +紧接着,实现和 Call 相关的三个方法。 + +```go +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + call.Seq = client.seq + client.pending[call.Seq] = call + client.seq++ + return call.Seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} +``` + +- registerCall:将参数 call 添加到 client.pending 中,并更新 client.seq。 +- removeCall:根据 seq,从 client.pending 中移除对应的 call,并返回。 +- terminateCalls:服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。 + +对一个客户端端来说,接收响应、发送请求是最重要的 2 个功能。那么首先实现接收功能,接收到的响应有三种情况: + +- call 不存在,可能是请求没有发送完整,或者因为其他原因被取消,但是服务端仍旧处理了。 +- call 存在,但服务端处理出错,即 h.Error 不为空。 +- call 存在,服务端处理正常,那么需要从 body 中读取 Reply 的值。 + +```go +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} +``` + +创建 Client 实例时,首先需要完成一开始的协议交换,即发送 `Option` 信息给服务端。协商好消息的编解码方式之后,再创建一个子协程调用 `receive()` 接收响应。 + +```go +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} +``` + +还需要实现 `Dial` 函数,便于用户传入服务端地址,创建 Client 实例。为了简化用户调用,通过 `...*Option` 将 Option 实现为可选参数。 + +```go +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) +} +``` + +此时,GeeRPC 客户端已经具备了完整的创建连接和接收响应的能力了,最后还需要实现发送请求的能力。 + +```go +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(serviceMethod string, args, reply interface{}) error { + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} +``` + +- `Go` 和 `Call` 是客户端暴露给用户的两个 RPC 服务调用接口,`Go` 是一个异步接口,返回 call 实例。 +- `Call` 是对 `Go` 的封装,阻塞 call.Done,等待响应返回,是一个同步接口。 + +至此,一个支持异步和并发的 GeeRPC 客户端已经完成。 + +## Demo + +第一天 GeeRPC 只实现了服务端,因此我们在 main 函数中手动模拟了整个通信过程,今天我们就将 main 函数中通信部分替换为今天的客户端吧。 + +[day2-client/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day2-client) + +startServer 没有发生变化。 + +```go +func startServer(addr chan string) { + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} +``` + +在 main 函数中使用了 `client.Call` 并发了 5 个 RPC 同步调用,参数和返回值的类型均为 string。 + +```go +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := fmt.Sprintf("geerpc req %d", i) + var reply string + if err := client.Call("Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Println("reply:", reply) + }(i) + } + wg.Wait() +} +``` + +运行结果如下: + +```bash +start rpc server on [::]:50658 +&{Foo.Sum 5 } geerpc req 3 +&{Foo.Sum 1 } geerpc req 0 +&{Foo.Sum 3 } geerpc req 1 +&{Foo.Sum 2 } geerpc req 4 +&{Foo.Sum 4 } geerpc req 2 +reply: geerpc resp 1 +reply: geerpc resp 5 +reply: geerpc resp 3 +reply: geerpc resp 2 +reply: geerpc resp 4 +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc-day3.md b/gee-rpc/doc/geerpc-day3.md new file mode 100644 index 0000000..c1b720c --- /dev/null +++ b/gee-rpc/doc/geerpc-day3.md @@ -0,0 +1,497 @@ +--- +title: 动手写RPC框架 - GeeRPC第三天 服务注册(service register) +date: 2020-10-07 19:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第三天实现了服务注册,即将 Go 语言结构体通过反射映射为服务。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 反射 +- 服务 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day3 服务注册 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第三篇。 + +- 通过反射实现服务注册功能 +- 在服务端实现服务调用,代码约 150 行 + +## 结构体映射为服务 + +RPC 框架的一个基础能力是:像调用本地程序一样调用远程服务。那如何将程序映射为服务呢?那么对 Go 来说,这个问题就变成了如何将结构体的方法映射为服务。 + +对 `net/rpc` 而言,一个函数需要能够被远程调用,需要满足如下五个条件: + +- the method's type is exported. -- 方法所属类型是导出的。 +- the method is exported. -- 方式是导出的。 +- the method has two arguments, both exported (or builtin) types. -- 两个入参,均为导出或内置类型。 +- the method's second argument is a pointer. -- 第二个入参必须是一个指针。 +- the method has return type error. -- 返回值为 error 类型。 + +更直观一些: + +```go +func (t *T) MethodName(argType T1, replyType *T2) error +``` + +假设客户端发过来一个请求,包含 ServiceMethod 和 Argv。 + +```json +{ + "ServiceMethod": "T.MethodName" + "Argv":"0101110101..." // 序列化之后的字节流 +} +``` + +通过 "T.MethodName" 可以确定调用的是类型 T 的 MethodName,如果硬编码实现这个功能,很可能是这样: + +```go +switch req.ServiceMethod { + case "T.MethodName": + t := new(t) + reply := new(T2) + var argv T1 + gob.NewDecoder(conn).Decode(&argv) + err := t.MethodName(argv, reply) + server.sendMessage(reply, err) + case "Foo.Sum": + f := new(Foo) + ... +} +``` + +也就是说,如果使用硬编码的方式来实现结构体与服务的映射,那么每暴露一个方法,就需要编写等量的代码。那有没有什么方式,能够将这个映射过程自动化呢?可以借助反射。 + +通过反射,我们能够非常容易地获取某个结构体的所有方法,并且能够通过方法,获取到该方法所有的参数类型与返回值。例如: + +```go +func main() { + var wg sync.WaitGroup + typ := reflect.TypeOf(&wg) + for i := 0; i < typ.NumMethod(); i++ { + method := typ.Method(i) + argv := make([]string, 0, method.Type.NumIn()) + returns := make([]string, 0, method.Type.NumOut()) + // j 从 1 开始,第 0 个入参是 wg 自己。 + for j := 1; j < method.Type.NumIn(); j++ { + argv = append(argv, method.Type.In(j).Name()) + } + for j := 0; j < method.Type.NumOut(); j++ { + returns = append(returns, method.Type.Out(j).Name()) + } + log.Printf("func (w *%s) %s(%s) %s", + typ.Elem().Name(), + method.Name, + strings.Join(argv, ","), + strings.Join(returns, ",")) + } +} +``` + +运行的结果是: + +```go +func (w *WaitGroup) Add(int) +func (w *WaitGroup) Done() +func (w *WaitGroup) Wait() +``` + +## 通过反射实现 service + +前面两天我们完成了客户端和服务端,客户端相对来说功能是比较完整的,但是服务端的功能并不完整,仅仅将请求的 header 打印了出来,并没有真正地处理。那今天的主要目的是补全这部分功能。首先通过反射实现结构体与服务的映射关系,代码独立放置在 `service.go` 中。 + +[day3-service/service.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day3-service) + +第一步,定义结构体 methodType: + +```go +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} +``` + +每一个 methodType 实例包含了一个方法的完整信息。包括 + +- method:方法本身 +- ArgType:第一个参数的类型 +- ReplyType:第二个参数的类型 +- numCalls:后续统计方法调用次数时会用到 + +另外,我们还实现了 2 个方法 `newArgv` 和 `newReplyv`,用于创建对应类型的实例。`newArgv` 方法有一个小细节,指针类型和值类型创建实例的方式有细微区别。 + +第二步,定义结构体 service: + +```go +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} +``` + +service 的定义也是非常简洁的,name 即映射的结构体的名称,比如 `T`,比如 `WaitGroup`;typ 是结构体的类型;rcvr 即结构体的实例本身,保留 rcvr 是因为在调用时需要 rcvr 作为第 0 个参数;method 是 map 类型,存储映射的结构体的所有符合条件的方法。 + +接下来,完成构造函数 `newService`,入参是任意需要映射为服务的结构体实例。 + +```go +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} +``` + +`registerMethods` 过滤出了符合条件的方法: + +- 两个导出或内置类型的入参(反射时为 3 个,第 0 个是自身,类似于 python 的 self,java 中的 this) +- 返回值有且只有 1 个,类型为 error + +最后,我们还需要实现 `call` 方法,即能够通过反射值调用方法。 + +```go +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} +``` + +## service 的测试用例 + +为了保证 service 实现的正确性,我们为 service.go 写了几个测试用例。 + +[day3-service/service_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day3-service) + +定义结构体 Foo,实现 2 个方法,导出方法 Sum 和 非导出方法 sum。 + +```go +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} +``` + +测试 newService 和 call 方法。 + +```go +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} +``` + +## 集成到服务端 + +通过反射结构体已经映射为服务,但请求的处理过程还没有完成。从接收到请求到回复还差以下几个步骤:第一步,根据入参类型,将请求的 body 反序列化;第二步,调用 `service.call`,完成方法调用;第三步,将 reply 序列化为字节流,构造响应报文,返回。 + +回到代码本身,补全之前在 `server.go` 中遗留的 2 个 TODO 任务 `readRequest` 和 `handleRequest` 即可。 + +在这之前,我们还需要为 Server 实现一个方法 `Register`。 + +[day3-service/server.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day3-service) + +```go +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// Register publishes in the server the set of methods of the +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } +``` + +配套实现 `findService` 方法,即通过 `ServiceMethod` 从 serviceMap 中找到对应的 service + +```go +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} +``` + +`findService` 的实现看似比较繁琐,但是逻辑还是非常清晰的。因为 ServiceMethod 的构成是 "Service.Method",因此先将其分割成 2 部分,第一部分是 Service 的名称,第二部分即方法名。现在 serviceMap 中找到对应的 service 实例,再从 service 实例的 method 中,找到对应的 methodType。 + +准备工具已经就绪,我们首先补全 readRequest 方法: + +```go +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} +``` + +readRequest 方法中最重要的部分,即通过 `newArgv()` 和 `newReplyv()` 两个方法创建出两个入参实例,然后通过 `cc.ReadBody()` 将请求报文反序列化为第一个入参 argv,在这里同样需要注意 argv 可能是值类型,也可能是指针类型,所以处理方式有点差异。 + +接下来补全 handleRequest 方法: + +```go +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + defer wg.Done() + err := req.svc.call(req.mtype, req.argv, req.replyv) + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} +``` + +相对于 readRequest,handleRequest 的实现非常简单,通过 `req.svc.call` 完成方法调用,将 replyv 传递给 sendResponse 完成序列化即可。 + +到这里,今天的所有内容已经实现完成,成功在服务端实现了服务注册与调用。 + +## Demo + +最后,还是需要写一个可执行程序(main)验证今天的成果。 + +[day3-service/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day3-service) + +第一步,定义结构体 Foo 和方法 Sum + +```go +package main + +import ( + "geerpc" + "log" + "net" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} +``` + +第二步,注册 Foo 到 Server 中,并启动 RPC 服务 + +```go +func startServer(addr chan string) { + var foo Foo + if err := geerpc.Register(&foo); err != nil { + log.Fatal("register error:", err) + } + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} +``` + +第三步,构造参数,发送 RPC 请求,并打印结果。 + +```go +func main() { + log.SetFlags(0) + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call("Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} +``` + +运行结果如下: + +```bash +rpc server: register Foo.Sum +start rpc server on [::]:57509 +1 + 1 = 2 +2 + 4 = 6 +3 + 9 = 12 +0 + 0 = 0 +4 + 16 = 20 +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) diff --git a/gee-rpc/doc/geerpc-day4.md b/gee-rpc/doc/geerpc-day4.md new file mode 100644 index 0000000..4d89c05 --- /dev/null +++ b/gee-rpc/doc/geerpc-day4.md @@ -0,0 +1,270 @@ +--- +title: 动手写RPC框架 - GeeRPC第四天 超时处理(timeout) +date: 2020-10-07 23:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第四天为RPC框架提供了处理超时的能力(timeout processing)。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 连接超时 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day4 超时处理 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第四篇。 + +- 增加连接超时的处理机制 +- 增加服务端处理超时的处理机制,代码约 100 行 + +## 为什么需要超时处理机制 + +超时处理是 RPC 框架一个比较基本的能力,如果缺少超时处理机制,无论是服务端还是客户端都容易因为网络或其他错误导致挂死,资源耗尽,这些问题的出现大大地降低了服务的可用性。因此,我们需要在 RPC 框架中加入超时处理的能力。 + +纵观整个远程调用的过程,需要客户端处理超时的地方有: + +- 与服务端建立连接,导致的超时 +- 发送请求到服务端,写报文导致的超时 +- 等待服务端处理时,等待处理导致的超时(比如服务端已挂死,迟迟不响应) +- 从服务端接收响应时,读报文导致的超时 + +需要服务端处理超时的地方有: + +- 读取客户端请求报文时,读报文导致的超时 +- 发送响应报文时,写报文导致的超时 +- 调用映射服务的方法时,处理报文导致的超时 + + +GeeRPC 在 3 个地方添加了超时处理机制。分别是: + +1)客户端创建连接时 +2)客户端 `Client.Call()` 整个过程导致的超时(包含发送报文,等待处理,接收报文所有阶段) +3)服务端处理报文,即 `Server.handleRequest` 超时。 + +## 创建连接超时 + +为了实现上的简单,将超时设定放在了 Option 中。ConnectTimeout 默认值为 10s,HandleTimeout 默认值为 0,即不设限。 + +```go +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} +``` + +客户端连接超时,只需要为 Dial 添加一层超时处理的外壳即可。 + +[day4-timeout/client.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day4-timeout) + +```go +type clientResult struct { + client *Client + err error +} + +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) + +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if err != nil { + _ = conn.Close() + } + }() + ch := make(chan clientResult) + go func() { + client, err := f(conn, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewClient, network, address, opts...) +} +``` + +在这里实现了一个超时处理的外壳 `dialTimeout`,这个壳将 NewClient 作为入参,在 2 个地方添加了超时处理的机制。 + +1) 将 `net.Dial` 替换为 `net.DialTimeout`,如果连接创建超时,将返回错误。 +2)使用子协程执行 NewClient,执行完成后则通过信道 ch 发送结果,如果 `time.After()` 信道先接收到消息,则说明 NewClient 执行超时,返回错误。 + +## Client.Call 超时 + +`Client.Call` 的超时处理机制,使用 context 包实现,控制权交给用户,控制更为灵活。 + +```go +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} +``` + +用户可以使用 `context.WithTimeout` 创建具备超时检测能力的 context 对象来控制。例如: + +```go +ctx, _ := context.WithTimeout(context.Background(), time.Second) +var reply int +err := client.Call(ctx, "Foo.Sum", &Args{1, 2}, &reply) +... +``` + +## 服务端处理超时 + +这一部分的实现与客户端很接近,使用 `time.After()` 结合 `select+chan` 完成。 + +[day4-timeout/server.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day4-timeout) + +```go +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} +``` + +这里需要确保 `sendResponse` 仅调用一次,因此将整个过程拆分为 `called` 和 `sent` 两个阶段,在这段代码中只会发生如下两种情况: + +1) called 信道接收到消息,代表处理没有超时,继续执行 sendResponse。 +2) `time.After()` 先于 called 接收到消息,说明处理已经超时,called 和 sent 都将被阻塞。在 `case <-time.After(timeout)` 处调用 `sendResponse`。 + +## 测试用例 + +第一个测试用例,用于测试连接超时。NewClient 函数耗时 2s,ConnectionTimeout 分别设置为 1s 和 0 两种场景。 + +[day4-timeout/client_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day4-timeout) + +```go +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} +``` + +第二个测试用例,用于测试处理超时。`Bar.Timeout` 耗时 2s,场景一:客户端设置超时时间为 1s,服务端无限制;场景二,服务端设置超时时间为1s,客户端无限制。 + +```go +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + time.Sleep(time.Second) + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc-day5.md b/gee-rpc/doc/geerpc-day5.md new file mode 100644 index 0000000..c6813da --- /dev/null +++ b/gee-rpc/doc/geerpc-day5.md @@ -0,0 +1,376 @@ +--- +title: 动手写RPC框架 - GeeRPC第五天 支持HTTP协议 +date: 2020-10-08 11:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第五天支持了 HTTP 协议,并且提供了一个简单的 DEBUG 页面。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- HTTP +- debug +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day5 支持HTTP协议 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第五篇。 + +- 支持 HTTP 协议 +- 基于 HTTP 实现一个简单的 Debug 页面,代码约 150 行。 + +## 支持 HTTP 协议需要做什么? + +Web 开发中,我们经常使用 HTTP 协议中的 HEAD、GET、POST 等方式发送请求,等待响应。但 RPC 的消息格式与标准的 HTTP 协议并不兼容,在这种情况下,就需要一个协议的转换过程。HTTP 协议的 CONNECT 方法恰好提供了这个能力,CONNECT 一般用于代理服务。 + +假设浏览器与服务器之间的 HTTPS 通信都是加密的,浏览器通过代理服务器发起 HTTPS 请求时,由于请求的站点地址和端口号都是加密保存在 HTTPS 请求报文头中的,代理服务器如何知道往哪里发送请求呢?为了解决这个问题,浏览器通过 HTTP 明文形式向代理服务器发送一个 CONNECT 请求告诉代理服务器目标地址和端口,代理服务器接收到这个请求后,会在对应端口与目标站点建立一个 TCP 连接,连接建立成功后返回 HTTP 200 状态码告诉浏览器与该站点的加密通道已经完成。接下来代理服务器仅需透传浏览器和服务器之间的加密数据包即可,代理服务器无需解析 HTTPS 报文。 + +举一个简单例子: + +1) 浏览器向代理服务器发送 CONNECT 请求。 + +```bash +CONNECT geektutu.com:443 HTTP/1.0 +``` + +2) 代理服务器返回 HTTP 200 状态码表示连接已经建立。 + +```bash +HTTP/1.0 200 Connection Established +``` + +3) 之后浏览器和服务器开始 HTTPS 握手并交换加密数据,代理服务器只负责传输彼此的数据包,并不能读取具体数据内容(代理服务器也可以选择安装可信根证书解密 HTTPS 报文)。 + +事实上,这个过程其实是通过代理服务器将 HTTP 协议转换为 HTTPS 协议的过程。对 RPC 服务端来,需要做的是将 HTTP 协议转换为 RPC 协议,对客户端来说,需要新增通过 HTTP CONNECT 请求创建连接的逻辑。 + + +## 服务端支持 HTTP 协议 + +那通信过程应该是这样的: + +1) 客户端向 RPC 服务器发送 CONNECT 请求 + +```bash +CONNECT 10.0.0.1:9999/_geerpc_ HTTP/1.0 +``` + +2) RPC 服务器返回 HTTP 200 状态码表示连接建立。 + +```bash +HTTP/1.0 200 Connected to Gee RPC +``` + +3) 客户端使用创建好的连接发送 RPC 报文,先发送 Option,再发送 N 个请求报文,服务端处理 RPC 请求并响应。 + +在 `server.go` 中新增如下的方法: + +[day5-http-debug/server.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day5-http-debug) + +```go +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} +``` + +`defaultDebugPath` 是为后续 DEBUG 页面预留的地址。 + +在 Go 语言中处理 HTTP 请求是非常简单的一件事,Go 标准库中 `http.Handle` 的实现如下: + +```go +package http +// Handle registers the handler for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } +``` + +第一个参数是支持通配的字符串 pattern,在这里,我们固定传入 `/_geeprc_`,第二个参数是 Handler 类型,Handler 是一个接口类型,定义如下: + +```go +type Handler interface { + ServeHTTP(w ResponseWriter, r *Request) +} +``` + +也就是说,只需要实现接口 Handler 即可作为一个 HTTP Handler 处理 HTTP 请求。接口 Handler 只定义了一个方法 `ServeHTTP`,实现该方法即可。 + +> 关于 http.Handler 的更多信息,推荐阅读 [Go语言动手写Web框架 - Gee第一天 http.Handler](https://geektutu.com/post/gee-day1.html) + +## 客户端支持 HTTP 协议 + +服务端已经能够接受 CONNECT 请求,并返回了 200 状态码 `HTTP/1.0 200 Connected to Gee RPC`,客户端要做的,发起 CONNECT 请求,检查返回状态码即可成功建立连接。 + +[day5-http-debug/client.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day5-http-debug) + +```go +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(NewHTTPClient, network, address, opts...) +} +``` + +通过 HTTP CONNECT 请求建立连接之后,后续的通信过程就交给 NewClient 了。 + +为了简化调用,提供了一个统一入口 `XDial` + +```go +// XDial calls different functions to connect to a RPC server +// according the first parameter rpcAddr. +// rpcAddr is a general format (protocol@addr) to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } +} +``` + +添加一个测试用例试一试,这个测试用例使用了 unix 协议创建 socket 连接,适用于本机内部的通信,使用上与 TCP 协议并无区别。 + +[day5-http-debug/client_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day5-http-debug) + +```go +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} +``` + + +## 实现简单的 DEBUG 页面 + +支持 HTTP 协议的好处在于,RPC 服务仅仅使用了监听端口的 `/_geerpc` 路径,在其他路径上我们可以提供诸如日志、统计等更为丰富的功能。接下来我们在 `/debug/geerpc` 上展示服务的调用统计视图。 + +[day5-http-debug/debug.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day5-http-debug) + +```go +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} +``` + +在这里,我们将返回一个 HTML 报文,这个报文将展示注册所有的 service 的每一个方法的调用情况。 + +将 debugHTTP 实例绑定到地址 `/debug/geerpc`。 + +```go +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} +``` + +## Demo + +OK,我们已经迫不及待地想看看最终的效果了。 + +[day5-http-debug/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day5-http-debug) + +和之前的例子相比较,将 startServer 中的 `geerpc.Accept()` 替换为了 `geerpc.HandleHTTP()`,端口固定为 9999。 + +```go +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addrCh chan string) { + var foo Foo + l, _ := net.Listen("tcp", ":9999") + _ = geerpc.Register(&foo) + geerpc.HandleHTTP() + addrCh <- l.Addr().String() + _ = http.Serve(l, nil) +} +``` + +客户端将 `Dial` 替换为 `DialHTTP`,其余地方没有发生改变。 + +```go +func call(addrCh chan string) { + client, _ := geerpc.DialHTTP("tcp", <-addrCh) + defer func() { _ = client.Close() }() + + time.Sleep(time.Second) + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} + +func main() { + log.SetFlags(0) + ch := make(chan string) + go call(ch) + startServer(ch) +} +``` + +main 函数中,我们在最后调用 `startServer`,服务启动后将一直等待。 + +运行结果如下: + +```bash +main$ go run . +rpc server: register Foo.Sum +rpc server debug path: /debug/geerpc +3 + 9 = 12 +2 + 4 = 6 +4 + 16 = 20 +0 + 0 = 0 +1 + 1 = 2 +``` + +服务已经启动,此时我们如果在浏览器中访问 `localhost:9999/debug/geerpc`,将会看到: + +![geerpc services debug](geerpc-day5/geerpc_debug.png) + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc-day5/geerpc_debug.png b/gee-rpc/doc/geerpc-day5/geerpc_debug.png new file mode 100644 index 0000000..55a396d Binary files /dev/null and b/gee-rpc/doc/geerpc-day5/geerpc_debug.png differ diff --git a/gee-rpc/doc/geerpc-day6.md b/gee-rpc/doc/geerpc-day6.md new file mode 100644 index 0000000..7b0ab97 --- /dev/null +++ b/gee-rpc/doc/geerpc-day6.md @@ -0,0 +1,440 @@ +--- +title: 动手写RPC框架 - GeeRPC第六天 负载均衡(load balance) +date: 2020-10-08 14:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第六天实现了2种简单的负载均衡(load balance)算法,随机选择和 Round Robin 轮询调度算法。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 负载均衡 +- 轮询调度 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day6 负载均衡 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第六篇。 + +- 通过随机选择和 Round Robin 轮询调度算法实现服务端负载均衡,代码约 250 行 + +## 负载均衡策略 + +假设有多个服务实例,每个实例提供相同的功能,为了提高整个系统的吞吐量,每个实例部署在不同的机器上。客户端可以选择任意一个实例进行调用,获取想要的结果。那如何选择呢?取决了负载均衡的策略。对于 RPC 框架来说,我们可以很容易地想到这么几种策略: + +- 随机选择策略 - 从服务列表中随机选择一个。 +- 轮询算法(Round Robin) - 依次调度不同的服务器,每次调度执行 i = (i + 1) mode n。 +- 加权轮询(Weight Round Robin) - 在轮询算法的基础上,为每个服务实例设置一个权重,高性能的机器赋予更高的权重,也可以根据服务实例的当前的负载情况做动态的调整,例如考虑最近5分钟部署服务器的 CPU、内存消耗情况。 +- 哈希/一致性哈希策略 - 依据请求的某些特征,计算一个 hash 值,根据 hash 值将请求发送到对应的机器。一致性 hash 还可以解决服务实例动态添加情况下,调度抖动的问题。一致性哈希的一个典型应用场景是分布式缓存服务。感兴趣可以阅读[动手写分布式缓存 - GeeCache第四天 一致性哈希(hash)](https://geektutu.com/post/geecache-day4.html) +- ... + +## 服务发现 + +负载均衡的前提是有多个服务实例,那我们首先实现一个最基础的服务发现模块 Discovery。为了与通信部分解耦,这部分的代码统一放置在 xclient 子目录下。 + +定义 2 个类型: + +- SelectMode 代表不同的负载均衡策略,简单起见,GeeRPC 仅实现 Random 和 RoundRobin 两种策略。 +- Discovery 是一个接口类型,包含了服务发现所需要的最基本的接口。 + - Refresh() 从注册中心更新服务列表 + - Update(servers []string) 手动更新服务列表 + - Get(mode SelectMode) 根据负载均衡策略,选择一个服务实例 + - GetAll() 返回所有的服务实例 + +[day6-load-balance/xclient/discovery.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day6-load-balance) + +```go +package xclient + +import ( + "errors" + "math" + "math/rand" + "sync" + "time" +) + +type SelectMode int + +const ( + RandomSelect SelectMode = iota // select randomly + RoundRobinSelect // select using Robbin algorithm +) + +type Discovery interface { + Refresh() error // refresh from remote registry + Update(servers []string) error + Get(mode SelectMode) (string, error) + GetAll() ([]string, error) +} +``` + +紧接着,我们实现一个不需要注册中心,服务列表由手工维护的服务发现的结构体:MultiServersDiscovery + +```go +// MultiServersDiscovery is a discovery for multi servers without a registry center +// user provides the server addresses explicitly instead +type MultiServersDiscovery struct { + r *rand.Rand // generate random number + mu sync.RWMutex // protect following + servers []string + index int // record the selected position for robin algorithm +} + +// NewMultiServerDiscovery creates a MultiServersDiscovery instance +func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery { + d := &MultiServersDiscovery{ + servers: servers, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } + d.index = d.r.Intn(math.MaxInt32 - 1) + return d +} +``` + +- r 是一个产生随机数的实例,初始化时使用时间戳设定随机数种子,避免每次产生相同的随机数序列。 +- index 记录 Round Robin 算法已经轮询到的位置,为了避免每次从 0 开始,初始化时随机设定一个值。 + +然后,实现 Discovery 接口 + +```go +var _ Discovery = (*MultiServersDiscovery)(nil) + +// Refresh doesn't make sense for MultiServersDiscovery, so ignore it +func (d *MultiServersDiscovery) Refresh() error { + return nil +} + +// Update the servers of discovery dynamically if needed +func (d *MultiServersDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + return nil +} + +// Get a server according to mode +func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) { + d.mu.Lock() + defer d.mu.Unlock() + n := len(d.servers) + if n == 0 { + return "", errors.New("rpc discovery: no available servers") + } + switch mode { + case RandomSelect: + return d.servers[d.r.Intn(n)], nil + case RoundRobinSelect: + s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety + d.index = (d.index + 1) % n + return s, nil + default: + return "", errors.New("rpc discovery: not supported select mode") + } +} + +// returns all servers in discovery +func (d *MultiServersDiscovery) GetAll() ([]string, error) { + d.mu.RLock() + defer d.mu.RUnlock() + // return a copy of d.servers + servers := make([]string, len(d.servers), len(d.servers)) + copy(servers, d.servers) + return servers, nil +} +``` + +## 支持负载均衡的客户端 + +接下来,我们向用户暴露一个支持负载均衡的客户端 XClient。 + +[day6-load-balance/xclient/xclient.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day6-load-balance) + +```go +package xclient + +import ( + "context" + . "geerpc" + "io" + "reflect" + "sync" +) + +type XClient struct { + d Discovery + mode SelectMode + opt *Option + mu sync.Mutex // protect following + clients map[string]*Client +} + +var _ io.Closer = (*XClient)(nil) + +func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient { + return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)} +} + +func (xc *XClient) Close() error { + xc.mu.Lock() + defer xc.mu.Unlock() + for key, client := range xc.clients { + // I have no idea how to deal with error, just ignore it. + _ = client.Close() + delete(xc.clients, key) + } + return nil +} +``` + +XClient 的构造函数需要传入三个参数,服务发现实例 Discovery、负载均衡模式 SelectMode 以及协议选项 Option。为了尽量地复用已经创建好的 Socket 连接,使用 clients 保存创建成功的 Client 实例,并提供 Close 方法在结束后,关闭已经建立的连接。 + +接下来,实现客户端最基本的功能 `Call`。 + +```go +func (xc *XClient) dial(rpcAddr string) (*Client, error) { + xc.mu.Lock() + defer xc.mu.Unlock() + client, ok := xc.clients[rpcAddr] + if ok && !client.IsAvailable() { + _ = client.Close() + delete(xc.clients, rpcAddr) + client = nil + } + if client == nil { + var err error + client, err = XDial(rpcAddr, xc.opt) + if err != nil { + return nil, err + } + xc.clients[rpcAddr] = client + } + return client, nil +} + +func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error { + client, err := xc.dial(rpcAddr) + if err != nil { + return err + } + return client.Call(ctx, serviceMethod, args, reply) +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +// xc will choose a proper server. +func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + rpcAddr, err := xc.d.Get(xc.mode) + if err != nil { + return err + } + return xc.call(rpcAddr, ctx, serviceMethod, args, reply) +} +``` + +我们将复用 Client 的能力封装在方法 `dial` 中,dial 的处理逻辑如下: + +1) 检查 `xc.clients` 是否有缓存的 Client,如果有,检查是否是可用状态,如果是则返回缓存的 Client,如果不可用,则从缓存中删除。 +2) 如果步骤 1) 没有返回缓存的 Client,则说明需要创建新的 Client,缓存并返回。 + +另外,我们为 XClient 添加一个常用功能:`Broadcast`。 + +```go +// Broadcast invokes the named function for every server registered in discovery +func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error { + servers, err := xc.d.GetAll() + if err != nil { + return err + } + var wg sync.WaitGroup + var mu sync.Mutex // protect e and replyDone + var e error + replyDone := reply == nil // if reply is nil, don't need to set value + ctx, cancel := context.WithCancel(ctx) + for _, rpcAddr := range servers { + wg.Add(1) + go func(rpcAddr string) { + defer wg.Done() + var clonedReply interface{} + if reply != nil { + clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface() + } + err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply) + mu.Lock() + if err != nil && e == nil { + e = err + cancel() // if any call failed, cancel unfinished calls + } + if err == nil && !replyDone { + reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) + replyDone = true + } + mu.Unlock() + }(rpcAddr) + } + wg.Wait() + return e +} +``` + +Broadcast 将请求广播到所有的服务实例,如果任意一个实例发生错误,则返回其中一个错误;如果调用成功,则返回其中一个的结果。有以下几点需要注意: + +1) 为了提升性能,请求是并发的。 +2) 并发情况下需要使用互斥锁保证 error 和 reply 能被正确赋值。 +3) 借助 context.WithCancel 确保有错误发生时,快速失败。 + +## Demo + +又到了 Demo 环节,我们还是借助一个简单的 Demo 验证今天的成果吧。 + +首先,启动 RPC 服务的代码还是类似的,Sum 是正常的方法,Sleep 用于验证 XClient 的超时机制能否正常运作。 + +[day6-load-balance/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day6-load-balance) + +```go +package main + +import ( + "context" + "geerpc" + "geerpc/xclient" + "log" + "net" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func (f Foo) Sleep(args Args, reply *int) error { + time.Sleep(time.Second * time.Duration(args.Num1)) + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addrCh chan string) { + var foo Foo + l, _ := net.Listen("tcp", ":0") + server := geerpc.NewServer() + _ = server.Register(&foo) + addrCh <- l.Addr().String() + server.Accept(l) +} +``` + +封装一个方法 `foo`,便于在 `Call` 或 `Broadcast` 之后统一打印成功或失败的日志。 + +```go +func foo(xc *xclient.XClient, ctx context.Context, typ, serviceMethod string, args *Args) { + var reply int + var err error + switch typ { + case "call": + err = xc.Call(ctx, serviceMethod, args, &reply) + case "broadcast": + err = xc.Broadcast(ctx, serviceMethod, args, &reply) + } + if err != nil { + log.Printf("%s %s error: %v", typ, serviceMethod, err) + } else { + log.Printf("%s %s success: %d + %d = %d", typ, serviceMethod, args.Num1, args.Num2, reply) + } +} +``` + +call 调用单个服务实例,broadcast 调用所有服务实例 + +```go +func call(addr1, addr2 string) { + d := xclient.NewMultiServerDiscovery([]string{"tcp@" + addr1, "tcp@" + addr2}) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func broadcast(addr1, addr2 string) { + d := xclient.NewMultiServerDiscovery([]string{"tcp@" + addr1, "tcp@" + addr2}) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + // expect 2 - 5 timeout + ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + + +func main() { + log.SetFlags(0) + ch1 := make(chan string) + ch2 := make(chan string) + // start two servers + go startServer(ch1) + go startServer(ch2) + + addr1 := <-ch1 + addr2 := <-ch2 + + time.Sleep(time.Second) + call(addr1, addr2) + broadcast(addr1, addr2) +} +``` + +运行结果如下: + +```go +rpc server: register Foo.Sleep +rpc server: register Foo.Sum +rpc server: register Foo.Sleep +rpc server: register Foo.Sum +call Foo.Sum success: 4 + 16 = 20 +call Foo.Sum success: 0 + 0 = 0 +call Foo.Sum success: 3 + 9 = 12 +call Foo.Sum success: 2 + 4 = 6 +call Foo.Sum success: 1 + 1 = 2 +broadcast Foo.Sum success: 3 + 9 = 12 +broadcast Foo.Sum success: 1 + 1 = 2 +broadcast Foo.Sum success: 0 + 0 = 0 +broadcast Foo.Sum success: 4 + 16 = 20 +broadcast Foo.Sum success: 2 + 4 = 6 +broadcast Foo.Sleep success: 0 + 0 = 0 +broadcast Foo.Sleep success: 1 + 1 = 2 +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +``` + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc-day7.md b/gee-rpc/doc/geerpc-day7.md new file mode 100644 index 0000000..c7bb6c2 --- /dev/null +++ b/gee-rpc/doc/geerpc-day7.md @@ -0,0 +1,397 @@ +--- +title: 动手写RPC框架 - GeeRPC第七天 服务发现与注册中心(registry) +date: 2020-10-08 16:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。第七天实现了一个简单的注册中心(registry),具备超时移除、接收心跳(heartbeat)等能力,并且实现了一个简单的服务发现(server discovery)模块。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 注册中心 +- 服务发现 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day7 服务发现与注册中心 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +本文是[7天用Go从零实现RPC框架GeeRPC](https://geektutu.com/post/geerpc.html)的第七篇。 + +- 实现一个简单的注册中心,支持服务注册、接收心跳等功能 +- 客户端实现基于注册中心的服务发现机制,代码约 250 行 + +## 注册中心的位置 + +![geerpc registry](geerpc-day7/registry.jpg) + +注册中心的位置如上图所示。注册中心的好处在于,客户端和服务端都只需要感知注册中心的存在,而无需感知对方的存在。更具体一些: + +1) 服务端启动后,向注册中心发送注册消息,注册中心得知该服务已经启动,处于可用状态。一般来说,服务端还需要定期向注册中心发送心跳,证明自己还活着。 +2) 客户端向注册中心询问,当前哪天服务是可用的,注册中心将可用的服务列表返回客户端。 +3) 客户端根据注册中心得到的服务列表,选择其中一个发起调用。 + +如果没有注册中心,就像 GeeRPC 第六天实现的一样,客户端需要硬编码服务端的地址,而且没有机制保证服务端是否处于可用状态。当然注册中心的功能还有很多,比如配置的动态同步、通知机制等。比较常用的注册中心有 [etcd](https://github.com/etcd-io/etcd)、[zookeeper](https://github.com/apache/zookeeper)、[consul](https://github.com/hashicorp/consul),一般比较出名的微服务或者 RPC 框架,这些主流的注册中心都是支持的。 + + +## Gee Registry + +主流的注册中心 etcd、zookeeper 等功能强大,与这类注册中心的对接代码量是比较大的,需要实现的接口很多。GeeRPC 选择自己实现一个简单的支持心跳保活的注册中心。 + +GeeRegistry 的代码独立放置在子目录 registry 中。 + +首先定义 GeeRegistry 结构体,默认超时时间设置为 5 min,也就是说,任何注册的服务超过 5 min,即视为不可用状态。 + +[day7-registry/registry/registry.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day7-registry) + +```go +// GeeRegistry is a simple register center, provide following functions. +// add a server and receive heartbeat to keep it alive. +// returns all alive servers and delete dead servers sync simultaneously. +type GeeRegistry struct { + timeout time.Duration + mu sync.Mutex // protect following + servers map[string]*ServerItem +} + +type ServerItem struct { + Addr string + start time.Time +} + +const ( + defaultPath = "/_geerpc_/registry" + defaultTimeout = time.Minute * 5 +) + +// New create a registry instance with timeout setting +func New(timeout time.Duration) *GeeRegistry { + return &GeeRegistry{ + servers: make(map[string]*ServerItem), + timeout: timeout, + } +} + +var DefaultGeeRegister = New(defaultTimeout) +``` + +为 GeeRegistry 实现添加服务实例和返回服务列表的方法。 + +- putServer:添加服务实例,如果服务已经存在,则更新 start。 +- aliveServers:返回可用的服务列表,如果存在超时的服务,则删除。 + +```go +func (r *GeeRegistry) putServer(addr string) { + r.mu.Lock() + defer r.mu.Unlock() + s := r.servers[addr] + if s == nil { + r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()} + } else { + s.start = time.Now() // if exists, update start time to keep alive + } +} + +func (r *GeeRegistry) aliveServers() []string { + r.mu.Lock() + defer r.mu.Unlock() + var alive []string + for addr, s := range r.servers { + if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) { + alive = append(alive, addr) + } else { + delete(r.servers, addr) + } + } + sort.Strings(alive) + return alive +} +``` + +为了实现上的简单,GeeRegistry 采用 HTTP 协议提供服务,且所有的有用信息都承载在 HTTP Header 中。 + +- Get:返回所有可用的服务列表,通过自定义字段 X-Geerpc-Servers 承载。 +- Post:添加服务实例或发送心跳,通过自定义字段 X-Geerpc-Server 承载。 + +```go +// Runs at /_geerpc_/registry +func (r *GeeRegistry) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + // keep it simple, server is in req.Header + w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ",")) + case "POST": + // keep it simple, server is in req.Header + addr := req.Header.Get("X-Geerpc-Server") + if addr == "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + r.putServer(addr) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath +func (r *GeeRegistry) HandleHTTP(registryPath string) { + http.Handle(registryPath, r) + log.Println("rpc registry path:", registryPath) +} + +func HandleHTTP() { + DefaultGeeRegister.HandleHTTP(defaultPath) +} +``` + +另外,提供 Heartbeat 方法,便于服务启动时定时向注册中心发送心跳,默认周期比注册中心设置的过期时间少 1 min。 + +```go +// Heartbeat send a heartbeat message every once in a while +// it's a helper function for a server to register or send heartbeat +func Heartbeat(registry, addr string, duration time.Duration) { + if duration == 0 { + // make sure there is enough time to send heart beat + // before it's removed from registry + duration = defaultTimeout - time.Duration(1)*time.Minute + } + var err error + err = sendHeartbeat(registry, addr) + go func() { + t := time.NewTicker(duration) + for err == nil { + <-t.C + err = sendHeartbeat(registry, addr) + } + }() +} + +func sendHeartbeat(registry, addr string) error { + log.Println(addr, "send heart beat to registry", registry) + httpClient := &http.Client{} + req, _ := http.NewRequest("POST", registry, nil) + req.Header.Set("X-Geerpc-Server", addr) + if _, err := httpClient.Do(req); err != nil { + log.Println("rpc server: heart beat err:", err) + return err + } + return nil +} +``` + +## GeeRegistryDiscovery + +在 xclient 中对应实现 Discovery。 + +[day7-registry/xclient/discovery_gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day7-registry) + +```go +package xclient + +type GeeRegistryDiscovery struct { + *MultiServersDiscovery + registry string + timeout time.Duration + lastUpdate time.Time +} + +const defaultUpdateTimeout = time.Second * 10 + +func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery { + if timeout == 0 { + timeout = defaultUpdateTimeout + } + d := &GeeRegistryDiscovery{ + MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)), + registry: registerAddr, + timeout: timeout, + } + return d +} +``` + +- GeeRegistryDiscovery 嵌套了 MultiServersDiscovery,很多能力可以复用。 +- registry 即注册中心的地址 +- timeout 服务列表的过期时间 +- lastUpdate 是代表最后从注册中心更新服务列表的时间,默认 10s 过期,即 10s 之后,需要从注册中心更新新的列表。 + +实现 Update 和 Refresh 方法,超时重新获取的逻辑在 Refresh 中实现: + +```go +func (d *GeeRegistryDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + d.lastUpdate = time.Now() + return nil +} + +func (d *GeeRegistryDiscovery) Refresh() error { + d.mu.Lock() + defer d.mu.Unlock() + if d.lastUpdate.Add(d.timeout).After(time.Now()) { + return nil + } + log.Println("rpc registry: refresh servers from registry", d.registry) + resp, err := http.Get(d.registry) + if err != nil { + log.Println("rpc registry refresh err:", err) + return err + } + servers := strings.Split(resp.Header.Get("X-Geerpc-Servers"), ",") + d.servers = make([]string, 0, len(servers)) + for _, server := range servers { + if strings.TrimSpace(server) != "" { + d.servers = append(d.servers, strings.TrimSpace(server)) + } + } + d.lastUpdate = time.Now() + return nil +} +``` + +`Get` 和 `GetAll` 与 MultiServersDiscovery 相似,唯一的不同在于,GeeRegistryDiscovery 需要先调用 Refresh 确保服务列表没有过期。 + +```go +func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) { + if err := d.Refresh(); err != nil { + return "", err + } + return d.MultiServersDiscovery.Get(mode) +} + +func (d *GeeRegistryDiscovery) GetAll() ([]string, error) { + if err := d.Refresh(); err != nil { + return nil, err + } + return d.MultiServersDiscovery.GetAll() +} +``` + +## Demo + +最后,依旧通过简单的 Demo 验证今天的成果。 + +添加函数 startRegistry,稍微修改 startServer,添加调用注册中心的 `Heartbeat` 方法的逻辑,定期向注册中心发送心跳保活。 + +[day7-registry/main/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day7-registry) + +```go +func startRegistry(wg *sync.WaitGroup) { + l, _ := net.Listen("tcp", ":9999") + registry.HandleHTTP() + wg.Done() + _ = http.Serve(l, nil) +} + +func startServer(registryAddr string, wg *sync.WaitGroup) { + var foo Foo + l, _ := net.Listen("tcp", ":0") + server := geerpc.NewServer() + _ = server.Register(&foo) + registry.Heartbeat(registryAddr, "tcp@"+l.Addr().String(), 0) + wg.Done() + server.Accept(l) +} +``` + +接下来,将 call 和 broadcast 的 MultiServersDiscovery 替换为 GeeRegistryDiscovery,不再需要硬编码服务列表。 + +```go +func call(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func broadcast(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + // expect 2 - 5 timeout + ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} +``` + +最后在 main 函数中,将所有的逻辑串联起来,确保注册中心启动后,再启动 RPC 服务端,最后客户端远程调用。 + +```go +func main() { + log.SetFlags(0) + registryAddr := "http://localhost:9999/_geerpc_/registry" + var wg sync.WaitGroup + wg.Add(1) + go startRegistry(&wg) + wg.Wait() + + time.Sleep(time.Second) + wg.Add(2) + go startServer(registryAddr, &wg) + go startServer(registryAddr, &wg) + wg.Wait() + + time.Sleep(time.Second) + call(registryAddr) + broadcast(registryAddr) +} +``` + +运行结果如下: + +```go +rpc registry path: /_geerpc_/registry +rpc server: register Foo.Sleep +rpc server: register Foo.Sum +tcp@[::]:56276 send heart beat to registry http://localhost:9999/_geerpc_/registry +rpc server: register Foo.Sleep +rpc server: register Foo.Sum +tcp@[::]:56277 send heart beat to registry http://localhost:9999/_geerpc_/registry +rpc registry: refresh servers from registry http://localhost:9999/_geerpc_/registry +call Foo.Sum success: 3 + 9 = 12 +call Foo.Sum success: 4 + 16 = 20 +call Foo.Sum success: 1 + 1 = 2 +call Foo.Sum success: 0 + 0 = 0 +call Foo.Sum success: 2 + 4 = 6 +rpc registry: refresh servers from registry http://localhost:9999/_geerpc_/registry +broadcast Foo.Sum success: 4 + 16 = 20 +broadcast Foo.Sum success: 1 + 1 = 2 +broadcast Foo.Sum success: 3 + 9 = 12 +broadcast Foo.Sum success: 0 + 0 = 0 +broadcast Foo.Sum success: 2 + 4 = 6 +broadcast Foo.Sleep success: 0 + 0 = 0 +broadcast Foo.Sleep success: 1 + 1 = 2 +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded +``` + +到这里,七天用 Go 从零实现 RPC 框架的教程也结束了。我们用七天时间参照 golang 标准库 net/rpc,实现了服务端以及支持并发的客户端,并且支持选择不同的序列化与反序列化方式;为了防止服务挂死,在其中一些关键部分添加了超时处理机制;支持 TCP、Unix、HTTP 等多种传输协议;支持多种负载均衡模式,最后还实现了一个简易的服务注册和发现中心。 + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc-day7/registry.jpg b/gee-rpc/doc/geerpc-day7/registry.jpg new file mode 100644 index 0000000..7261132 Binary files /dev/null and b/gee-rpc/doc/geerpc-day7/registry.jpg differ diff --git a/gee-rpc/doc/geerpc.md b/gee-rpc/doc/geerpc.md new file mode 100644 index 0000000..a010128 --- /dev/null +++ b/gee-rpc/doc/geerpc.md @@ -0,0 +1,68 @@ +--- +title: 7天用Go从零实现RPC框架GeeRPC +date: 2020-10-06 16:00:00 +description: 7天用 Go语言/golang 从零实现 RPC 框架 GeeRPC 教程(7 days implement golang remote procedure call framework from scratch tutorial),动手写 RPC 框架,参照 golang 标准库 net/rpc 的实现,实现了服务端(server)、支持异步和并发的客户端(client)、消息编码与解码(message encoding and decoding)、服务注册(service register)、支持 TCP/Unix/HTTP 等多种传输协议。并在此基础上新增了协议交换(protocol exchange)、注册中心(registry)、服务发现(service discovery)、负载均衡(load balance)、超时处理(timeout processing)等特性。 +tags: +- Go +nav: 从零实现 +categories: +- RPC框架 - GeeRPC +keywords: +- Go语言 +- 从零实现RPC框架 +- 动手写RPC框架 +- 服务注册与发现 +- 负载均衡 +image: post/geerpc/geerpc.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day0 序言 +--- + +![golang RPC framework](geerpc/geerpc.jpg) + +## 1 谈谈 RPC 框架 + +RPC(Remote Procedure Call,远程过程调用)是一种计算机通信协议,允许调用不同进程空间的程序。RPC 的客户端和服务器可以在一台机器上,也可以在不同的机器上。程序员使用时,就像调用本地程序一样,无需关注内部的实现细节。 + +不同的应用程序之间的通信方式有很多,比如浏览器和服务器之间广泛使用的基于 HTTP 协议的 Restful API。与 RPC 相比,Restful API 有相对统一的标准,因而更通用,兼容性更好,支持不同的语言。HTTP 协议是基于文本的,一般具备更好的可读性。但是缺点也很明显: + +- Restful 接口需要额外的定义,无论是客户端还是服务端,都需要额外的代码来处理,而 RPC 调用则更接近于直接调用。 +- 基于 HTTP 协议的 Restful 报文冗余,承载了过多的无效信息,而 RPC 通常使用自定义的协议格式,减少冗余报文。 +- RPC 可以采用更高效的序列化协议,将文本转为二进制传输,获得更高的性能。 +- 因为 RPC 的灵活性,所以更容易扩展和集成诸如注册中心、负载均衡等功能。 + +## 2 RPC 框架需要解决什么问题 + +RPC 框架需要解决什么问题?或者我们换一个问题,为什么需要 RPC 框架? + +我们可以想象下两台机器上,两个应用程序之间需要通信,那么首先,需要确定采用的传输协议是什么?如果这个两个应用程序位于不同的机器,那么一般会选择 TCP 协议或者 HTTP 协议;那如果两个应用程序位于相同的机器,也可以选择 Unix Socket 协议。传输协议确定之后,还需要确定报文的编码格式,比如采用最常用的 JSON 或者 XML,那如果报文比较大,还可能会选择 protobuf 等其他的编码方式,甚至编码之后,再进行压缩。接收端获取报文则需要相反的过程,先解压再解码。 + +解决了传输协议和报文编码的问题,接下来还需要解决一系列的可用性问题,例如,连接超时了怎么办?是否支持异步请求和并发? + +如果服务端的实例很多,客户端并不关心这些实例的地址和部署位置,只关心自己能否获取到期待的结果,那就引出了注册中心(registry)和负载均衡(load balance)的问题。简单地说,即客户端和服务端互相不感知对方的存在,服务端启动时将自己注册到注册中心,客户端调用时,从注册中心获取到所有可用的实例,选择一个来调用。这样服务端和客户端只需要感知注册中心的存在就够了。注册中心通常还需要实现服务动态添加、删除,使用心跳确保服务处于可用状态等功能。 + +再进一步,假设服务端是不同的团队提供的,如果没有统一的 RPC 框架,各个团队的服务提供方就需要各自实现一套消息编解码、连接池、收发线程、超时处理等“业务之外”的重复技术劳动,造成整体的低效。因此,“业务之外”的这部分公共的能力,即是 RPC 框架所需要具备的能力。 + +## 3 关于 GeeRPC + +Go 语言广泛地应用于云计算和微服务,成熟的 RPC 框架和微服务框架汗牛充栋。`grpc`、`rpcx`、`go-micro` 等都是非常成熟的框架。一般而言,RPC 是微服务框架的一个子集,微服务框架可以自己实现 RPC 部分,当然,也可以选择不同的 RPC 框架作为通信基座。 + +考虑性能和功能,上述成熟的框架代码量都比较庞大,而且通常和第三方库,例如 `protobuf`、`etcd`、`zookeeper` 等有比较深的耦合,难以直观地窥视框架的本质。GeeRPC 的目的是以最少的代码,实现 RPC 框架中最为重要的部分,帮助大家理解 RPC 框架在设计时需要考虑什么。代码简洁是第一位的,功能是第二位的。 + +因此,GeeRPC 选择从零实现 Go 语言官方的标准库 `net/rpc`,并在此基础上,新增了协议交换(protocol exchange)、注册中心(registry)、服务发现(service discovery)、负载均衡(load balance)、超时处理(timeout processing)等特性。分七天完成,最终代码约 1000 行。 + +## 4 目录 + +- 第一天 - [服务端与消息编码](https://geektutu.com/post/geerpc-day1.html) | [Code](ghttps://github.com/geektutu/7days-golang/tree/master/ee-rpc/day1-codec) +- 第二天 - [支持并发与异步的客户端](https://geektutu.com/post/geerpc-day2.html) | [Code](ghttps://github.com/geektutu/7days-golang/tree/master/ee-rpc/day2-client) +- 第三天 - [服务注册(service register)](https://geektutu.com/post/geerpc-day3.html) | [Code](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day3-service ) +- 第四天 - [超时处理(timeout)](https://geektutu.com/post/geerpc-day4.html) | [Code](ghttps://github.com/geektutu/7days-golang/tree/master/ee-rpc/day4-timeout ) +- 第五天 - [支持HTTP协议](https://geektutu.com/post/geerpc-day5.html) | [Code](ghttps://github.com/geektutu/7days-golang/tree/master/ee-rpc/day5-http-debug) +- 第六天 - [负载均衡(load balance)](https://geektutu.com/post/geerpc-day6.html) | [Code](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day6-load-balance) +- 第七天 - [服务发现与注册中心(registry)](https://geektutu.com/post/geerpc-day7.html) | [Code](https://github.com/geektutu/7days-golang/tree/master/gee-rpc/day7-registry) + +## 附 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) \ No newline at end of file diff --git a/gee-rpc/doc/geerpc/geerpc.jpg b/gee-rpc/doc/geerpc/geerpc.jpg new file mode 100644 index 0000000..5a86ebe Binary files /dev/null and b/gee-rpc/doc/geerpc/geerpc.jpg differ diff --git a/gee-web/README.md b/gee-web/README.md new file mode 100644 index 0000000..d56ed8d --- /dev/null +++ b/gee-web/README.md @@ -0,0 +1,224 @@ +# 7 Days Go Web Framework Gee from Scratch + +
+README 中文版本 +
+ +## 7天用Go从零实现Web框架Gee + +![Gee](doc/gee/gee.jpg) + +### Content + +- [第一天:前置知识(http.Handler接口)](https://geektutu.com/post/gee-day1.html) +- [第二天:上下文设计(Context)](https://geektutu.com/post/gee-day2.html) +- [第三天:Trie树路由(Router)](https://geektutu.com/post/gee-day3.html) +- [第四天:分组控制(Group)](https://geektutu.com/post/gee-day4.html) +- [第五天:中间件(Middleware)](https://geektutu.com/post/gee-day5.html) +- [第六天:HTML模板(Template)](https://geektutu.com/post/gee-day6.html) +- [第七天:错误恢复(Panic Recover)](https://geektutu.com/post/gee-day7.html) + +
+
+ +## Day 1 - Static Route + +```go +func main() { + r := gee.New() + r.GET("/", func(w http.ResponseWriter, req *http.Request) { + fmt.Fprintf(w, "URL.Path = %q\n", req.URL.Path) + }) + + r.GET("/hello", func(w http.ResponseWriter, req *http.Request) { + for k, v := range req.Header { + fmt.Fprintf(w, "Header[%q] = %q\n", k, v) + } + }) + + r.Run(":9999") +} +``` + +## Day 2 - Context Design + +```go +func main() { + r := gee.New() + r.GET("/", func(c *gee.Context) { + c.HTML(http.StatusOK, "

Hello Gee

") + }) + r.GET("/hello", func(c *gee.Context) { + // expect /hello?name=geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) + }) + + r.POST("/login", func(c *gee.Context) { + c.JSON(http.StatusOK, &map[string]string{ + "username": c.PostForm("username"), + "password": c.PostForm("password"), + }) + }) + + r.Run(":9999") +} +``` + +## Day 3 - Dynamic Route + +```go +func main() { + r := gee.New() + r.GET("/", func(c *gee.Context) { + c.HTML(http.StatusOK, "

Hello Gee

") + }) + + r.GET("/hello", func(c *gee.Context) { + // expect /hello?name=geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) + }) + + r.GET("/hello/:name", func(c *gee.Context) { + // expect /hello/geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) + }) + + r.GET("/assets/*filepath", func(c *gee.Context) { + c.JSON(http.StatusOK, gee.H{"filepath": c.Param("filepath")}) + }) + + r.Run(":9999") +} +``` + +## Day 4 - Nesting Group Control + +```go +func main() { + r := gee.New() + v1 := r.Group("/v1") + { + v1.GET("/", func(c *gee.Context) { + c.HTML(http.StatusOK, "

Hello Gee

") + }) + + v1.GET("/hello", func(c *gee.Context) { + // expect /hello?name=geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Query("name"), c.Path) + }) + } + v2 := r.Group("/v2") + { + v2.GET("/hello/:name", func(c *gee.Context) { + // expect /hello/geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) + }) + v2.POST("/login", func(c *gee.Context) { + c.JSON(http.StatusOK, &map[string]string{ + "username": c.PostForm("username"), + "password": c.PostForm("password"), + }) + }) + + } + + r.Run(":9999") +} +``` + +## Day 5 - Middleware + +```go +func onlyForV2() gee.HandlerFunc { + return func(c *gee.Context) { + // Start timer + t := time.Now() + // if a server error occurred + c.Fail(500, "Internal Server Error") + // Calculate resolution time + log.Printf("[%d] %s in %v for group v2", c.StatusCode, c.Req.RequestURI, time.Since(t)) + } +} + +func main() { + r := gee.New() + r.Use(gee.Logger()) // global midlleware + r.GET("/", func(c *gee.Context) { + c.HTML(http.StatusOK, "

Hello Gee

") + }) + + v2 := r.Group("/v2") + v2.Use(onlyForV2()) // v2 group middleware + { + v2.GET("/hello/:name", func(c *gee.Context) { + // expect /hello/geektutu + c.String(http.StatusOK, "hello %s, you're at %s\n", c.Param("name"), c.Path) + }) + } + + r.Run(":9999") +} +``` + +## Day 6 - HTML Template + +```go +type student struct { + Name string + Age int8 +} + +func FormatAsDate(t time.Time) string { + year, month, day := t.Date() + return fmt.Sprintf("%d-%02d-%02d", year, month, day) +} + +func main() { + r := gee.New() + r.Use(gee.Logger()) + r.SetFuncMap(template.FuncMap{ + "FormatAsDate": FormatAsDate, + }) + r.LoadHTMLGlob("templates/*") + r.Static("/assets", "./static") + + stu1 := &student{Name: "Geektutu", Age: 20} + stu2 := &student{Name: "Jack", Age: 22} + r.GET("/", func(c *gee.Context) { + c.HTML(http.StatusOK, "css.tmpl", nil) + }) + r.GET("/students", func(c *gee.Context) { + c.HTML(http.StatusOK, "arr.tmpl", gee.H{ + "title": "gee", + "stuArr": [2]*student{stu1, stu2}, + }) + }) + + r.GET("/date", func(c *gee.Context) { + c.HTML(http.StatusOK, "custom_func.tmpl", gee.H{ + "title": "gee", + "now": time.Date(2019, 8, 17, 0, 0, 0, 0, time.UTC), + }) + }) + + r.Run(":9999") +} +``` + +## Day 7 - Panic Recover + +```go +func main() { + r := gee.Default() + r.GET("/", func(c *gee.Context) { + c.String(http.StatusOK, "Hello Geektutu\n") + }) + // index out of range for testing Recovery() + r.GET("/panic", func(c *gee.Context) { + names := []string{"geektutu"} + c.String(http.StatusOK, names[100]) + }) + + r.Run(":9999") +} +``` \ No newline at end of file diff --git a/gee-web/day1-http-base/base1/go.mod b/gee-web/day1-http-base/base1/go.mod new file mode 100644 index 0000000..8d2394a --- /dev/null +++ b/gee-web/day1-http-base/base1/go.mod @@ -0,0 +1,3 @@ +module example + +go 1.13 diff --git a/day1-http-base/base1/main.go b/gee-web/day1-http-base/base1/main.go similarity index 100% rename from day1-http-base/base1/main.go rename to gee-web/day1-http-base/base1/main.go diff --git a/gee-web/day1-http-base/base2/go.mod b/gee-web/day1-http-base/base2/go.mod new file mode 100644 index 0000000..8d2394a --- /dev/null +++ b/gee-web/day1-http-base/base2/go.mod @@ -0,0 +1,3 @@ +module example + +go 1.13 diff --git a/day1-http-base/base2/main.go b/gee-web/day1-http-base/base2/main.go similarity index 100% rename from day1-http-base/base2/main.go rename to gee-web/day1-http-base/base2/main.go diff --git a/day1-http-base/base3/gee/gee.go b/gee-web/day1-http-base/base3/gee/gee.go similarity index 95% rename from day1-http-base/base3/gee/gee.go rename to gee-web/day1-http-base/base3/gee/gee.go index 8c38b4c..3aefd75 100644 --- a/day1-http-base/base3/gee/gee.go +++ b/gee-web/day1-http-base/base3/gee/gee.go @@ -2,6 +2,7 @@ package gee import ( "fmt" + "log" "net/http" ) @@ -20,6 +21,7 @@ func New() *Engine { func (engine *Engine) addRoute(method string, pattern string, handler HandlerFunc) { key := method + "-" + pattern + log.Printf("Route %4s - %s", method, pattern) engine.router[key] = handler } diff --git a/gee-web/day1-http-base/base3/gee/go.mod b/gee-web/day1-http-base/base3/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day1-http-base/base3/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/gee-web/day1-http-base/base3/go.mod b/gee-web/day1-http-base/base3/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day1-http-base/base3/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day1-http-base/base3/main.go b/gee-web/day1-http-base/base3/main.go similarity index 98% rename from day1-http-base/base3/main.go rename to gee-web/day1-http-base/base3/main.go index 5f55dc2..5cdac75 100644 --- a/day1-http-base/base3/main.go +++ b/gee-web/day1-http-base/base3/main.go @@ -12,7 +12,7 @@ import ( "fmt" "net/http" - "./gee" + "gee" ) func main() { diff --git a/day2-context/gee/context.go b/gee-web/day2-context/gee/context.go similarity index 100% rename from day2-context/gee/context.go rename to gee-web/day2-context/gee/context.go index 795bf16..72a7fe7 100644 --- a/day2-context/gee/context.go +++ b/gee-web/day2-context/gee/context.go @@ -46,14 +46,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -66,7 +66,7 @@ func (c *Context) Data(code int, data []byte) { } func (c *Context) HTML(code int, html string) { - c.Status(code) c.SetHeader("Content-Type", "text/html") + c.Status(code) c.Writer.Write([]byte(html)) } diff --git a/day3-router/gee/gee.go b/gee-web/day2-context/gee/gee.go similarity index 95% rename from day3-router/gee/gee.go rename to gee-web/day2-context/gee/gee.go index ce0d4da..802bb5e 100644 --- a/day3-router/gee/gee.go +++ b/gee-web/day2-context/gee/gee.go @@ -1,6 +1,7 @@ package gee import ( + "log" "net/http" ) @@ -18,6 +19,7 @@ func New() *Engine { } func (engine *Engine) addRoute(method string, pattern string, handler HandlerFunc) { + log.Printf("Route %4s - %s", method, pattern) engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day2-context/gee/go.mod b/gee-web/day2-context/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day2-context/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day2-context/gee/router.go b/gee-web/day2-context/gee/router.go similarity index 100% rename from day2-context/gee/router.go rename to gee-web/day2-context/gee/router.go diff --git a/gee-web/day2-context/go.mod b/gee-web/day2-context/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day2-context/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day2-context/main.go b/gee-web/day2-context/main.go similarity index 99% rename from day2-context/main.go rename to gee-web/day2-context/main.go index a493d3f..6f86bee 100644 --- a/day2-context/main.go +++ b/gee-web/day2-context/main.go @@ -25,7 +25,7 @@ $ curl "http://localhost:9999/xxx" import ( "net/http" - "./gee" + "gee" ) func main() { diff --git a/day3-router/gee/context.go b/gee-web/day3-router/gee/context.go similarity index 100% rename from day3-router/gee/context.go rename to gee-web/day3-router/gee/context.go index 2733bdb..cf79939 100644 --- a/day3-router/gee/context.go +++ b/gee-web/day3-router/gee/context.go @@ -52,14 +52,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -72,7 +72,7 @@ func (c *Context) Data(code int, data []byte) { } func (c *Context) HTML(code int, html string) { - c.Status(code) c.SetHeader("Content-Type", "text/html") + c.Status(code) c.Writer.Write([]byte(html)) } diff --git a/day2-context/gee/gee.go b/gee-web/day3-router/gee/gee.go similarity index 95% rename from day2-context/gee/gee.go rename to gee-web/day3-router/gee/gee.go index ce0d4da..802bb5e 100644 --- a/day2-context/gee/gee.go +++ b/gee-web/day3-router/gee/gee.go @@ -1,6 +1,7 @@ package gee import ( + "log" "net/http" ) @@ -18,6 +19,7 @@ func New() *Engine { } func (engine *Engine) addRoute(method string, pattern string, handler HandlerFunc) { + log.Printf("Route %4s - %s", method, pattern) engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day3-router/gee/go.mod b/gee-web/day3-router/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day3-router/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day3-router/gee/router.go b/gee-web/day3-router/gee/router.go similarity index 100% rename from day3-router/gee/router.go rename to gee-web/day3-router/gee/router.go diff --git a/day3-router/gee/router_test.go b/gee-web/day3-router/gee/router_test.go similarity index 100% rename from day3-router/gee/router_test.go rename to gee-web/day3-router/gee/router_test.go diff --git a/day3-router/gee/trie.go b/gee-web/day3-router/gee/trie.go similarity index 100% rename from day3-router/gee/trie.go rename to gee-web/day3-router/gee/trie.go diff --git a/gee-web/day3-router/go.mod b/gee-web/day3-router/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day3-router/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day3-router/main.go b/gee-web/day3-router/main.go similarity index 99% rename from day3-router/main.go rename to gee-web/day3-router/main.go index 60a7d9a..25f623e 100644 --- a/day3-router/main.go +++ b/gee-web/day3-router/main.go @@ -29,7 +29,7 @@ $ curl "http://localhost:9999/xxx" import ( "net/http" - "./gee" + "gee" ) func main() { diff --git a/day4-group/gee/context.go b/gee-web/day4-group/gee/context.go similarity index 100% rename from day4-group/gee/context.go rename to gee-web/day4-group/gee/context.go index 2733bdb..cf79939 100644 --- a/day4-group/gee/context.go +++ b/gee-web/day4-group/gee/context.go @@ -52,14 +52,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -72,7 +72,7 @@ func (c *Context) Data(code int, data []byte) { } func (c *Context) HTML(code int, html string) { - c.Status(code) c.SetHeader("Content-Type", "text/html") + c.Status(code) c.Writer.Write([]byte(html)) } diff --git a/day4-group/gee/gee.go b/gee-web/day4-group/gee/gee.go similarity index 97% rename from day4-group/gee/gee.go rename to gee-web/day4-group/gee/gee.go index 8c5d374..1e5cda3 100644 --- a/day4-group/gee/gee.go +++ b/gee-web/day4-group/gee/gee.go @@ -1,6 +1,7 @@ package gee import ( + "log" "net/http" ) @@ -46,6 +47,7 @@ func (group *RouterGroup) Group(prefix string) *RouterGroup { func (group *RouterGroup) addRoute(method string, comp string, handler HandlerFunc) { pattern := group.prefix + comp + log.Printf("Route %4s - %s", method, pattern) group.engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day4-group/gee/gee_test.go b/gee-web/day4-group/gee/gee_test.go new file mode 100644 index 0000000..f0c9577 --- /dev/null +++ b/gee-web/day4-group/gee/gee_test.go @@ -0,0 +1,16 @@ +package gee + +import "testing" + +func TestNestedGroup(t *testing.T) { + r := New() + v1 := r.Group("/v1") + v2 := v1.Group("/v2") + v3 := v2.Group("/v3") + if v2.prefix != "/v1/v2" { + t.Fatal("v2 prefix should be /v1/v2") + } + if v3.prefix != "/v1/v2/v3" { + t.Fatal("v2 prefix should be /v1/v2") + } +} diff --git a/gee-web/day4-group/gee/go.mod b/gee-web/day4-group/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day4-group/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day4-group/gee/router.go b/gee-web/day4-group/gee/router.go similarity index 100% rename from day4-group/gee/router.go rename to gee-web/day4-group/gee/router.go diff --git a/day4-group/gee/router_test.go b/gee-web/day4-group/gee/router_test.go similarity index 100% rename from day4-group/gee/router_test.go rename to gee-web/day4-group/gee/router_test.go diff --git a/day4-group/gee/trie.go b/gee-web/day4-group/gee/trie.go similarity index 100% rename from day4-group/gee/trie.go rename to gee-web/day4-group/gee/trie.go diff --git a/gee-web/day4-group/go.mod b/gee-web/day4-group/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day4-group/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day4-group/main.go b/gee-web/day4-group/main.go similarity index 99% rename from day4-group/main.go rename to gee-web/day4-group/main.go index 68c2d74..4336a9c 100644 --- a/day4-group/main.go +++ b/gee-web/day4-group/main.go @@ -37,7 +37,7 @@ $ curl "http://localhost:9999/hello" import ( "net/http" - "./gee" + "gee" ) func main() { diff --git a/day5-middleware/gee/context.go b/gee-web/day5-middleware/gee/context.go similarity index 100% rename from day5-middleware/gee/context.go rename to gee-web/day5-middleware/gee/context.go index 63eca76..1885e0c 100644 --- a/day5-middleware/gee/context.go +++ b/gee-web/day5-middleware/gee/context.go @@ -69,14 +69,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -89,7 +89,7 @@ func (c *Context) Data(code int, data []byte) { } func (c *Context) HTML(code int, html string) { - c.Status(code) c.SetHeader("Content-Type", "text/html") + c.Status(code) c.Writer.Write([]byte(html)) } diff --git a/day5-middleware/gee/gee.go b/gee-web/day5-middleware/gee/gee.go similarity index 97% rename from day5-middleware/gee/gee.go rename to gee-web/day5-middleware/gee/gee.go index e880d19..5c237c1 100644 --- a/day5-middleware/gee/gee.go +++ b/gee-web/day5-middleware/gee/gee.go @@ -1,6 +1,7 @@ package gee import ( + "log" "net/http" "strings" ) @@ -52,7 +53,7 @@ func (group *RouterGroup) Use(middlewares ...HandlerFunc) { func (group *RouterGroup) addRoute(method string, comp string, handler HandlerFunc) { pattern := group.prefix + comp - + log.Printf("Route %4s - %s", method, pattern) group.engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day5-middleware/gee/gee_test.go b/gee-web/day5-middleware/gee/gee_test.go new file mode 100644 index 0000000..f0c9577 --- /dev/null +++ b/gee-web/day5-middleware/gee/gee_test.go @@ -0,0 +1,16 @@ +package gee + +import "testing" + +func TestNestedGroup(t *testing.T) { + r := New() + v1 := r.Group("/v1") + v2 := v1.Group("/v2") + v3 := v2.Group("/v3") + if v2.prefix != "/v1/v2" { + t.Fatal("v2 prefix should be /v1/v2") + } + if v3.prefix != "/v1/v2/v3" { + t.Fatal("v2 prefix should be /v1/v2") + } +} diff --git a/gee-web/day5-middleware/gee/go.mod b/gee-web/day5-middleware/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day5-middleware/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day5-middleware/gee/logger.go b/gee-web/day5-middleware/gee/logger.go similarity index 100% rename from day5-middleware/gee/logger.go rename to gee-web/day5-middleware/gee/logger.go diff --git a/day5-middleware/gee/router.go b/gee-web/day5-middleware/gee/router.go similarity index 100% rename from day5-middleware/gee/router.go rename to gee-web/day5-middleware/gee/router.go diff --git a/day5-middleware/gee/router_test.go b/gee-web/day5-middleware/gee/router_test.go similarity index 100% rename from day5-middleware/gee/router_test.go rename to gee-web/day5-middleware/gee/router_test.go diff --git a/day5-middleware/gee/trie.go b/gee-web/day5-middleware/gee/trie.go similarity index 100% rename from day5-middleware/gee/trie.go rename to gee-web/day5-middleware/gee/trie.go diff --git a/gee-web/day5-middleware/go.mod b/gee-web/day5-middleware/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day5-middleware/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day5-middleware/main.go b/gee-web/day5-middleware/main.go similarity index 99% rename from day5-middleware/main.go rename to gee-web/day5-middleware/main.go index 469de0d..b66dbdc 100644 --- a/day5-middleware/main.go +++ b/gee-web/day5-middleware/main.go @@ -24,7 +24,7 @@ import ( "net/http" "time" - "./gee" + "gee" ) func onlyForV2() gee.HandlerFunc { diff --git a/day7-panic-recover/gee/context.go b/gee-web/day6-template/gee/context.go similarity index 96% rename from day7-panic-recover/gee/context.go rename to gee-web/day6-template/gee/context.go index 4e16ca2..9c47b0c 100644 --- a/day7-panic-recover/gee/context.go +++ b/gee-web/day6-template/gee/context.go @@ -71,14 +71,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -93,8 +93,8 @@ func (c *Context) Data(code int, data []byte) { // HTML template render // refer https://golang.org/pkg/html/template/ func (c *Context) HTML(code int, name string, data interface{}) { - c.Writer.WriteHeader(code) - c.Writer.Header().Set("Content-Type", "text/html") + c.SetHeader("Content-Type", "text/html") + c.Status(code) if err := c.engine.htmlTemplates.ExecuteTemplate(c.Writer, name, data); err != nil { c.Fail(500, err.Error()) } diff --git a/day6-template/gee/gee.go b/gee-web/day6-template/gee/gee.go similarity index 98% rename from day6-template/gee/gee.go rename to gee-web/day6-template/gee/gee.go index 357f6e9..d09404c 100644 --- a/day6-template/gee/gee.go +++ b/gee-web/day6-template/gee/gee.go @@ -2,6 +2,7 @@ package gee import ( "html/template" + "log" "net/http" "path" "strings" @@ -56,7 +57,7 @@ func (group *RouterGroup) Use(middlewares ...HandlerFunc) { func (group *RouterGroup) addRoute(method string, comp string, handler HandlerFunc) { pattern := group.prefix + comp - + log.Printf("Route %4s - %s", method, pattern) group.engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day6-template/gee/gee_test.go b/gee-web/day6-template/gee/gee_test.go new file mode 100644 index 0000000..f0c9577 --- /dev/null +++ b/gee-web/day6-template/gee/gee_test.go @@ -0,0 +1,16 @@ +package gee + +import "testing" + +func TestNestedGroup(t *testing.T) { + r := New() + v1 := r.Group("/v1") + v2 := v1.Group("/v2") + v3 := v2.Group("/v3") + if v2.prefix != "/v1/v2" { + t.Fatal("v2 prefix should be /v1/v2") + } + if v3.prefix != "/v1/v2/v3" { + t.Fatal("v2 prefix should be /v1/v2") + } +} diff --git a/gee-web/day6-template/gee/go.mod b/gee-web/day6-template/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day6-template/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day6-template/gee/logger.go b/gee-web/day6-template/gee/logger.go similarity index 100% rename from day6-template/gee/logger.go rename to gee-web/day6-template/gee/logger.go diff --git a/day6-template/gee/router.go b/gee-web/day6-template/gee/router.go similarity index 100% rename from day6-template/gee/router.go rename to gee-web/day6-template/gee/router.go diff --git a/day6-template/gee/router_test.go b/gee-web/day6-template/gee/router_test.go similarity index 100% rename from day6-template/gee/router_test.go rename to gee-web/day6-template/gee/router_test.go diff --git a/day6-template/gee/trie.go b/gee-web/day6-template/gee/trie.go similarity index 100% rename from day6-template/gee/trie.go rename to gee-web/day6-template/gee/trie.go diff --git a/gee-web/day6-template/go.mod b/gee-web/day6-template/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day6-template/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/day6-template/main.go b/gee-web/day6-template/main.go similarity index 94% rename from day6-template/main.go rename to gee-web/day6-template/main.go index 2400183..dea17b2 100644 --- a/day6-template/main.go +++ b/gee-web/day6-template/main.go @@ -39,7 +39,7 @@ import ( "net/http" "time" - "./gee" + "gee" ) type student struct { @@ -47,7 +47,7 @@ type student struct { Age int8 } -func formatAsDate(t time.Time) string { +func FormatAsDate(t time.Time) string { year, month, day := t.Date() return fmt.Sprintf("%d-%02d-%02d", year, month, day) } @@ -56,7 +56,7 @@ func main() { r := gee.New() r.Use(gee.Logger()) r.SetFuncMap(template.FuncMap{ - "formatAsDate": formatAsDate, + "FormatAsDate": FormatAsDate, }) r.LoadHTMLGlob("templates/*") r.Static("/assets", "./static") diff --git a/day6-template/static/css/geektutu.css b/gee-web/day6-template/static/css/geektutu.css similarity index 100% rename from day6-template/static/css/geektutu.css rename to gee-web/day6-template/static/css/geektutu.css diff --git a/day6-template/static/file1.txt b/gee-web/day6-template/static/file1.txt similarity index 100% rename from day6-template/static/file1.txt rename to gee-web/day6-template/static/file1.txt diff --git a/day6-template/templates/arr.tmpl b/gee-web/day6-template/templates/arr.tmpl similarity index 100% rename from day6-template/templates/arr.tmpl rename to gee-web/day6-template/templates/arr.tmpl diff --git a/day6-template/templates/css.tmpl b/gee-web/day6-template/templates/css.tmpl similarity index 100% rename from day6-template/templates/css.tmpl rename to gee-web/day6-template/templates/css.tmpl diff --git a/day6-template/templates/custom_func.tmpl b/gee-web/day6-template/templates/custom_func.tmpl similarity index 68% rename from day6-template/templates/custom_func.tmpl rename to gee-web/day6-template/templates/custom_func.tmpl index 2d50a8e..c267ebd 100644 --- a/day6-template/templates/custom_func.tmpl +++ b/gee-web/day6-template/templates/custom_func.tmpl @@ -2,7 +2,7 @@

hello, {{.title}}

-

Date: {{.now | formatAsDate}}

+

Date: {{.now | FormatAsDate}}

diff --git a/day6-template/gee/context.go b/gee-web/day7-panic-recover/gee/context.go similarity index 96% rename from day6-template/gee/context.go rename to gee-web/day7-panic-recover/gee/context.go index 4e16ca2..9c47b0c 100644 --- a/day6-template/gee/context.go +++ b/gee-web/day7-panic-recover/gee/context.go @@ -71,14 +71,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -93,8 +93,8 @@ func (c *Context) Data(code int, data []byte) { // HTML template render // refer https://golang.org/pkg/html/template/ func (c *Context) HTML(code int, name string, data interface{}) { - c.Writer.WriteHeader(code) - c.Writer.Header().Set("Content-Type", "text/html") + c.SetHeader("Content-Type", "text/html") + c.Status(code) if err := c.engine.htmlTemplates.ExecuteTemplate(c.Writer, name, data); err != nil { c.Fail(500, err.Error()) } diff --git a/day7-panic-recover/gee/gee.go b/gee-web/day7-panic-recover/gee/gee.go similarity index 98% rename from day7-panic-recover/gee/gee.go rename to gee-web/day7-panic-recover/gee/gee.go index 9eeb356..f0dd8fd 100644 --- a/day7-panic-recover/gee/gee.go +++ b/gee-web/day7-panic-recover/gee/gee.go @@ -2,6 +2,7 @@ package gee import ( "html/template" + "log" "net/http" "path" "strings" @@ -63,7 +64,7 @@ func (group *RouterGroup) Use(middlewares ...HandlerFunc) { func (group *RouterGroup) addRoute(method string, comp string, handler HandlerFunc) { pattern := group.prefix + comp - + log.Printf("Route %4s - %s", method, pattern) group.engine.router.addRoute(method, pattern, handler) } diff --git a/gee-web/day7-panic-recover/gee/gee_test.go b/gee-web/day7-panic-recover/gee/gee_test.go new file mode 100644 index 0000000..f0c9577 --- /dev/null +++ b/gee-web/day7-panic-recover/gee/gee_test.go @@ -0,0 +1,16 @@ +package gee + +import "testing" + +func TestNestedGroup(t *testing.T) { + r := New() + v1 := r.Group("/v1") + v2 := v1.Group("/v2") + v3 := v2.Group("/v3") + if v2.prefix != "/v1/v2" { + t.Fatal("v2 prefix should be /v1/v2") + } + if v3.prefix != "/v1/v2/v3" { + t.Fatal("v2 prefix should be /v1/v2") + } +} diff --git a/gee-web/day7-panic-recover/gee/go.mod b/gee-web/day7-panic-recover/gee/go.mod new file mode 100644 index 0000000..c944c8a --- /dev/null +++ b/gee-web/day7-panic-recover/gee/go.mod @@ -0,0 +1,3 @@ +module gee + +go 1.13 diff --git a/day7-panic-recover/gee/logger.go b/gee-web/day7-panic-recover/gee/logger.go similarity index 100% rename from day7-panic-recover/gee/logger.go rename to gee-web/day7-panic-recover/gee/logger.go diff --git a/day7-panic-recover/gee/recovery.go b/gee-web/day7-panic-recover/gee/recovery.go similarity index 100% rename from day7-panic-recover/gee/recovery.go rename to gee-web/day7-panic-recover/gee/recovery.go diff --git a/day7-panic-recover/gee/router.go b/gee-web/day7-panic-recover/gee/router.go similarity index 100% rename from day7-panic-recover/gee/router.go rename to gee-web/day7-panic-recover/gee/router.go diff --git a/day7-panic-recover/gee/router_test.go b/gee-web/day7-panic-recover/gee/router_test.go similarity index 100% rename from day7-panic-recover/gee/router_test.go rename to gee-web/day7-panic-recover/gee/router_test.go diff --git a/day7-panic-recover/gee/trie.go b/gee-web/day7-panic-recover/gee/trie.go similarity index 100% rename from day7-panic-recover/gee/trie.go rename to gee-web/day7-panic-recover/gee/trie.go diff --git a/gee-web/day7-panic-recover/go.mod b/gee-web/day7-panic-recover/go.mod new file mode 100644 index 0000000..b27ebc4 --- /dev/null +++ b/gee-web/day7-panic-recover/go.mod @@ -0,0 +1,7 @@ +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee diff --git a/gee-web/day7-panic-recover/main.go b/gee-web/day7-panic-recover/main.go new file mode 100644 index 0000000..c5b309b --- /dev/null +++ b/gee-web/day7-panic-recover/main.go @@ -0,0 +1,53 @@ +package main + +/* +$ curl "http://localhost:9999" +Hello Geektutu +$ curl "http://localhost:9999/panic" +{"message":"Internal Server Error"} +$ curl "http://localhost:9999" +Hello Geektutu + +>>> log +2020/01/09 01:00:10 Route GET - / +2020/01/09 01:00:10 Route GET - /panic +2020/01/09 01:00:22 [200] / in 25.364µs +2020/01/09 01:00:32 runtime error: index out of range +Traceback: + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:523 + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:44 + /Users/7days-golang/day7-panic-recover/main.go:47 + /Users/7days-golang/day7-panic-recover/gee/context.go:41 + /Users/7days-golang/day7-panic-recover/gee/recovery.go:37 + /Users/7days-golang/day7-panic-recover/gee/context.go:41 + /Users/7days-golang/day7-panic-recover/gee/logger.go:15 + /Users/7days-golang/day7-panic-recover/gee/context.go:41 + /Users/7days-golang/day7-panic-recover/gee/router.go:99 + /Users/7days-golang/day7-panic-recover/gee/gee.go:130 + /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:2775 + /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:1879 + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/asm_amd64.s:1338 + +2020/01/09 01:00:32 [500] /panic in 395.846µs +2020/01/09 01:00:38 [200] / in 6.985µs +*/ + +import ( + "net/http" + + "gee" +) + +func main() { + r := gee.Default() + r.GET("/", func(c *gee.Context) { + c.String(http.StatusOK, "Hello Geektutu\n") + }) + // index out of range for testing Recovery() + r.GET("/panic", func(c *gee.Context) { + names := []string{"geektutu"} + c.String(http.StatusOK, names[100]) + }) + + r.Run(":9999") +} diff --git a/doc/gee-day1.md b/gee-web/doc/gee-day1.md similarity index 92% rename from doc/gee-day1.md rename to gee-web/doc/gee-day1.md index 14385d2..9a09b37 100644 --- a/doc/gee-day1.md +++ b/gee-web/doc/gee-day1.md @@ -4,8 +4,9 @@ date: 2019-08-12 00:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了Go标准库 net/http 和 http.Handler 接口的使用,拦截所有的 HTTP 请求,交给Gee框架处理。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - net/http image: post/gee/gee.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day1 HTTP 基础 --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第一篇。 @@ -24,7 +27,7 @@ github: https://github.com/geektutu/7days-golang Go语言内置了 `net/http`库,封装了HTTP网络编程的基础的接口,我们实现的`Gee` Web 框架便是基于`net/http`的。我们接下来通过一个例子,简单介绍下这个库的使用。 -**[day1-http-base/base1/main.go](https://github.com/geektutu/7days-golang/tree/master/day1-http-base/base1)** +**[day1-http-base/base1/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base/base1)** ```go package main @@ -82,7 +85,7 @@ func ListenAndServe(address string, h Handler) error 第二个参数的类型是什么呢?通过查看`net/http`的源码可以发现,`Handler`是一个接口,需要实现方法 _ServeHTTP_ ,也就是说,只要传入任何实现了 _ServerHTTP_ 接口的实例,所有的HTTP请求,就都交给了该实例处理了。马上来试一试吧。 -**[day1-http-base/base2/main.go](https://github.com/geektutu/7days-golang/tree/master/day1-http-base/base2)** +**[day1-http-base/base2/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base/base2)** ```go package main @@ -130,12 +133,32 @@ func main() { ```bash gee/ |--gee.go + |--go.mod main.go +go.mod ``` +### go.mod + +**[day1-http-base/base3/go.mod](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base/base3)** + +```bash +module example + +go 1.13 + +require gee v0.0.0 + +replace gee => ./gee +``` + +- 在 `go.mod` 中使用 `replace` 将 gee 指向 `./gee` + +> 从 go 1.11 版本开始,引用相对路径的 package 需要使用上述方式。 + ### main.go -**[day1-http-base/base3/main.go](https://github.com/geektutu/7days-golang/tree/master/day1-http-base/base3)** +**[day1-http-base/base3/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base/base3)** ```go package main @@ -144,7 +167,7 @@ import ( "fmt" "net/http" - "./gee" + "gee" ) func main() { @@ -167,7 +190,7 @@ func main() { ### gee.go -**[day1-http-base/base3/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/day1-http-base/base3)** +**[day1-http-base/base3/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base/base3)** ```go package gee diff --git a/doc/gee-day2.md b/gee-web/doc/gee-day2.md similarity index 96% rename from doc/gee-day2.md rename to gee-web/doc/gee-day2.md index e4b5f89..9bd011f 100644 --- a/doc/gee-day2.md +++ b/gee-web/doc/gee-day2.md @@ -4,8 +4,9 @@ date: 2019-08-19 00:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了请求上下文(Context)的设计理念,封装了返回JSON/String/Data/HTML等类型响应的方法。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - Context image: post/gee/gee.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day2 上下文 --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第二篇。 @@ -25,7 +28,7 @@ github: https://github.com/geektutu/7days-golang 为了展示第二天的成果,我们看一看在使用时的效果。 -[day2-context/main.go](https://github.com/geektutu/7days-golang/tree/master/day2-context) +[day2-context/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day2-context) ```go @@ -90,7 +93,7 @@ c.JSON(http.StatusOK, gee.H{ ### 具体实现 -[day2-context/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/day2-context) +[day2-context/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day2-context) ```go type H map[string]interface{} @@ -133,14 +136,14 @@ func (c *Context) SetHeader(key string, value string) { } func (c *Context) String(code int, format string, values ...interface{}) { - c.Status(code) c.SetHeader("Content-Type", "text/plain") + c.Status(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Context) JSON(code int, obj interface{}) { - c.Status(code) c.SetHeader("Content-Type", "application/json") + c.Status(code) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(obj); err != nil { http.Error(c.Writer, err.Error(), 500) @@ -153,8 +156,8 @@ func (c *Context) Data(code int, data []byte) { } func (c *Context) HTML(code int, html string) { - c.Status(code) c.SetHeader("Content-Type", "text/html") + c.Status(code) c.Writer.Write([]byte(html)) } ``` @@ -168,7 +171,7 @@ func (c *Context) HTML(code int, html string) { 我们将和路由相关的方法和结构提取了出来,放到了一个新的文件中`router.go`,方便我们下一次对 router 的功能进行增强,例如提供动态路由的支持。 router 的 handle 方法作了一个细微的调整,即 handler 的参数,变成了 Context。 -[day2-context/gee/router.go](https://github.com/geektutu/7days-golang/tree/master/day2-context) +[day2-context/gee/router.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day2-context) ```go type router struct { @@ -180,6 +183,7 @@ func newRouter() *router { } func (r *router) addRoute(method string, pattern string, handler HandlerFunc) { + log.Printf("Route %4s - %s", method, pattern) key := method + "-" + pattern r.handlers[key] = handler } @@ -196,7 +200,7 @@ func (r *router) handle(c *Context) { ## 框架入口 -[day2-context/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/day2-context) +[day2-context/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day2-context) ```go // HandlerFunc defines the request handler used by gee diff --git a/doc/gee-day3.md b/gee-web/doc/gee-day3.md similarity index 98% rename from doc/gee-day3.md rename to gee-web/doc/gee-day3.md index f838554..3586a2f 100644 --- a/doc/gee-day3.md +++ b/gee-web/doc/gee-day3.md @@ -4,8 +4,9 @@ date: 2019-08-28 00:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了如何用 Trie 前缀树实现路由 Route。支持简单的参数解析和通配符的场景。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,11 +14,13 @@ keywords: - Route image: post/gee-day3/trie_router.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day3 前缀树路由 --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第三篇。 -- 使用 Tire 树实现动态路由(dynamic route)解析。 +- 使用 Trie 树实现动态路由(dynamic route)解析。 - 支持两种模式`:name`和`*filepath`,**代码约150行**。 ## Trie 树简介 @@ -52,7 +55,7 @@ HTTP请求的路径恰好是由`/`分隔的多段构成的,因此,每一段 首先我们需要设计树节点上应该存储那些信息。 -**[day3-router/gee/trie.go](https://github.com/geektutu/7days-golang/tree/master/day3-router/gee)** +**[day3-router/gee/trie.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day3-router)** ```go type node struct { diff --git a/doc/gee-day3/trie_eg.jpg b/gee-web/doc/gee-day3/trie_eg.jpg similarity index 100% rename from doc/gee-day3/trie_eg.jpg rename to gee-web/doc/gee-day3/trie_eg.jpg diff --git a/doc/gee-day3/trie_router.jpg b/gee-web/doc/gee-day3/trie_router.jpg similarity index 100% rename from doc/gee-day3/trie_router.jpg rename to gee-web/doc/gee-day3/trie_router.jpg diff --git a/doc/gee-day4.md b/gee-web/doc/gee-day4.md similarity index 97% rename from doc/gee-day4.md rename to gee-web/doc/gee-day4.md index c6263f0..d48202f 100644 --- a/doc/gee-day4.md +++ b/gee-web/doc/gee-day4.md @@ -4,8 +4,9 @@ date: 2019-09-01 15:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了分组控制(Group Control)的意义,以及嵌套分组路由的实现。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - Group Control image: post/gee-day4/group.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day4 分组控制 --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第四篇。 @@ -49,7 +52,7 @@ v1.GET("/", func(c *gee.Context) { 所以,最后的 Group 的定义是这样的: -**[day4-group/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/day4-group/gee)** +**[day4-group/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day4-group)** ```go RouterGroup struct { @@ -96,6 +99,7 @@ func (group *RouterGroup) Group(prefix string) *RouterGroup { func (group *RouterGroup) addRoute(method string, comp string, handler HandlerFunc) { pattern := group.prefix + comp + log.Printf("Route %4s - %s", method, pattern) group.engine.router.addRoute(method, pattern, handler) } diff --git a/doc/gee-day4/group.jpg b/gee-web/doc/gee-day4/group.jpg similarity index 100% rename from doc/gee-day4/group.jpg rename to gee-web/doc/gee-day4/group.jpg diff --git a/doc/gee-day5.md b/gee-web/doc/gee-day5.md similarity index 96% rename from doc/gee-day5.md rename to gee-web/doc/gee-day5.md index 763e8a7..7a540a1 100644 --- a/doc/gee-day5.md +++ b/gee-web/doc/gee-day5.md @@ -4,8 +4,9 @@ date: 2019-09-01 20:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了如何为Web框架添加中间件的功能(middlewares)。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - Middlewares image: post/gee-day5/middleware.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day5 中间件 --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第五篇。 @@ -33,7 +36,7 @@ github: https://github.com/geektutu/7days-golang Gee 的中间件的定义与路由映射的 Handler 一致,处理的输入是`Context`对象。插入点是框架接收到请求初始化`Context`对象后,允许用户使用自己定义的中间件做一些额外的处理,例如记录日志等,以及对`Context`进行二次加工。另外通过调用`(*Context).Next()`函数,中间件可等待用户自己定义的 `Handler`处理结束后,再做一些额外的操作,例如计算本次处理所用时间等。即 Gee 的中间件支持用户在请求被处理的前后,做一些额外的操作。举个例子,我们希望最终能够支持如下定义的中间件,`c.Next()`表示等待执行其他的中间件或用户的`Handler`: -****[day4-group/gee/logger.go](https://github.com/geektutu/7days-golang/tree/master/day5-middleware/gee)**** +****[day4-group/gee/logger.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day5-middleware)**** ```go func Logger() HandlerFunc { @@ -56,7 +59,7 @@ func Logger() HandlerFunc { 为此,我们给`Context`添加了2个参数,定义了`Next`方法: -**[day4-group/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/day5-middleware/gee)** +**[day4-group/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day5-middleware)** ```go type Context struct { @@ -130,7 +133,7 @@ func B(c *Context) { - 定义`Use`函数,将中间件应用到某个 Group 。 -**[day4-group/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/day5-middleware/gee)** +**[day4-group/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day5-middleware)** ```go // Use is defined to add middleware to the group @@ -155,7 +158,7 @@ ServeHTTP 函数也有变化,当我们接收到一个具体请求时,要判 - handle 函数中,将从路由匹配得到的 Handler 添加到 `c.handlers`列表中,执行`c.Next()`。 -**[day4-group/gee/router.go](https://github.com/geektutu/7days-golang/tree/master/day5-middleware/gee)** +**[day4-group/gee/router.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day5-middleware)** ```go func (r *router) handle(c *Context) { diff --git a/doc/gee-day5/middleware.jpg b/gee-web/doc/gee-day5/middleware.jpg similarity index 100% rename from doc/gee-day5/middleware.jpg rename to gee-web/doc/gee-day5/middleware.jpg diff --git a/doc/gee-day6.md b/gee-web/doc/gee-day6.md similarity index 88% rename from doc/gee-day6.md rename to gee-web/doc/gee-day6.md index ad4845b..1a6283a 100644 --- a/doc/gee-day6.md +++ b/gee-web/doc/gee-day6.md @@ -4,8 +4,9 @@ date: 2019-09-08 20:10:00 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了如何为Web框架添加HTML模板(HTML Template)以及静态文件(Serve Static Files)的功能。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Go语言 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - Template image: post/gee-day6/html.png github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day6 模板 Template --- 本文是 [7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第六篇。 @@ -36,7 +39,7 @@ github: https://github.com/geektutu/7days-golang 找到文件后,如何返回这一步,`net/http`库已经实现了。因此,gee 框架要做的,仅仅是解析请求的地址,映射到服务器上文件的真实地址,交给`http.FileServer`处理就好了。 -[day6-template/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/day6-template/gee) +[day6-template/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day6-template) ```go // create static handler @@ -103,18 +106,38 @@ func (engine *Engine) LoadHTMLGlob(pattern string) { 接下来,对原来的 `(*Context).HTML()`方法做了些小修改,使之支持根据模板文件名选择模板进行渲染。 -[day6-template/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/day6-template/gee) +[day6-template/gee/context.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day6-template) ```go +type Context struct { + // ... + // engine pointer + engine *Engine +} + func (c *Context) HTML(code int, name string, data interface{}) { - c.Writer.WriteHeader(code) - c.Writer.Header().Set("Content-Type", "text/html") + c.SetHeader("Content-Type", "text/html") + c.Status(code) if err := c.engine.htmlTemplates.ExecuteTemplate(c.Writer, name, data); err != nil { c.Fail(500, err.Error()) } } ``` +我们在 `Context` 中添加了成员变量 `engine *Engine`,这样就能够通过 Context 访问 Engine 中的 HTML 模板。实例化 Context 时,还需要给 `c.engine` 赋值。 + +[day6-template/gee/gee.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day6-template) + +```go +func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // ... + c := newContext(w, req) + c.handlers = middlewares + c.engine = engine + engine.router.handle(c) +} +``` + ## 使用Demo 最终的目录结构 @@ -140,7 +163,7 @@ func (c *Context) HTML(code int, name string, data interface{}) { ``` -[day6-template/main.go](https://github.com/geektutu/7days-golang/tree/master/day6-template/gee) +[day6-template/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day6-template) ```go type student struct { @@ -148,7 +171,7 @@ type student struct { Age int8 } -func formatAsDate(t time.Time) string { +func FormatAsDate(t time.Time) string { year, month, day := t.Date() return fmt.Sprintf("%d-%02d-%02d", year, month, day) } @@ -157,7 +180,7 @@ func main() { r := gee.New() r.Use(gee.Logger()) r.SetFuncMap(template.FuncMap{ - "formatAsDate": formatAsDate, + "FormatAsDate": FormatAsDate, }) r.LoadHTMLGlob("templates/*") r.Static("/assets", "./static") diff --git a/doc/gee-day6/html.png b/gee-web/doc/gee-day6/html.png similarity index 100% rename from doc/gee-day6/html.png rename to gee-web/doc/gee-day6/html.png diff --git a/doc/gee-day6/static.jpg b/gee-web/doc/gee-day6/static.jpg similarity index 100% rename from doc/gee-day6/static.jpg rename to gee-web/doc/gee-day6/static.jpg diff --git a/gee-web/doc/gee-day7.md b/gee-web/doc/gee-day7.md new file mode 100644 index 0000000..f64de84 --- /dev/null +++ b/gee-web/doc/gee-day7.md @@ -0,0 +1,292 @@ +--- +title: Go语言动手写Web框架 - Gee第七天 错误恢复(Panic Recover) +date: 2020-01-09 01:00:00 +description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。本文介绍了如何为Web框架增加错误处理机制。 +tags: +- Go +nav: 从零实现 +categories: +- Web框架 - Gee +keywords: +- Go语言 +- 从零实现Web框架 +- 动手写Web框架 +- Panic +- Recover +image: post/gee-day7/go-panic.png +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day7 错误恢复 +--- + +本文是[7天用Go从零实现Web框架Gee教程系列](https://geektutu.com/post/gee.html)的第七篇。 + +- 实现错误处理机制。 + +## panic + +Go 语言中,比较常见的错误处理方法是返回 error,由调用者决定后续如何处理。但是如果是无法恢复的错误,可以手动触发 panic,当然如果在程序运行过程中出现了类似于数组越界的错误,panic 也会被触发。panic 会中止当前执行的程序,退出。 + +下面是主动触发的例子: + +```go +// hello.go +func main() { + fmt.Println("before panic") + panic("crash") + fmt.Println("after panic") +} +``` + +```bash +$ go run hello.go + +before panic +panic: crash + +goroutine 1 [running]: +main.main() + ~/go_demo/hello/hello.go:7 +0x95 +exit status 2 +``` + +下面是数组越界触发的 panic + +```go +// hello.go +func main() { + arr := []int{1, 2, 3} + fmt.Println(arr[4]) +} +``` + +```bash +$ go run hello.go +panic: runtime error: index out of range [4] with length 3 +``` + +## defer + +panic 会导致程序被中止,但是在退出前,会先处理完当前协程上已经defer 的任务,执行完成后再退出。效果类似于 java 语言的 `try...catch`。 + +```go +// hello.go +func main() { + defer func() { + fmt.Println("defer func") + }() + + arr := []int{1, 2, 3} + fmt.Println(arr[4]) +} +``` + +```go +$ go run hello.go +defer func +panic: runtime error: index out of range [4] with length 3 +``` + +可以 defer 多个任务,在同一个函数中 defer 多个任务,会逆序执行。即先执行最后 defer 的任务。 + +在这里,defer 的任务执行完成之后,panic 还会继续被抛出,导致程序非正常结束。 + +## recover + +Go 语言还提供了 recover 函数,可以避免因为 panic 发生而导致整个程序终止,recover 函数只在 defer 中生效。 + +```go +// hello.go +func test_recover() { + defer func() { + fmt.Println("defer func") + if err := recover(); err != nil { + fmt.Println("recover success") + } + }() + + arr := []int{1, 2, 3} + fmt.Println(arr[4]) + fmt.Println("after panic") +} + +func main() { + test_recover() + fmt.Println("after recover") +} +``` + +```go +$ go run hello.go +defer func +recover success +after recover +``` + +我们可以看到,recover 捕获了 panic,程序正常结束。*test_recover()* 中的 *after panic* 没有打印,这是正确的,当 panic 被触发时,控制权就被交给了 defer 。就像在 java 中,`try`代码块中发生了异常,控制权交给了 `catch`,接下来执行 catch 代码块中的代码。而在 *main()* 中打印了 *after recover*,说明程序已经恢复正常,继续往下执行直到结束。 + +## Gee 的错误处理机制 + +对一个 Web 框架而言,错误处理机制是非常必要的。可能是框架本身没有完备的测试,导致在某些情况下出现空指针异常等情况。也有可能用户不正确的参数,触发了某些异常,例如数组越界,空指针等。如果因为这些原因导致系统宕机,必然是不可接受的。 + +我们在[第六天](https://geektutu.com/post/gee-day6.html)实现的框架并没有加入异常处理机制,如果代码中存在会触发 panic 的 BUG,很容易宕掉。 + +例如下面的代码: + +```go +func main() { + r := gee.New() + r.GET("/panic", func(c *gee.Context) { + names := []string{"geektutu"} + c.String(http.StatusOK, names[100]) + }) + r.Run(":9999") +} +``` +在上面的代码中,我们为 gee 注册了路由 `/panic`,而这个路由的处理函数内部存在数组越界 `names[100]`,如果访问 *localhost:9999/panic*,Web 服务就会宕掉。 + +今天,我们将在 gee 中添加一个非常简单的错误处理机制,即在此类错误发生时,向用户返回 *Internal Server Error*,并且在日志中打印必要的错误信息,方便进行错误定位。 + +我们之前实现了中间件机制,错误处理也可以作为一个中间件,增强 gee 框架的能力。 + +新增文件 **gee/recovery.go**,在这个文件中实现中间件 `Recovery`。 + +```go +func Recovery() HandlerFunc { + return func(c *Context) { + defer func() { + if err := recover(); err != nil { + message := fmt.Sprintf("%s", err) + log.Printf("%s\n\n", trace(message)) + c.Fail(http.StatusInternalServerError, "Internal Server Error") + } + }() + + c.Next() + } +} +``` + +`Recovery` 的实现非常简单,使用 defer 挂载上错误恢复的函数,在这个函数中调用 *recover()*,捕获 panic,并且将堆栈信息打印在日志中,向用户返回 *Internal Server Error*。 + +你可能注意到,这里有一个 *trace()* 函数,这个函数是用来获取触发 panic 的堆栈信息,完整代码如下: + +[day7-panic-recover/gee/recovery.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day7-panic-recover) + +```go +package gee + +import ( + "fmt" + "log" + "net/http" + "runtime" + "strings" +) + +// print stack trace for debug +func trace(message string) string { + var pcs [32]uintptr + n := runtime.Callers(3, pcs[:]) // skip first 3 caller + + var str strings.Builder + str.WriteString(message + "\nTraceback:") + for _, pc := range pcs[:n] { + fn := runtime.FuncForPC(pc) + file, line := fn.FileLine(pc) + str.WriteString(fmt.Sprintf("\n\t%s:%d", file, line)) + } + return str.String() +} + +func Recovery() HandlerFunc { + return func(c *Context) { + defer func() { + if err := recover(); err != nil { + message := fmt.Sprintf("%s", err) + log.Printf("%s\n\n", trace(message)) + c.Fail(http.StatusInternalServerError, "Internal Server Error") + } + }() + + c.Next() + } +} +``` + +在 *trace()* 中,调用了 `runtime.Callers(3, pcs[:])`,Callers 用来返回调用栈的程序计数器, 第 0 个 Caller 是 Callers 本身,第 1 个是上一层 trace,第 2 个是再上一层的 `defer func`。因此,为了日志简洁一点,我们跳过了前 3 个 Caller。 + +接下来,通过 `runtime.FuncForPC(pc)` 获取对应的函数,在通过 `fn.FileLine(pc)` 获取到调用该函数的文件名和行号,打印在日志中。 + +至此,gee 框架的错误处理机制就完成了。 + +## 使用 Demo + +[day7-panic-recover/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-web/day7-panic-recover) + +```go +package main + +import ( + "net/http" + + "gee" +) + +func main() { + r := gee.Default() + r.GET("/", func(c *gee.Context) { + c.String(http.StatusOK, "Hello Geektutu\n") + }) + // index out of range for testing Recovery() + r.GET("/panic", func(c *gee.Context) { + names := []string{"geektutu"} + c.String(http.StatusOK, names[100]) + }) + + r.Run(":9999") +} +``` + +接下来进行测试,先访问主页,访问一个有BUG的 `/panic`,服务正常返回。接下来我们再一次成功访问了主页,说明服务完全运转正常。 + +```bash +$ curl "http://localhost:9999" +Hello Geektutu +$ curl "http://localhost:9999/panic" +{"message":"Internal Server Error"} +$ curl "http://localhost:9999" +Hello Geektutu +``` + +我们可以在后台日志中看到如下内容,引发错误的原因和堆栈信息都被打印了出来,通过日志,我们可以很容易地知道,在*day7-panic-recover/main.go:47* 的地方出现了 `index out of range` 错误。 + +```bash +2020/01/09 01:00:10 Route GET - / +2020/01/09 01:00:10 Route GET - /panic +2020/01/09 01:00:22 [200] / in 25.364µs +2020/01/09 01:00:32 runtime error: index out of range +Traceback: + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:523 + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/panic.go:44 + /tmp/7days-golang/day7-panic-recover/main.go:47 + /tmp/7days-golang/day7-panic-recover/gee/context.go:41 + /tmp/7days-golang/day7-panic-recover/gee/recovery.go:37 + /tmp/7days-golang/day7-panic-recover/gee/context.go:41 + /tmp/7days-golang/day7-panic-recover/gee/logger.go:15 + /tmp/7days-golang/day7-panic-recover/gee/context.go:41 + /tmp/7days-golang/day7-panic-recover/gee/router.go:99 + /tmp/7days-golang/day7-panic-recover/gee/gee.go:130 + /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:2775 + /usr/local/Cellar/go/1.12.5/libexec/src/net/http/server.go:1879 + /usr/local/Cellar/go/1.12.5/libexec/src/runtime/asm_amd64.s:1338 + +2020/01/09 01:00:32 [500] /panic in 395.846µs +2020/01/09 01:00:38 [200] / in 6.985µs +``` + +## 参考 + +- [Package runtime - golang.org](https://golang.org/pkg/runtime/) +- [Is it possible get information about caller function in Golang? - StackOverflow](https://stackoverflow.com/questions/35212985/is-it-possible-get-information-about-caller-function-in-golang) + diff --git a/gee-web/doc/gee-day7/go-panic.png b/gee-web/doc/gee-day7/go-panic.png new file mode 100644 index 0000000..1682e75 Binary files /dev/null and b/gee-web/doc/gee-day7/go-panic.png differ diff --git a/doc/gee.md b/gee-web/doc/gee.md similarity index 78% rename from doc/gee.md rename to gee-web/doc/gee.md index 03f2c79..9dbfc54 100644 --- a/doc/gee.md +++ b/gee-web/doc/gee.md @@ -4,8 +4,9 @@ date: 2019-08-11 02:10:10 description: 7天用 Go语言 从零实现Web框架教程(7 days implement golang web framework from scratch tutorial),用 Go语言/golang 动手写Web框架,从零实现一个Web框架,以 Gin 为原型从零设计一个Web框架。 tags: - Go +nav: 从零实现 categories: -- 从零实现 +- Web框架 - Gee keywords: - Gee教程 - 从零实现Web框架 @@ -13,6 +14,8 @@ keywords: - from scratch image: post/gee/gee.jpg github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: Day0 序言 --- ![gee](gee/gee.jpg) @@ -62,10 +65,16 @@ func handler(w http.ResponseWriter, r *http.Request) { ## 目录 -- [第一天:前置知识(http.Handler接口)](https://geektutu.com/post/gee-day1.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day1-http-base) -- [第二天:上下文设计(Context)](https://geektutu.com/post/gee-day2.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day2-context) -- [第三天:Tire树路由(Router)](https://geektutu.com/post/gee-day3.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day3-router) -- [第四天:分组控制(Group)](https://geektutu.com/post/gee-day4.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day4-group) -- [第五天:中间件(Middleware)](https://geektutu.com/post/gee-day5.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day5-middleware) -- [第六天:HTML模板(Template)](https://geektutu.com/post/gee-day6.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day6-template) -- 第七天:错误恢复(Panic Recover),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/day7-panic-recover) \ No newline at end of file +- 第一天:[前置知识(http.Handler接口)](https://geektutu.com/post/gee-day1.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day1-http-base) +- 第二天:[上下文设计(Context)](https://geektutu.com/post/gee-day2.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day2-context) +- 第三天:[Trie树路由(Router)](https://geektutu.com/post/gee-day3.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day3-router) +- 第四天:[分组控制(Group)](https://geektutu.com/post/gee-day4.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day4-group) +- 第五天:[中间件(Middleware)](https://geektutu.com/post/gee-day5.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day5-middleware) +- 第六天:[HTML模板(Template)](https://geektutu.com/post/gee-day6.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day6-template) +- 第七天:[错误恢复(Panic Recover)](https://geektutu.com/post/gee-day7.html),[Code - Github](https://github.com/geektutu/7days-golang/tree/master/gee-web/day7-panic-recover) + +## 推荐阅读 + +- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-golang.html) +- [Go Gin 简明教程](https://geektutu.com/post/quick-go-gin.html) \ No newline at end of file diff --git a/doc/gee/gee.jpg b/gee-web/doc/gee/gee.jpg similarity index 100% rename from doc/gee/gee.jpg rename to gee-web/doc/gee/gee.jpg diff --git a/questions/7days-golang-q1.md b/questions/7days-golang-q1.md new file mode 100644 index 0000000..9020b8c --- /dev/null +++ b/questions/7days-golang-q1.md @@ -0,0 +1,203 @@ +--- +title: Go 接口型函数的使用场景 +date: 2020-10-25 12:30:00 +description: Go 语言/golang 中函数式接口或接口型函数的实现与价值,什么是接口型函数,为什么不直接将函数作为参数,而是封装为一个接口。Go 语言标准库 net/http 中是如何使用接口型函数的。 +tags: +- Go +nav: 从零实现 +categories: +- 7days-golang Q & A +keywords: +- 函数式接口 +- 接口型函数 +- net/http +image: post/7days-golang-q1/7days-golang-qa.jpg +github: https://github.com/geektutu/7days-golang +book: 七天用Go从零实现系列 +book_title: 接口型函数 +--- + +![7days-golang 有价值的问题](7days-golang-q1/7days-golang-qa.jpg) + +## 问题 + +在 [动手写分布式缓存 - GeeCache第二天 单机并发缓存](https://geektutu.com/post/geecache-day2.html) 这篇文章中,有一个接口型函数的实现: + +```go +// A Getter loads data for a key. +type Getter interface { + Get(key string) ([]byte, error) +} + +// A GetterFunc implements Getter with a function. +type GetterFunc func(key string) ([]byte, error) + +// Get implements Getter interface function +func (f GetterFunc) Get(key string) ([]byte, error) { + return f(key) +} +``` + +这里呢,定义了一个接口 `Getter`,只包含一个方法 `Get(key string) ([]byte, error)`,紧接着定义了一个函数类型 `GetterFunc`,GetterFunc 参数和返回值与 Getter 中 Get 方法是一致的。而且 GetterFunc 还定义了 Get 方式,并在 Get 方法中调用自己,这样就实现了接口 Getter。所以 GetterFunc 是一个实现了接口的函数类型,简称为接口型函数。 + +这个接口型函数的实现就引起了好几个童鞋的关注。接口型函数只能应用于接口内部只定义了一个方法的情况,例如接口 Getter 内部有且只有一个方法 Get。既然只有一个方法,为什么还要多此一举,封装为一个接口呢?定义参数的时候,直接用 GetterFunc 这个函数类型不就好了,让用户直接传入一个函数作为参数,不更简单吗? + +所以呢,接口型函数的价值什么? + + +## 价值 + +我们想象这么一个使用场景,`GetFromSource` 的作用是从某数据源获取结果,接口类型 Getter 是其中一个参数,代表某数据源: + +```go +func GetFromSource(getter Getter, key string) []byte { + buf, err := getter.Get(key) + if err == nil { + return buf + } + return nil +} +``` + +我们可以有多种方式调用该函数: + +- 方式一:GetterFunc 类型的函数作为参数 + +```go +GetFromSource(GetterFunc(func(key string) ([]byte, error) { + return []byte(key), nil +}), "hello") +``` + +支持匿名函数,也支持普通的函数: + +```go +func test(key string) ([]byte, error) { + return []byte(key), nil +} + +func main() { + GetFromSource(GetterFunc(test), "hello") +} +``` + +将 test 强制类型转换为 GetterFunc,GetterFunc 实现了接口 Getter,是一个合法参数。这种方式适用于逻辑较为简单的场景。 + + +- 方式二:实现了 Getter 接口的结构体作为参数 + +```go +type DB struct{ url string} + +func (db *DB) Query(sql string, args ...string) string { + // ... + return "hello" +} + +func (db *DB) Get(key string) ([]byte, error) { + // ... + v := db.Query("SELECT NAME FROM TABLE WHEN NAME= ?", key) + return []byte(v), nil +} + +func main() { + GetFromSource(new(DB), "hello") +} +``` + +DB 实现了接口 Getter,也是一个合法参数。这种方式适用于逻辑较为复杂的场景,如果对数据库的操作需要很多信息,地址、用户名、密码,还有很多中间状态需要保持,比如超时、重连、加锁等等。这种情况下,更适合封装为一个结构体作为参数。 + +这样,既能够将普通的函数类型(需类型转换)作为参数,也可以将结构体作为参数,使用更为灵活,可读性也更好,这就是接口型函数的价值。 + +## 使用场景 + +这个特性在 groupcache 等大量的 Go 语言开源项目中被广泛使用,标准库中用得也不少,`net/http` 的 Handler 和 HandlerFunc 就是一个典型。 + +我们先看一下 Handler 的定义: + +```go +type Handler interface { + ServeHTTP(ResponseWriter, *Request) +} +type HandlerFunc func(ResponseWriter, *Request) + +func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + f(w, r) +} +``` + +> 摘自 Go 语言源代码 [net/http/server.go](https://github.com/golang/go/blob/master/src/net/http/server.go) + +我们可以 `http.Handle` 来映射请求路径和处理函数,Handle 的定义如下: + +```go +func Handle(pattern string, handler Handler) +``` + +第二个参数是即接口类型 Handler,我们可以这么用。 + +```go +func home(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello, index page")) +} + +func main() { + http.Handle("/home", http.HandlerFunc(home)) + _ = http.ListenAndServe("localhost:8000", nil) +} +``` + +通常,我们还会使用另外一个函数 `http.HandleFunc`,HandleFunc 的定义如下: + +```go +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) +``` + +第二个参数是一个普通的函数类型,那可以直接将 home 传递给 HandleFunc: + +```go +func main() { + http.HandleFunc("/home", home) + _ = http.ListenAndServe("localhost:8000", nil) +} +``` + +那如果我们看过 HandleFunc 的内部实现的话,就会知道两种写法是完全等价的,内部将第二种写法转换为了第一种写法。 + +```go +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + if handler == nil { + panic("http: nil handler") + } + mux.Handle(pattern, HandlerFunc(handler)) +} +``` + +如果你仔细观察,会发现 `http.ListenAndServe` 的第二个参数也是接口类型 `Handler`,我们使用了标准库 `net/http` 内置的路由,因此呢,传入的值是 nil。那如果这个地方我们传入的是一个实现了 `Handler` 接口的结构体呢?就可以完全托管所有的 HTTP 请求,后续怎么路由,怎么处理,请求前后增加什么功能,都可以自定义了。慢慢地,就变成了一个功能丰富的 Web 框架了。如果你感兴趣呢,可以阅读 [7天用Go从零实现Web框架Gee教程](https://geektutu.com/post/gee.html)。 + +## 其他语言类似特性 + +如果有 Java 编程经验的同学可能比较有感触。Java 1.5 中是不支持直接传入函数的,参数要么是接口,要么是对象。举一个最简单的例子,列表自定义排序时,需要实现一个匿名的 Comparator 类,重写 compare 方法。 + +```java +Collections.sort(list, new Comparator(){ + @Override + public int compare(Integer o1, Integer o2) { + return o2 - o1; + } +}); +``` + +Java 1.8 中引入了大量的函数式编程的特性,其中 lambda 表达式和函数式接口就是一个很好的简化 Java 写法的特性。Java 1.8 中,上述的例子可以简化为: + +```java +Collections.sort(list, (Integer o1, Integer o2) -> o2 - o1 ); +``` + +即从需要构造一个匿名对象简化为只需要一个 lambda 函数表达式,可以认为是面向对象与函数式编程的一种结合。同样地,这种写法只支持只定义了一个方法的接口类型。正是这种结合,可以达到实现相同代码,代码量更少的目的。 + +## 附 参考 + +- [7days-golang 有价值的问题讨论汇总贴](https://github.com/geektutu/7days-golang/issues/24) +- [GeeCache第二天 单机并发缓存 - Github 评论区](https://github.com/geektutu/blog/issues/64) \ No newline at end of file diff --git a/questions/7days-golang-q1/7days-golang-qa.jpg b/questions/7days-golang-q1/7days-golang-qa.jpg new file mode 100644 index 0000000..c783be8 Binary files /dev/null and b/questions/7days-golang-q1/7days-golang-qa.jpg differ