From b0c73e9fbf03065efaaa18a75e05eb5ebb253058 Mon Sep 17 00:00:00 2001 From: gzdaijie Date: Mon, 5 Oct 2020 18:07:16 +0800 Subject: [PATCH] gee-rpc/day7 add registry --- README.md | 28 ++ gee-rpc/day1-codec/main/main.go | 2 + gee-rpc/day2-client/main/main.go | 2 + gee-rpc/day3-service/main/main.go | 2 + gee-rpc/day4-timeout/main/main.go | 2 + gee-rpc/day5-http-debug/main/main.go | 2 + .../client.go | 0 .../client_test.go | 0 .../codec/codec.go | 0 .../codec/gob.go | 0 .../debug.go | 0 .../go.mod | 0 .../main/main.go | 2 + .../server.go | 0 .../service.go | 0 .../service_test.go | 0 .../xclient/discovery.go | 31 +- .../xclient/xclient.go | 12 +- gee-rpc/day7-registry/client.go | 324 ++++++++++++++++++ gee-rpc/day7-registry/client_test.go | 84 +++++ gee-rpc/day7-registry/codec/codec.go | 34 ++ gee-rpc/day7-registry/codec/gob.go | 57 +++ gee-rpc/day7-registry/debug.go | 60 ++++ gee-rpc/day7-registry/go.mod | 3 + gee-rpc/day7-registry/main/main.go | 112 ++++++ gee-rpc/day7-registry/registry/registry.go | 123 +++++++ gee-rpc/day7-registry/server.go | 266 ++++++++++++++ gee-rpc/day7-registry/service.go | 99 ++++++ gee-rpc/day7-registry/service_test.go | 48 +++ gee-rpc/day7-registry/xclient/discovery.go | 83 +++++ .../day7-registry/xclient/discovery_gee.go | 74 ++++ gee-rpc/day7-registry/xclient/xclient.go | 109 ++++++ 32 files changed, 1546 insertions(+), 13 deletions(-) rename gee-rpc/{day6-discovery => day6-load-balance}/client.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/client_test.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/codec/codec.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/codec/gob.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/debug.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/go.mod (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/main/main.go (98%) rename gee-rpc/{day6-discovery => day6-load-balance}/server.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/service.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/service_test.go (100%) rename gee-rpc/{day6-discovery => day6-load-balance}/xclient/discovery.go (63%) rename gee-rpc/{day6-discovery => day6-load-balance}/xclient/xclient.go (93%) create mode 100644 gee-rpc/day7-registry/client.go create mode 100644 gee-rpc/day7-registry/client_test.go create mode 100644 gee-rpc/day7-registry/codec/codec.go create mode 100644 gee-rpc/day7-registry/codec/gob.go create mode 100644 gee-rpc/day7-registry/debug.go create mode 100644 gee-rpc/day7-registry/go.mod create mode 100644 gee-rpc/day7-registry/main/main.go create mode 100644 gee-rpc/day7-registry/registry/registry.go create mode 100644 gee-rpc/day7-registry/server.go create mode 100644 gee-rpc/day7-registry/service.go create mode 100644 gee-rpc/day7-registry/service_test.go create mode 100644 gee-rpc/day7-registry/xclient/discovery.go create mode 100644 gee-rpc/day7-registry/xclient/discovery_gee.go create mode 100644 gee-rpc/day7-registry/xclient/xclient.go diff --git a/README.md b/README.md index 37d366d..c2d9d02 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ 推荐先阅读 **[Go 语言简明教程](https://geektutu.com/post/quick-golang.html)**,一篇文章了解Go的基本语法、并发编程,依赖管理等内容。 +另外推荐 **[Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html)**,加深对 Go 语言的理解。 + 期待关注我的「[知乎专栏](https://zhuanlan.zhihu.com/geekgo)」和「[微博](http://weibo.com/geektutu)」,查看最近的文章和动态。 ### 7天用Go从零实现Web框架 - Gee @@ -50,6 +52,20 @@ gorm 准备推出完全重写的 v2 版本(目前还在开发中),相对 gorm- - 第六天:[支持事务(Transaction)](https://geektutu.com/post/geeorm-day6.html) | [Code](gee-orm/day6-transaction) - 第七天:[数据库迁移(Migrate)](https://geektutu.com/post/geeorm-day7.html) | [Code](gee-orm/day7-migrate) + +### 7天用Go从零实现RPC框架 GeeRPC + +[GeeRPC](https://geektutu.com/post/geerpc.html) 是一个基于 [net/rpc](https://github.com/golang/go/tree/master/src/net/rpc) 开发的 RPC 框架 +GeeRPC 是基于 Go 语言标准库 `net/rpc` 实现的,添加了协议交换、服务注册与发现、负载均衡等功能,代码约 1k。 + +- 第一天 - [服务端与消息编码](https://geektutu.com/post/geerpc-day1.html) | [Code](gee-rpc/day1-codec) +- 第二天 - [支持并发与异步的客户端](https://geektutu.com/post/geerpc-day2.html) | [Code](gee-rpc/day2-client) +- 第三天 - [服务注册(service register)](https://geektutu.com/post/geerpc-day3.html) | [Code](gee-rpc/day3-service ) +- 第四天 - [超时处理(timeout)](https://geektutu.com/post/geerpc-day4.html) | [Code](gee-rpc/day4-timeout ) +- 第五天 - [支持HTTP协议](https://geektutu.com/post/geerpc-day5.html) | [Code](gee-rpc/day5-http-debug) +- 第六天 - [负载均衡(load balance)](https://geektutu.com/post/geerpc-day6.html) | [Code](gee-rpc/day6-load-balance) +- 第七天 - [服务发现与注册中心(registry)](https://geektutu.com/post/geerpc-day7.html) | [Code](gee-rpc/day7-registry) + ### WebAssembly 使用示例 具体的实践过程记录在 [Go WebAssembly 简明教程](https://geektutu.com/post/quick-go-wasm.html)。 @@ -102,6 +118,18 @@ Xorm's desgin is easier to understand than gorm-v1, so the main designs referenc - Day 6 - Support Transaction | [Code](gee-orm/day6-transaction) - Day 7 - Migrate Database | [Code](gee-orm/day7-migrate) +[GeeRPC](https://geektutu.com/post/geerpc.html) is a [net/rpc](https://github.com/golang/go/tree/master/src/net/rpc)-like RPC framework + +Based on golang standard library `net/rpc`, GeeRPC implements more features. eg, protocol exchange, service registration and discovery, load balance, etc. + +- Day 1 - Server Message Codec | [Code](gee-rpc/day1-codec) +- Day 2 - Concurrent Client | [Code](gee-rpc/day2-client) +- Day 3 - Service Register | [Code](gee-rpc/day3-service ) +- Day 4 - Timeout Processing | [Code](gee-rpc/day4-timeout ) +- Day 5 - Support HTTP Protocol | [Code](gee-rpc/day5-http-debug) +- Day 6 - Load Balance | [Code](gee-rpc/day6-load-balance) +- Day 7 - Discovery and Registry | [Code](gee-rpc/day7-registry) + ## Golang WebAssembly Demo - Demo 1 - Hello World [Code](demo-wasm/hello-world) diff --git a/gee-rpc/day1-codec/main/main.go b/gee-rpc/day1-codec/main/main.go index 8f29531..948b763 100644 --- a/gee-rpc/day1-codec/main/main.go +++ b/gee-rpc/day1-codec/main/main.go @@ -7,6 +7,7 @@ import ( "geerpc/codec" "log" "net" + "time" ) func startServer(addr chan string) { @@ -28,6 +29,7 @@ func main() { conn, _ := net.Dial("tcp", <-addr) defer func() { _ = conn.Close() }() + time.Sleep(time.Second) // send options _ = json.NewEncoder(conn).Encode(geerpc.DefaultOption) cc := codec.NewGobCodec(conn) diff --git a/gee-rpc/day2-client/main/main.go b/gee-rpc/day2-client/main/main.go index e291353..8502fe9 100644 --- a/gee-rpc/day2-client/main/main.go +++ b/gee-rpc/day2-client/main/main.go @@ -6,6 +6,7 @@ import ( "log" "net" "sync" + "time" ) func startServer(addr chan string) { @@ -25,6 +26,7 @@ func main() { client, _ := geerpc.Dial("tcp", <-addr) defer func() { _ = client.Close() }() + time.Sleep(time.Second) // send request & receive response var wg sync.WaitGroup for i := 0; i < 5; i++ { diff --git a/gee-rpc/day3-service/main/main.go b/gee-rpc/day3-service/main/main.go index c526e11..0f0b668 100644 --- a/gee-rpc/day3-service/main/main.go +++ b/gee-rpc/day3-service/main/main.go @@ -5,6 +5,7 @@ import ( "log" "net" "sync" + "time" ) type Foo int @@ -37,6 +38,7 @@ func main() { client, _ := geerpc.Dial("tcp", <-addr) defer func() { _ = client.Close() }() + time.Sleep(time.Second) // send request & receive response var wg sync.WaitGroup for i := 0; i < 5; i++ { diff --git a/gee-rpc/day4-timeout/main/main.go b/gee-rpc/day4-timeout/main/main.go index e5e6050..9693eb5 100644 --- a/gee-rpc/day4-timeout/main/main.go +++ b/gee-rpc/day4-timeout/main/main.go @@ -6,6 +6,7 @@ import ( "log" "net" "sync" + "time" ) type Foo int @@ -38,6 +39,7 @@ func main() { client, _ := geerpc.Dial("tcp", <-addr) defer func() { _ = client.Close() }() + time.Sleep(time.Second) // send request & receive response var wg sync.WaitGroup for i := 0; i < 5; i++ { diff --git a/gee-rpc/day5-http-debug/main/main.go b/gee-rpc/day5-http-debug/main/main.go index a71af74..6499b53 100644 --- a/gee-rpc/day5-http-debug/main/main.go +++ b/gee-rpc/day5-http-debug/main/main.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "sync" + "time" ) type Foo int @@ -31,6 +32,7 @@ func call(addrCh chan string) { client, _ := geerpc.DialHTTP("tcp", <-addrCh) defer func() { _ = client.Close() }() + time.Sleep(time.Second) // send request & receive response var wg sync.WaitGroup for i := 0; i < 5; i++ { diff --git a/gee-rpc/day6-discovery/client.go b/gee-rpc/day6-load-balance/client.go similarity index 100% rename from gee-rpc/day6-discovery/client.go rename to gee-rpc/day6-load-balance/client.go diff --git a/gee-rpc/day6-discovery/client_test.go b/gee-rpc/day6-load-balance/client_test.go similarity index 100% rename from gee-rpc/day6-discovery/client_test.go rename to gee-rpc/day6-load-balance/client_test.go diff --git a/gee-rpc/day6-discovery/codec/codec.go b/gee-rpc/day6-load-balance/codec/codec.go similarity index 100% rename from gee-rpc/day6-discovery/codec/codec.go rename to gee-rpc/day6-load-balance/codec/codec.go diff --git a/gee-rpc/day6-discovery/codec/gob.go b/gee-rpc/day6-load-balance/codec/gob.go similarity index 100% rename from gee-rpc/day6-discovery/codec/gob.go rename to gee-rpc/day6-load-balance/codec/gob.go diff --git a/gee-rpc/day6-discovery/debug.go b/gee-rpc/day6-load-balance/debug.go similarity index 100% rename from gee-rpc/day6-discovery/debug.go rename to gee-rpc/day6-load-balance/debug.go diff --git a/gee-rpc/day6-discovery/go.mod b/gee-rpc/day6-load-balance/go.mod similarity index 100% rename from gee-rpc/day6-discovery/go.mod rename to gee-rpc/day6-load-balance/go.mod diff --git a/gee-rpc/day6-discovery/main/main.go b/gee-rpc/day6-load-balance/main/main.go similarity index 98% rename from gee-rpc/day6-discovery/main/main.go rename to gee-rpc/day6-load-balance/main/main.go index 308da5c..d00f864 100644 --- a/gee-rpc/day6-discovery/main/main.go +++ b/gee-rpc/day6-load-balance/main/main.go @@ -92,6 +92,8 @@ func main() { addr1 := <-ch1 addr2 := <-ch2 + + time.Sleep(time.Second) call(addr1, addr2) broadcast(addr1, addr2) } diff --git a/gee-rpc/day6-discovery/server.go b/gee-rpc/day6-load-balance/server.go similarity index 100% rename from gee-rpc/day6-discovery/server.go rename to gee-rpc/day6-load-balance/server.go diff --git a/gee-rpc/day6-discovery/service.go b/gee-rpc/day6-load-balance/service.go similarity index 100% rename from gee-rpc/day6-discovery/service.go rename to gee-rpc/day6-load-balance/service.go diff --git a/gee-rpc/day6-discovery/service_test.go b/gee-rpc/day6-load-balance/service_test.go similarity index 100% rename from gee-rpc/day6-discovery/service_test.go rename to gee-rpc/day6-load-balance/service_test.go diff --git a/gee-rpc/day6-discovery/xclient/discovery.go b/gee-rpc/day6-load-balance/xclient/discovery.go similarity index 63% rename from gee-rpc/day6-discovery/xclient/discovery.go rename to gee-rpc/day6-load-balance/xclient/discovery.go index b473b5d..f823a7b 100644 --- a/gee-rpc/day6-discovery/xclient/discovery.go +++ b/gee-rpc/day6-load-balance/xclient/discovery.go @@ -1,6 +1,7 @@ package xclient import ( + "errors" "math/rand" "sync" "time" @@ -14,8 +15,10 @@ const ( ) type Discovery interface { - Get(mode SelectMode) string - All() []string + Refresh() error // refresh from remote registry + Update(servers []string) error + Get(mode SelectMode) (string, error) + GetAll() ([]string, error) } var _ Discovery = (*MultiServersDiscovery)(nil) @@ -29,38 +32,46 @@ type MultiServersDiscovery struct { index int // record the selected position for robin algorithm } +// Refresh doesn't make sense for MultiServersDiscovery, so ignore it +func (d *MultiServersDiscovery) Refresh() error { + return nil +} + // Update the servers of discovery dynamically if needed -func (d *MultiServersDiscovery) Update(servers []string) { +func (d *MultiServersDiscovery) Update(servers []string) error { d.mu.Lock() defer d.mu.Unlock() d.servers = servers + return nil } -func (d *MultiServersDiscovery) Get(mode SelectMode) string { +// Get a server according to mode +func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) { d.mu.Lock() defer d.mu.Unlock() if len(d.servers) == 0 { - return "" + return "", errors.New("rpc discovery: no available servers") } switch mode { case RandomSelect: - return d.servers[d.r.Intn(len(d.servers))] + return d.servers[d.r.Intn(len(d.servers))], nil case RoundRobinSelect: s := d.servers[d.index] d.index = (d.index + 1) % len(d.servers) - return s + return s, nil default: - return "" + return "", errors.New("rpc discovery: not supported select mode") } } -func (d *MultiServersDiscovery) All() []string { +// returns all servers in discovery +func (d *MultiServersDiscovery) GetAll() ([]string, error) { d.mu.RLock() defer d.mu.RUnlock() // return a copy of d.servers servers := make([]string, len(d.servers), len(d.servers)) copy(servers, d.servers) - return servers + return servers, nil } // NewMultiServerDiscovery creates a MultiServersDiscovery instance diff --git a/gee-rpc/day6-discovery/xclient/xclient.go b/gee-rpc/day6-load-balance/xclient/xclient.go similarity index 93% rename from gee-rpc/day6-discovery/xclient/xclient.go rename to gee-rpc/day6-load-balance/xclient/xclient.go index f3df99c..838b343 100644 --- a/gee-rpc/day6-discovery/xclient/xclient.go +++ b/gee-rpc/day6-load-balance/xclient/xclient.go @@ -65,15 +65,21 @@ func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod strin // and returns its error status. // xc will choose a proper server. func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { - rpcAddr := xc.d.Get(xc.mode) + rpcAddr, err := xc.d.Get(xc.mode) + if err != nil { + return err + } return xc.call(rpcAddr, ctx, serviceMethod, args, reply) } // Broadcast invokes the named function for every server registered in discovery func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error { - servers := xc.d.All() + servers, err := xc.d.GetAll() + if err != nil { + return err + } var wg sync.WaitGroup - var mu sync.Mutex + var mu sync.Mutex // protect e and replyDone var e error replyDone := reply == nil // if reply is nil, don't need to set value ctx, cancel := context.WithCancel(ctx) diff --git a/gee-rpc/day7-registry/client.go b/gee-rpc/day7-registry/client.go new file mode 100644 index 0000000..e9b4540 --- /dev/null +++ b/gee-rpc/day7-registry/client.go @@ -0,0 +1,324 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + seq := client.seq + client.pending[seq] = call + client.seq++ + return seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + call.Seq = seq + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type dialFunc func(network, address string, opt *Option) (client *Client, err error) + +func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + ch := make(chan clientResult) + go func() { + client, err := f(network, address, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +func dial(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn, opt) +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dial, network, address, opts...) +} + +func dialHTTP(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + _ = conn.Close() + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dialHTTP, network, address, opts...) +} + +// XDial use a general format to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } +} diff --git a/gee-rpc/day7-registry/client_test.go b/gee-rpc/day7-registry/client_test.go new file mode 100644 index 0000000..10b9817 --- /dev/null +++ b/gee-rpc/day7-registry/client_test.go @@ -0,0 +1,84 @@ +package geerpc + +import ( + "context" + "net" + "os" + "runtime" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + f := func(network, address string, opt *Option) (client *Client, err error) { + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day7-registry/codec/codec.go b/gee-rpc/day7-registry/codec/codec.go new file mode 100644 index 0000000..20b6ba7 --- /dev/null +++ b/gee-rpc/day7-registry/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" // not implemented +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day7-registry/codec/gob.go b/gee-rpc/day7-registry/codec/gob.go new file mode 100644 index 0000000..808d97b --- /dev/null +++ b/gee-rpc/day7-registry/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err := c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return err + } + if err := c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return err + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day7-registry/debug.go b/gee-rpc/day7-registry/debug.go new file mode 100644 index 0000000..ece1ffd --- /dev/null +++ b/gee-rpc/day7-registry/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day7-registry/go.mod b/gee-rpc/day7-registry/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day7-registry/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day7-registry/main/main.go b/gee-rpc/day7-registry/main/main.go new file mode 100644 index 0000000..0776312 --- /dev/null +++ b/gee-rpc/day7-registry/main/main.go @@ -0,0 +1,112 @@ +package main + +import ( + "context" + "geerpc" + registy "geerpc/registry" + "geerpc/xclient" + "log" + "net" + "net/http" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func (f Foo) Sleep(args Args, reply *int) error { + time.Sleep(time.Second * time.Duration(args.Num1)) + *reply = args.Num1 + args.Num2 + return nil +} + +func startRegistry(wg *sync.WaitGroup) { + l, _ := net.Listen("tcp", ":9999") + registy.HandleHTTP() + wg.Done() + _ = http.Serve(l, nil) +} + +func startServer(registry string, wg *sync.WaitGroup) { + var foo Foo + l, _ := net.Listen("tcp", ":0") + server := geerpc.NewServer() + _ = server.Register(&foo) + registy.Heartbeat(registry, "tcp@"+l.Addr().String(), 0) + wg.Done() + server.Accept(l) +} + +func foo(xc *xclient.XClient, ctx context.Context, typ, serviceMethod string, args *Args) { + var reply int + var err error + switch typ { + case "call": + err = xc.Call(ctx, serviceMethod, args, &reply) + case "broadcast": + err = xc.Broadcast(ctx, serviceMethod, args, &reply) + } + if err != nil { + log.Printf("%s %s error: %v", typ, serviceMethod, err) + } else { + log.Printf("%s Foo.Sum success: %d + %d = %d", typ, args.Num1, args.Num2, reply) + } +} + +func call(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + defer func() { _ = xc.Close() }() + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func broadcast(registry string) { + d := xclient.NewGeeRegistryDiscovery(registry, 0) + xc := xclient.NewXClient(d, xclient.RandomSelect, nil) + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i}) + // expect 2 - 5 timeout + ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i}) + }(i) + } + wg.Wait() +} + +func main() { + registryAddr := "http://localhost:9999/_geerpc_/registry" + var wg sync.WaitGroup + wg.Add(1) + go startRegistry(&wg) + wg.Wait() + + time.Sleep(time.Second) + wg.Add(2) + go startServer(registryAddr, &wg) + go startServer(registryAddr, &wg) + wg.Wait() + + time.Sleep(time.Second) + call(registryAddr) + broadcast(registryAddr) +} diff --git a/gee-rpc/day7-registry/registry/registry.go b/gee-rpc/day7-registry/registry/registry.go new file mode 100644 index 0000000..0d50b4d --- /dev/null +++ b/gee-rpc/day7-registry/registry/registry.go @@ -0,0 +1,123 @@ +package registy + +import ( + "log" + "net/http" + "strings" + "sync" + "time" +) + +// Registry is a simple register center, provide following functions. +// add a server and receive heartbeat to keep it alive. +// returns all alive servers and delete dead servers sync simultaneously. +type Registry struct { + timeout time.Duration + mu sync.Mutex // protect following + servers map[string]*ServerItem +} + +type ServerItem struct { + Addr string + start time.Time +} + +// New create a registry instance with timeout setting +func New(timeout time.Duration) *Registry { + return &Registry{ + servers: make(map[string]*ServerItem), + timeout: timeout, + } +} + +var DefaultRegister = New(defaultTimeout) + +func (r *Registry) putServer(addr string) { + r.mu.Lock() + defer r.mu.Unlock() + s := r.servers[addr] + if s == nil { + r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()} + } else { + s.start = time.Now() // if exists, update start time to keep alive + } +} + +func (r *Registry) aliveServers() []string { + r.mu.Lock() + defer r.mu.Unlock() + var alive []string + for addr, s := range r.servers { + if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) { + alive = append(alive, addr) + } else { + delete(r.servers, addr) + } + } + return alive +} + +const ( + defaultPath = "/_geerpc_/registry" + defaultTimeout = time.Minute * 5 +) + +// Runs at /_geerpc_/registry +func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + // keep it simple, server is in req.Header + w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ",")) + case "POST": + // keep it simple, server is in req.Header + addr := req.Header.Get("X-Geerpc-Server") + if addr == "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + r.putServer(addr) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +// HandleHTTP registers an HTTP handler for Registry messages on registryPath +func (r *Registry) HandleHTTP(registryPath string) { + http.Handle(registryPath, r) + log.Println("rpc registry path:", registryPath) +} + +func HandleHTTP() { + DefaultRegister.HandleHTTP(defaultPath) +} + +// Heartbeat send a heartbeat message every once in a while +// it's a helper function for a server to register or send heartbeat +func Heartbeat(registry, addr string, duration time.Duration) { + if duration == 0 { + // make sure there is enough time to send heart beat + // before it's removed from registry + duration = defaultTimeout - time.Duration(1)*time.Minute + } + var err error + err = sendHeartbeat(registry, addr) + go func() { + t := time.NewTicker(duration) + for err == nil { + <-t.C + err = sendHeartbeat(registry, addr) + } + }() +} + +func sendHeartbeat(registry, addr string) error { + log.Println(addr, "send heart beat to registry") + httpClient := &http.Client{} + req, _ := http.NewRequest("POST", registry, nil) + req.Header.Set("X-Geerpc-Server", addr) + if _, err := httpClient.Do(req); err != nil { + log.Println("rpc server: heart beat err:", err) + return err + } + return nil +} diff --git a/gee-rpc/day7-registry/server.go b/gee-rpc/day7-registry/server.go new file mode 100644 index 0000000..38fad20 --- /dev/null +++ b/gee-rpc/day7-registry/server.go @@ -0,0 +1,266 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} diff --git a/gee-rpc/day7-registry/service.go b/gee-rpc/day7-registry/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day7-registry/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day7-registry/service_test.go b/gee-rpc/day7-registry/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day7-registry/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day7-registry/xclient/discovery.go b/gee-rpc/day7-registry/xclient/discovery.go new file mode 100644 index 0000000..f823a7b --- /dev/null +++ b/gee-rpc/day7-registry/xclient/discovery.go @@ -0,0 +1,83 @@ +package xclient + +import ( + "errors" + "math/rand" + "sync" + "time" +) + +type SelectMode int + +const ( + RandomSelect SelectMode = iota // select randomly + RoundRobinSelect // select using Robbin algorithm +) + +type Discovery interface { + Refresh() error // refresh from remote registry + Update(servers []string) error + Get(mode SelectMode) (string, error) + GetAll() ([]string, error) +} + +var _ Discovery = (*MultiServersDiscovery)(nil) + +// MultiServersDiscovery is a discovery for multi servers without a registry center +// user provides the server addresses explicitly instead +type MultiServersDiscovery struct { + r *rand.Rand // generate random number + mu sync.RWMutex // protect following + servers []string + index int // record the selected position for robin algorithm +} + +// Refresh doesn't make sense for MultiServersDiscovery, so ignore it +func (d *MultiServersDiscovery) Refresh() error { + return nil +} + +// Update the servers of discovery dynamically if needed +func (d *MultiServersDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + return nil +} + +// Get a server according to mode +func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) { + d.mu.Lock() + defer d.mu.Unlock() + if len(d.servers) == 0 { + return "", errors.New("rpc discovery: no available servers") + } + switch mode { + case RandomSelect: + return d.servers[d.r.Intn(len(d.servers))], nil + case RoundRobinSelect: + s := d.servers[d.index] + d.index = (d.index + 1) % len(d.servers) + return s, nil + default: + return "", errors.New("rpc discovery: not supported select mode") + } +} + +// returns all servers in discovery +func (d *MultiServersDiscovery) GetAll() ([]string, error) { + d.mu.RLock() + defer d.mu.RUnlock() + // return a copy of d.servers + servers := make([]string, len(d.servers), len(d.servers)) + copy(servers, d.servers) + return servers, nil +} + +// NewMultiServerDiscovery creates a MultiServersDiscovery instance +func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery { + return &MultiServersDiscovery{ + servers: servers, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} diff --git a/gee-rpc/day7-registry/xclient/discovery_gee.go b/gee-rpc/day7-registry/xclient/discovery_gee.go new file mode 100644 index 0000000..865c30e --- /dev/null +++ b/gee-rpc/day7-registry/xclient/discovery_gee.go @@ -0,0 +1,74 @@ +package xclient + +import ( + "log" + "net/http" + "strings" + "time" +) + +type GeeRegistryDiscovery struct { + *MultiServersDiscovery + registry string + timeout time.Duration + lastUpdate time.Time +} + +const defaultUpdateTimeout = time.Second * 10 + +func (d *GeeRegistryDiscovery) Update(servers []string) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers + d.lastUpdate = time.Now() + return nil +} + +func (d *GeeRegistryDiscovery) Refresh() error { + d.mu.Lock() + defer d.mu.Unlock() + if d.lastUpdate.Add(d.timeout).After(time.Now()) { + return nil + } + log.Println("rpc registry: refresh servers from registry", d.registry) + resp, err := http.Get(d.registry) + if err != nil { + log.Println("rpc registry refresh err:", err) + return err + } + servers := strings.Split(resp.Header.Get("X-Geerpc-Servers"), ",") + d.servers = make([]string, 0, len(servers)) + for _, server := range servers { + if strings.TrimSpace(server) != "" { + d.servers = append(d.servers, strings.TrimSpace(server)) + } + } + d.lastUpdate = time.Now() + return nil +} + +func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) { + if err := d.Refresh(); err != nil { + return "", err + } + return d.MultiServersDiscovery.Get(mode) +} + +func (d *GeeRegistryDiscovery) GetAll() ([]string, error) { + if err := d.Refresh(); err != nil { + return nil, err + } + return d.MultiServersDiscovery.GetAll() +} + +func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery { + if timeout == 0 { + timeout = defaultUpdateTimeout + } + d := &GeeRegistryDiscovery{ + MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)), + registry: registerAddr, + timeout: timeout, + } + return d +} diff --git a/gee-rpc/day7-registry/xclient/xclient.go b/gee-rpc/day7-registry/xclient/xclient.go new file mode 100644 index 0000000..838b343 --- /dev/null +++ b/gee-rpc/day7-registry/xclient/xclient.go @@ -0,0 +1,109 @@ +package xclient + +import ( + "context" + . "geerpc" + "io" + "reflect" + "sync" +) + +type XClient struct { + d Discovery + mode SelectMode + opt *Option + mu sync.Mutex // protect following + clients map[string]*Client +} + +var _ io.Closer = (*XClient)(nil) + +func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient { + return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)} +} + +func (xc *XClient) Close() error { + xc.mu.Lock() + defer xc.mu.Unlock() + for key, client := range xc.clients { + // I have no idea how to deal with error, just ignore it. + _ = client.Close() + delete(xc.clients, key) + } + return nil +} + +func (xc *XClient) dial(rpcAddr string) (*Client, error) { + xc.mu.Lock() + defer xc.mu.Unlock() + client, ok := xc.clients[rpcAddr] + if ok && !client.IsAvailable() { + _ = client.Close() + delete(xc.clients, rpcAddr) + client = nil + } + if client == nil { + var err error + client, err = XDial(rpcAddr, xc.opt) + if err != nil { + return nil, err + } + xc.clients[rpcAddr] = client + } + return client, nil +} + +func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error { + client, err := xc.dial(rpcAddr) + if err != nil { + return err + } + return client.Call(ctx, serviceMethod, args, reply) +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +// xc will choose a proper server. +func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + rpcAddr, err := xc.d.Get(xc.mode) + if err != nil { + return err + } + return xc.call(rpcAddr, ctx, serviceMethod, args, reply) +} + +// Broadcast invokes the named function for every server registered in discovery +func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error { + servers, err := xc.d.GetAll() + if err != nil { + return err + } + var wg sync.WaitGroup + var mu sync.Mutex // protect e and replyDone + var e error + replyDone := reply == nil // if reply is nil, don't need to set value + ctx, cancel := context.WithCancel(ctx) + for _, rpcAddr := range servers { + wg.Add(1) + go func() { + defer wg.Done() + var clonedReply interface{} + if reply != nil { + clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface() + } + err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply) + mu.Lock() + if err != nil && e == nil { + e = err + cancel() // if any call failed, cancel unfinished calls + } + if err == nil && !replyDone { + reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) + replyDone = true + } + mu.Unlock() + }() + } + wg.Wait() + return e +}