From 1b78c37452bca4f6eddcda45db4ca816188b9d5f Mon Sep 17 00:00:00 2001 From: gzdaijie Date: Sat, 3 Oct 2020 10:54:28 +0800 Subject: [PATCH] gee-rpc: implement day5-timeout --- gee-rpc/day1-codec/main/main.go | 6 +- gee-rpc/day1-codec/server.go | 6 +- gee-rpc/day2-client/client.go | 39 ++- gee-rpc/day2-client/server.go | 6 +- gee-rpc/day3-service/client.go | 39 ++- gee-rpc/day3-service/server.go | 8 +- .../client.go | 92 +++--- gee-rpc/day4-timeout/client_test.go | 63 ++++ .../codec/codec.go | 0 .../codec/gob.go | 0 .../{day4-http-debug => day4-timeout}/go.mod | 0 gee-rpc/day4-timeout/main/main.go | 56 ++++ gee-rpc/day4-timeout/server.go | 228 +++++++++++++ .../service.go | 0 .../service_test.go | 0 gee-rpc/day5-http-debug/client.go | 310 ++++++++++++++++++ gee-rpc/day5-http-debug/client_test.go | 63 ++++ gee-rpc/day5-http-debug/codec/codec.go | 34 ++ gee-rpc/day5-http-debug/codec/gob.go | 57 ++++ .../debug.go | 0 gee-rpc/day5-http-debug/go.mod | 3 + .../main/main.go | 3 +- .../server.go | 63 +++- gee-rpc/day5-http-debug/service.go | 99 ++++++ gee-rpc/day5-http-debug/service_test.go | 48 +++ 25 files changed, 1117 insertions(+), 106 deletions(-) rename gee-rpc/{day4-http-debug => day4-timeout}/client.go (74%) create mode 100644 gee-rpc/day4-timeout/client_test.go rename gee-rpc/{day4-http-debug => day4-timeout}/codec/codec.go (100%) rename gee-rpc/{day4-http-debug => day4-timeout}/codec/gob.go (100%) rename gee-rpc/{day4-http-debug => day4-timeout}/go.mod (100%) create mode 100644 gee-rpc/day4-timeout/main/main.go create mode 100644 gee-rpc/day4-timeout/server.go rename gee-rpc/{day4-http-debug => day4-timeout}/service.go (100%) rename gee-rpc/{day4-http-debug => day4-timeout}/service_test.go (100%) create mode 100644 gee-rpc/day5-http-debug/client.go create mode 100644 gee-rpc/day5-http-debug/client_test.go create mode 100644 gee-rpc/day5-http-debug/codec/codec.go create mode 100644 gee-rpc/day5-http-debug/codec/gob.go rename gee-rpc/{day4-http-debug => day5-http-debug}/debug.go (100%) create mode 100644 gee-rpc/day5-http-debug/go.mod rename gee-rpc/{day4-http-debug => day5-http-debug}/main/main.go (90%) rename gee-rpc/{day4-http-debug => day5-http-debug}/server.go (80%) create mode 100644 gee-rpc/day5-http-debug/service.go create mode 100644 gee-rpc/day5-http-debug/service_test.go diff --git a/gee-rpc/day1-codec/main/main.go b/gee-rpc/day1-codec/main/main.go index a715209..8f29531 100644 --- a/gee-rpc/day1-codec/main/main.go +++ b/gee-rpc/day1-codec/main/main.go @@ -29,11 +29,7 @@ func main() { defer func() { _ = conn.Close() }() // send options - _ = json.NewEncoder(conn).Encode(&geerpc.Options{ - MagicNumber: geerpc.MagicNumber, - CodecType: codec.GobType, - }) - + _ = json.NewEncoder(conn).Encode(geerpc.DefaultOption) cc := codec.NewGobCodec(conn) // send request & receive response for i := 0; i < 5; i++ { diff --git a/gee-rpc/day1-codec/server.go b/gee-rpc/day1-codec/server.go index 06ef763..fb93e4f 100644 --- a/gee-rpc/day1-codec/server.go +++ b/gee-rpc/day1-codec/server.go @@ -17,12 +17,12 @@ import ( const MagicNumber = 0x3bef5c -type Options struct { +type Option struct { MagicNumber int // MagicNumber marks this's a geerpc request CodecType codec.Type // client may choose different Codec to encode body } -var defaultOptions = &Options{ +var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, } @@ -42,7 +42,7 @@ var DefaultServer = NewServer() // ServeConn blocks, serving the connection until the client hangs up. func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func() { _ = conn.Close() }() - var opt Options + var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: ", err) return diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go index b73c461..f9b09ee 100644 --- a/gee-rpc/day2-client/client.go +++ b/gee-rpc/day2-client/client.go @@ -17,6 +17,7 @@ import ( // Call represents an active RPC. type Call struct { + Seq uint64 ServiceMethod string // format "." Args interface{} // arguments to the function Reply interface{} // reply from the function @@ -34,7 +35,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec - opt *Options + opt *Option sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -96,6 +97,7 @@ func (client *Client) send(call *Call) { // register this call. seq, err := client.registerCall(call) + call.Seq = seq if err != nil { call.Error = err call.done() @@ -173,35 +175,31 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return call.Error } -func parseOptions(opts ...*Options) (*Options, error) { +func parseOptions(opts ...*Option) (*Option, error) { // if opts is nil or pass nil as parameter if len(opts) == 0 || opts[0] == nil { - return defaultOptions, nil + return DefaultOption, nil } if len(opts) != 1 { return nil, errors.New("number of options is more than 1") } opt := opts[0] - opt.MagicNumber = defaultOptions.MagicNumber + opt.MagicNumber = DefaultOption.MagicNumber if opt.CodecType == "" { - opt.CodecType = defaultOptions.CodecType + opt.CodecType = DefaultOption.CodecType } return opt, nil } -func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { - opt, err := parseOptions(opts...) - if err != nil { - return nil, err - } +func NewClient(conn net.Conn, opt *Option) (*Client, error) { f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { - err = fmt.Errorf("invalid codec type %s", opt.CodecType) + err := fmt.Errorf("invalid codec type %s", opt.CodecType) log.Println("rpc client: codec error:", err) return nil, err } // send options with server - if err = json.NewEncoder(conn).Encode(opt); err != nil { + if err := json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) _ = conn.Close() return nil, err @@ -209,8 +207,9 @@ func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec, opt *Options) *Client { +func newClientCodec(cc codec.Codec, opt *Option) *Client { client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call cc: cc, opt: opt, pending: make(map[uint64]*Call), @@ -219,11 +218,19 @@ func newClientCodec(cc codec.Codec, opt *Options) *Client { return client } -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Options) (*Client, error) { +func dial(network, address string, opt *Option) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opts...) + return NewClient(conn, opt) +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + return dial(network, address, opt) } diff --git a/gee-rpc/day2-client/server.go b/gee-rpc/day2-client/server.go index 06ef763..fb93e4f 100644 --- a/gee-rpc/day2-client/server.go +++ b/gee-rpc/day2-client/server.go @@ -17,12 +17,12 @@ import ( const MagicNumber = 0x3bef5c -type Options struct { +type Option struct { MagicNumber int // MagicNumber marks this's a geerpc request CodecType codec.Type // client may choose different Codec to encode body } -var defaultOptions = &Options{ +var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, } @@ -42,7 +42,7 @@ var DefaultServer = NewServer() // ServeConn blocks, serving the connection until the client hangs up. func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func() { _ = conn.Close() }() - var opt Options + var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: ", err) return diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go index b73c461..f9b09ee 100644 --- a/gee-rpc/day3-service/client.go +++ b/gee-rpc/day3-service/client.go @@ -17,6 +17,7 @@ import ( // Call represents an active RPC. type Call struct { + Seq uint64 ServiceMethod string // format "." Args interface{} // arguments to the function Reply interface{} // reply from the function @@ -34,7 +35,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec - opt *Options + opt *Option sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -96,6 +97,7 @@ func (client *Client) send(call *Call) { // register this call. seq, err := client.registerCall(call) + call.Seq = seq if err != nil { call.Error = err call.done() @@ -173,35 +175,31 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return call.Error } -func parseOptions(opts ...*Options) (*Options, error) { +func parseOptions(opts ...*Option) (*Option, error) { // if opts is nil or pass nil as parameter if len(opts) == 0 || opts[0] == nil { - return defaultOptions, nil + return DefaultOption, nil } if len(opts) != 1 { return nil, errors.New("number of options is more than 1") } opt := opts[0] - opt.MagicNumber = defaultOptions.MagicNumber + opt.MagicNumber = DefaultOption.MagicNumber if opt.CodecType == "" { - opt.CodecType = defaultOptions.CodecType + opt.CodecType = DefaultOption.CodecType } return opt, nil } -func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { - opt, err := parseOptions(opts...) - if err != nil { - return nil, err - } +func NewClient(conn net.Conn, opt *Option) (*Client, error) { f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { - err = fmt.Errorf("invalid codec type %s", opt.CodecType) + err := fmt.Errorf("invalid codec type %s", opt.CodecType) log.Println("rpc client: codec error:", err) return nil, err } // send options with server - if err = json.NewEncoder(conn).Encode(opt); err != nil { + if err := json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) _ = conn.Close() return nil, err @@ -209,8 +207,9 @@ func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec, opt *Options) *Client { +func newClientCodec(cc codec.Codec, opt *Option) *Client { client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call cc: cc, opt: opt, pending: make(map[uint64]*Call), @@ -219,11 +218,19 @@ func newClientCodec(cc codec.Codec, opt *Options) *Client { return client } -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Options) (*Client, error) { +func dial(network, address string, opt *Option) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opts...) + return NewClient(conn, opt) +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + return dial(network, address, opt) } diff --git a/gee-rpc/day3-service/server.go b/gee-rpc/day3-service/server.go index 6558dad..4634394 100644 --- a/gee-rpc/day3-service/server.go +++ b/gee-rpc/day3-service/server.go @@ -18,12 +18,12 @@ import ( const MagicNumber = 0x3bef5c -type Options struct { +type Option struct { MagicNumber int // MagicNumber marks this's a geerpc request CodecType codec.Type // client may choose different Codec to encode body } -var defaultOptions = &Options{ +var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, } @@ -45,7 +45,7 @@ var DefaultServer = NewServer() // ServeConn blocks, serving the connection until the client hangs up. func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func() { _ = conn.Close() }() - var opt Options + var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: ", err) return @@ -162,6 +162,8 @@ func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync. err := req.svc.call(req.mtype, req.argv, req.replyv) if err != nil { req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + return } server.sendResponse(cc, req.h, req.replyv.Interface(), sending) } diff --git a/gee-rpc/day4-http-debug/client.go b/gee-rpc/day4-timeout/client.go similarity index 74% rename from gee-rpc/day4-http-debug/client.go rename to gee-rpc/day4-timeout/client.go index 7ae8938..c9900fa 100644 --- a/gee-rpc/day4-http-debug/client.go +++ b/gee-rpc/day4-timeout/client.go @@ -5,7 +5,7 @@ package geerpc import ( - "bufio" + "context" "encoding/json" "errors" "fmt" @@ -13,12 +13,13 @@ import ( "io" "log" "net" - "net/http" "sync" + "time" ) // Call represents an active RPC. type Call struct { + Seq uint64 ServiceMethod string // format "." Args interface{} // arguments to the function Reply interface{} // reply from the function @@ -36,7 +37,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec - opt *Options + opt *Option sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -98,6 +99,7 @@ func (client *Client) send(call *Call) { // register this call. seq, err := client.registerCall(call) + call.Seq = seq if err != nil { call.Error = err call.done() @@ -170,40 +172,42 @@ func (client *Client) Go(serviceMethod string, args, reply interface{}, done cha // Call invokes the named function, waits for it to complete, // and returns its error status. -func (client *Client) Call(serviceMethod string, args, reply interface{}) error { - call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done - return call.Error +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } } -func parseOptions(opts ...*Options) (*Options, error) { +func parseOptions(opts ...*Option) (*Option, error) { // if opts is nil or pass nil as parameter if len(opts) == 0 || opts[0] == nil { - return defaultOptions, nil + return DefaultOption, nil } if len(opts) != 1 { return nil, errors.New("number of options is more than 1") } opt := opts[0] - opt.MagicNumber = defaultOptions.MagicNumber + opt.MagicNumber = DefaultOption.MagicNumber if opt.CodecType == "" { - opt.CodecType = defaultOptions.CodecType + opt.CodecType = DefaultOption.CodecType } return opt, nil } -func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { - opt, err := parseOptions(opts...) - if err != nil { - return nil, err - } +func NewClient(conn net.Conn, opt *Option) (*Client, error) { f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { - err = fmt.Errorf("invalid codec type %s", opt.CodecType) + err := fmt.Errorf("invalid codec type %s", opt.CodecType) log.Println("rpc client: codec error:", err) return nil, err } // send options with server - if err = json.NewEncoder(conn).Encode(opt); err != nil { + if err := json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) _ = conn.Close() return nil, err @@ -211,8 +215,9 @@ func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec, opt *Options) *Client { +func newClientCodec(cc codec.Codec, opt *Option) *Client { client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call cc: cc, opt: opt, pending: make(map[uint64]*Call), @@ -221,39 +226,44 @@ func newClientCodec(cc codec.Codec, opt *Options) *Client { return client } -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Options) (*Client, error) { +func dial(network, address string, opt *Option) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opts...) + return NewClient(conn, opt) } -// DialHTTP connects to an HTTP RPC server at the specified network address -// listening on the default HTTP RPC path. -func DialHTTP(network, address string, opts ...*Options) (*Client, error) { - return DialHTTPPath(network, address, defaultRPCPath, opts...) +type clientResult struct { + client *Client + err error } -// DialHTTPPath connects to an HTTP RPC server -// at the specified network address and path. -func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err +func dialTimeout(f func() (client *Client, err error), timeout time.Duration) (*Client, error) { + if timeout == 0 { + return f() + } + ch := make(chan clientResult) + go func() { + client, err := f() + ch <- clientResult{client: client, err: err} + }() + select { + case <-time.After(timeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", timeout) + case result := <-ch: + return result.client, result.err } - _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", path)) +} - // Require successful HTTP response - // before switching to RPC protocol. - resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) - if err == nil && resp.Status == connected { - return NewClient(conn, opts...) +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } - if err == nil { - err = errors.New("unexpected HTTP response: " + resp.Status) + f := func() (client *Client, err error) { + return dial(network, address, opt) } - _ = conn.Close() - return nil, err + return dialTimeout(f, opt.ConnectTimeout) } diff --git a/gee-rpc/day4-timeout/client_test.go b/gee-rpc/day4-timeout/client_test.go new file mode 100644 index 0000000..5669f9c --- /dev/null +++ b/gee-rpc/day4-timeout/client_test.go @@ -0,0 +1,63 @@ +package geerpc + +import ( + "context" + "net" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + f := func() (client *Client, err error) { + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, time.Second) + _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, 0) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} diff --git a/gee-rpc/day4-http-debug/codec/codec.go b/gee-rpc/day4-timeout/codec/codec.go similarity index 100% rename from gee-rpc/day4-http-debug/codec/codec.go rename to gee-rpc/day4-timeout/codec/codec.go diff --git a/gee-rpc/day4-http-debug/codec/gob.go b/gee-rpc/day4-timeout/codec/gob.go similarity index 100% rename from gee-rpc/day4-http-debug/codec/gob.go rename to gee-rpc/day4-timeout/codec/gob.go diff --git a/gee-rpc/day4-http-debug/go.mod b/gee-rpc/day4-timeout/go.mod similarity index 100% rename from gee-rpc/day4-http-debug/go.mod rename to gee-rpc/day4-timeout/go.mod diff --git a/gee-rpc/day4-timeout/main/main.go b/gee-rpc/day4-timeout/main/main.go new file mode 100644 index 0000000..e5e6050 --- /dev/null +++ b/gee-rpc/day4-timeout/main/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "geerpc" + "log" + "net" + "sync" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addr chan string) { + var foo Foo + if err := geerpc.Register(&foo); err != nil { + log.Fatal("register error:", err) + } + // pick a free port + l, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatal("network error:", err) + } + log.Println("start rpc server on", l.Addr()) + addr <- l.Addr().String() + geerpc.Accept(l) +} + +func main() { + addr := make(chan string) + go startServer(addr) + client, _ := geerpc.Dial("tcp", <-addr) + defer func() { _ = client.Close() }() + + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} diff --git a/gee-rpc/day4-timeout/server.go b/gee-rpc/day4-timeout/server.go new file mode 100644 index 0000000..a049914 --- /dev/null +++ b/gee-rpc/day4-timeout/server.go @@ -0,0 +1,228 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } diff --git a/gee-rpc/day4-http-debug/service.go b/gee-rpc/day4-timeout/service.go similarity index 100% rename from gee-rpc/day4-http-debug/service.go rename to gee-rpc/day4-timeout/service.go diff --git a/gee-rpc/day4-http-debug/service_test.go b/gee-rpc/day4-timeout/service_test.go similarity index 100% rename from gee-rpc/day4-http-debug/service_test.go rename to gee-rpc/day4-timeout/service_test.go diff --git a/gee-rpc/day5-http-debug/client.go b/gee-rpc/day5-http-debug/client.go new file mode 100644 index 0000000..5f336d7 --- /dev/null +++ b/gee-rpc/day5-http-debug/client.go @@ -0,0 +1,310 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closed bool // user has called Close +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closed { + return ErrShutdown + } + client.closed = true + return client.cc.Close() +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closed { + return 0, ErrShutdown + } + seq := client.seq + client.pending[seq] = call + client.seq++ + return seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + call.Seq = seq + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +func dial(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn, opt) +} + +type clientResult struct { + client *Client + err error +} + +func dialTimeout(f func() (client *Client, err error), timeout time.Duration) (*Client, error) { + if timeout == 0 { + return f() + } + ch := make(chan clientResult) + go func() { + client, err := f() + ch <- clientResult{client: client, err: err} + }() + select { + case <-time.After(timeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", timeout) + case result := <-ch: + return result.client, result.err + } +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + f := func() (client *Client, err error) { + return dial(network, address, opt) + } + return dialTimeout(f, opt.ConnectTimeout) +} + +func dialHTTPPath(network, address, path string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", path)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + _ = conn.Close() + return nil, err +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + f := func() (*Client, error) { + return dialHTTPPath(network, address, path, opt) + } + return dialTimeout(f, opt.ConnectTimeout) +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return DialHTTPPath(network, address, defaultRPCPath, opts...) +} diff --git a/gee-rpc/day5-http-debug/client_test.go b/gee-rpc/day5-http-debug/client_test.go new file mode 100644 index 0000000..5669f9c --- /dev/null +++ b/gee-rpc/day5-http-debug/client_test.go @@ -0,0 +1,63 @@ +package geerpc + +import ( + "context" + "net" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + f := func() (client *Client, err error) { + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, time.Second) + _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, 0) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} diff --git a/gee-rpc/day5-http-debug/codec/codec.go b/gee-rpc/day5-http-debug/codec/codec.go new file mode 100644 index 0000000..ba28fba --- /dev/null +++ b/gee-rpc/day5-http-debug/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day5-http-debug/codec/gob.go b/gee-rpc/day5-http-debug/codec/gob.go new file mode 100644 index 0000000..808d97b --- /dev/null +++ b/gee-rpc/day5-http-debug/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err := c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return err + } + if err := c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return err + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day4-http-debug/debug.go b/gee-rpc/day5-http-debug/debug.go similarity index 100% rename from gee-rpc/day4-http-debug/debug.go rename to gee-rpc/day5-http-debug/debug.go diff --git a/gee-rpc/day5-http-debug/go.mod b/gee-rpc/day5-http-debug/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day5-http-debug/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day4-http-debug/main/main.go b/gee-rpc/day5-http-debug/main/main.go similarity index 90% rename from gee-rpc/day4-http-debug/main/main.go rename to gee-rpc/day5-http-debug/main/main.go index 13bb592..f25d909 100644 --- a/gee-rpc/day4-http-debug/main/main.go +++ b/gee-rpc/day5-http-debug/main/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "geerpc" "log" "net/http" @@ -38,7 +39,7 @@ func call() { defer wg.Done() args := &Args{Num1: i, Num2: i * i} var reply int - if err := client.Call("Foo.Sum", args, &reply); err != nil { + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { log.Fatal("call Foo.Sum error:", err) } log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) diff --git a/gee-rpc/day4-http-debug/server.go b/gee-rpc/day5-http-debug/server.go similarity index 80% rename from gee-rpc/day4-http-debug/server.go rename to gee-rpc/day5-http-debug/server.go index ffb85ee..5be1c66 100644 --- a/gee-rpc/day4-http-debug/server.go +++ b/gee-rpc/day5-http-debug/server.go @@ -7,6 +7,7 @@ package geerpc import ( "encoding/json" "errors" + "fmt" "geerpc/codec" "io" "log" @@ -15,18 +16,22 @@ import ( "reflect" "strings" "sync" + "time" ) const MagicNumber = 0x3bef5c -type Options struct { - MagicNumber int // MagicNumber marks this's a geerpc request - CodecType codec.Type // client may choose different Codec to encode body +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration } -var defaultOptions = &Options{ - MagicNumber: MagicNumber, - CodecType: codec.GobType, +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, } // Server represents an RPC Server. @@ -46,7 +51,7 @@ var DefaultServer = NewServer() // ServeConn blocks, serving the connection until the client hangs up. func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func() { _ = conn.Close() }() - var opt Options + var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: ", err) return @@ -60,13 +65,13 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { log.Printf("rpc server: invalid codec type %s", opt.CodecType) return } - server.serveCodec(f(conn)) + server.serveCodec(f(conn), &opt) } // invalidRequest is a placeholder for response argv when error occurs var invalidRequest = struct{}{} -func (server *Server) serveCodec(cc codec.Codec) { +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { sending := new(sync.Mutex) // make sure to send a complete response wg := new(sync.WaitGroup) // wait until all request are handled for { @@ -80,7 +85,7 @@ func (server *Server) serveCodec(cc codec.Codec) { continue } wg.Add(1) - go server.handleRequest(cc, req, sending, wg) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) } wg.Wait() _ = cc.Close() @@ -108,7 +113,7 @@ func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { dot := strings.LastIndex(serviceMethod, ".") if dot < 0 { - err = errors.New("rpc server: service/Method request ill-formed: " + serviceMethod) + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) return } serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] @@ -120,7 +125,7 @@ func (server *Server) findService(serviceMethod string) (svc *service, mtype *me svc = svci.(*service) mtype = svc.method[methodName] if mtype == nil { - err = errors.New("rpc server: can't find Method " + methodName) + err = errors.New("rpc server: can't find method " + methodName) } return } @@ -158,13 +163,35 @@ func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interfa } } -func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() - err := req.svc.call(req.mtype, req.argv, req.replyv) - if err != nil { - req.h.Error = err.Error() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent } - server.sendResponse(cc, req.h, req.replyv.Interface(), sending) } // Accept accepts connections on the listener and serves requests @@ -186,7 +213,7 @@ func Accept(lis net.Listener) { DefaultServer.Accept(lis) } // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: -// - exported Method of exported type +// - exported method of exported type // - two arguments, both of exported type // - the second argument is a pointer // - one return value, of type error diff --git a/gee-rpc/day5-http-debug/service.go b/gee-rpc/day5-http-debug/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day5-http-debug/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day5-http-debug/service_test.go b/gee-rpc/day5-http-debug/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day5-http-debug/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +}