diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go index f9b09ee..3df6ec9 100644 --- a/gee-rpc/day2-client/client.go +++ b/gee-rpc/day2-client/client.go @@ -34,14 +34,15 @@ func (call *Call) done() { // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { - cc codec.Codec - opt *Option - sending sync.Mutex // protect following - header codec.Header - mu sync.Mutex // protect following - seq uint64 - pending map[uint64]*Call - closed bool // user has called Close + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop } var _ io.Closer = (*Client)(nil) @@ -52,17 +53,24 @@ var ErrShutdown = errors.New("connection is shut down") func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing { return ErrShutdown } - client.closed = true + client.closing = true return client.cc.Close() } +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + func (client *Client) registerCall(call *Call) (uint64, error) { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing || client.shutdown { return 0, ErrShutdown } seq := client.seq @@ -84,6 +92,7 @@ func (client *Client) terminateCalls(err error) { defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() + client.shutdown = true for _, call := range client.pending { call.Error = err call.done() diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go index f9b09ee..3df6ec9 100644 --- a/gee-rpc/day3-service/client.go +++ b/gee-rpc/day3-service/client.go @@ -34,14 +34,15 @@ func (call *Call) done() { // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { - cc codec.Codec - opt *Option - sending sync.Mutex // protect following - header codec.Header - mu sync.Mutex // protect following - seq uint64 - pending map[uint64]*Call - closed bool // user has called Close + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop } var _ io.Closer = (*Client)(nil) @@ -52,17 +53,24 @@ var ErrShutdown = errors.New("connection is shut down") func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing { return ErrShutdown } - client.closed = true + client.closing = true return client.cc.Close() } +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + func (client *Client) registerCall(call *Call) (uint64, error) { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing || client.shutdown { return 0, ErrShutdown } seq := client.seq @@ -84,6 +92,7 @@ func (client *Client) terminateCalls(err error) { defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() + client.shutdown = true for _, call := range client.pending { call.Error = err call.done() diff --git a/gee-rpc/day4-timeout/client.go b/gee-rpc/day4-timeout/client.go index c9900fa..b301647 100644 --- a/gee-rpc/day4-timeout/client.go +++ b/gee-rpc/day4-timeout/client.go @@ -36,14 +36,15 @@ func (call *Call) done() { // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { - cc codec.Codec - opt *Option - sending sync.Mutex // protect following - header codec.Header - mu sync.Mutex // protect following - seq uint64 - pending map[uint64]*Call - closed bool // user has called Close + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop } var _ io.Closer = (*Client)(nil) @@ -54,17 +55,24 @@ var ErrShutdown = errors.New("connection is shut down") func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing { return ErrShutdown } - client.closed = true + client.closing = true return client.cc.Close() } +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + func (client *Client) registerCall(call *Call) (uint64, error) { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing || client.shutdown { return 0, ErrShutdown } seq := client.seq @@ -86,6 +94,7 @@ func (client *Client) terminateCalls(err error) { defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() + client.shutdown = true for _, call := range client.pending { call.Error = err call.done() @@ -226,44 +235,44 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client { return client } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - type clientResult struct { client *Client err error } -func dialTimeout(f func() (client *Client, err error), timeout time.Duration) (*Client, error) { - if timeout == 0 { - return f() +type dialFunc func(network, address string, opt *Option) (client *Client, err error) + +func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } ch := make(chan clientResult) go func() { - client, err := f() + client, err := f(network, address, opt) ch <- clientResult{client: client, err: err} }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } select { - case <-time.After(timeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", timeout) + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) +func dial(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) if err != nil { return nil, err } - f := func() (client *Client, err error) { - return dial(network, address, opt) - } - return dialTimeout(f, opt.ConnectTimeout) + return NewClient(conn, opt) +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dial, network, address, opts...) } diff --git a/gee-rpc/day4-timeout/client_test.go b/gee-rpc/day4-timeout/client_test.go index 5669f9c..ab2fa64 100644 --- a/gee-rpc/day4-timeout/client_test.go +++ b/gee-rpc/day4-timeout/client_test.go @@ -26,16 +26,16 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func() (client *Client, err error) { + f := func(network, address string, opt *Option) (client *Client, err error) { time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, time.Second) + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, 0) + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } diff --git a/gee-rpc/day5-http-debug/client.go b/gee-rpc/day5-http-debug/client.go index 5f336d7..e9b4540 100644 --- a/gee-rpc/day5-http-debug/client.go +++ b/gee-rpc/day5-http-debug/client.go @@ -15,6 +15,7 @@ import ( "log" "net" "net/http" + "strings" "sync" "time" ) @@ -38,14 +39,15 @@ func (call *Call) done() { // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { - cc codec.Codec - opt *Option - sending sync.Mutex // protect following - header codec.Header - mu sync.Mutex // protect following - seq uint64 - pending map[uint64]*Call - closed bool // user has called Close + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop } var _ io.Closer = (*Client)(nil) @@ -56,17 +58,24 @@ var ErrShutdown = errors.New("connection is shut down") func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing { return ErrShutdown } - client.closed = true + client.closing = true return client.cc.Close() } +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + func (client *Client) registerCall(call *Call) (uint64, error) { client.mu.Lock() defer client.mu.Unlock() - if client.closed { + if client.closing || client.shutdown { return 0, ErrShutdown } seq := client.seq @@ -88,6 +97,7 @@ func (client *Client) terminateCalls(err error) { defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() + client.shutdown = true for _, call := range client.pending { call.Error = err call.done() @@ -228,54 +238,54 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client { return client } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - type clientResult struct { client *Client err error } -func dialTimeout(f func() (client *Client, err error), timeout time.Duration) (*Client, error) { - if timeout == 0 { - return f() +type dialFunc func(network, address string, opt *Option) (client *Client, err error) + +func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } ch := make(chan clientResult) go func() { - client, err := f() + client, err := f(network, address, opt) ch <- clientResult{client: client, err: err} }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } select { - case <-time.After(timeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", timeout) + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) +func dial(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) if err != nil { return nil, err } - f := func() (client *Client, err error) { - return dial(network, address, opt) - } - return dialTimeout(f, opt.ConnectTimeout) + return NewClient(conn, opt) } -func dialHTTPPath(network, address, path string, opt *Option) (*Client, error) { +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dial, network, address, opts...) +} + +func dialHTTP(network, address string, opt *Option) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } - _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", path)) + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) // Require successful HTTP response // before switching to RPC protocol. @@ -290,21 +300,25 @@ func dialHTTPPath(network, address, path string, opt *Option) (*Client, error) { return nil, err } -// DialHTTPPath connects to an HTTP RPC server -// at the specified network address and path. -func DialHTTPPath(network, address, path string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) - if err != nil { - return nil, err - } - f := func() (*Client, error) { - return dialHTTPPath(network, address, path, opt) - } - return dialTimeout(f, opt.ConnectTimeout) -} - // DialHTTP connects to an HTTP RPC server at the specified network address // listening on the default HTTP RPC path. func DialHTTP(network, address string, opts ...*Option) (*Client, error) { - return DialHTTPPath(network, address, defaultRPCPath, opts...) + return dialTimeout(dialHTTP, network, address, opts...) +} + +// XDial use a general format to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr, opts...) + } } diff --git a/gee-rpc/day5-http-debug/client_test.go b/gee-rpc/day5-http-debug/client_test.go index 5669f9c..10b9817 100644 --- a/gee-rpc/day5-http-debug/client_test.go +++ b/gee-rpc/day5-http-debug/client_test.go @@ -3,6 +3,8 @@ package geerpc import ( "context" "net" + "os" + "runtime" "strings" "testing" "time" @@ -26,16 +28,16 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func() (client *Client, err error) { + f := func(network, address string, opt *Option) (client *Client, err error) { time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, time.Second) + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, 0) + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } @@ -61,3 +63,22 @@ func TestClient_Call(t *testing.T) { _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") }) } + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day5-http-debug/debug.go b/gee-rpc/day5-http-debug/debug.go index d76de55..ece1ffd 100644 --- a/gee-rpc/day5-http-debug/debug.go +++ b/gee-rpc/day5-http-debug/debug.go @@ -41,7 +41,7 @@ type debugService struct { Method map[string]*methodType } -// Runs at /debug/rpc +// Runs at /debug/geerpc func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Build a sorted version of the data. var services []debugService diff --git a/gee-rpc/day5-http-debug/main/main.go b/gee-rpc/day5-http-debug/main/main.go index f25d909..a71af74 100644 --- a/gee-rpc/day5-http-debug/main/main.go +++ b/gee-rpc/day5-http-debug/main/main.go @@ -4,9 +4,9 @@ import ( "context" "geerpc" "log" + "net" "net/http" "sync" - "time" ) type Foo int @@ -18,17 +18,17 @@ func (f Foo) Sum(args Args, reply *int) error { return nil } -func startServer(addr string) { +func startServer(addrCh chan string) { var foo Foo + l, _ := net.Listen("tcp", ":9999") _ = geerpc.Register(&foo) geerpc.HandleHTTP() - log.Fatal(http.ListenAndServe(addr, nil)) + addrCh <- l.Addr().String() + _ = http.Serve(l, nil) } -func call() { - // start server may cost some time - time.Sleep(time.Second) - client, _ := geerpc.DialHTTP("tcp", ":9999") +func call(addrCh chan string) { + client, _ := geerpc.DialHTTP("tcp", <-addrCh) defer func() { _ = client.Close() }() // send request & receive response @@ -49,6 +49,7 @@ func call() { } func main() { - go call() - startServer(":9999") + ch := make(chan string) + go call(ch) + startServer(ch) } diff --git a/gee-rpc/day5-http-debug/server.go b/gee-rpc/day5-http-debug/server.go index 5be1c66..38fad20 100644 --- a/gee-rpc/day5-http-debug/server.go +++ b/gee-rpc/day5-http-debug/server.go @@ -254,12 +254,13 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { // HandleHTTP registers an HTTP handler for RPC messages on rpcPath, // and a debugging handler on debugPath. // It is still necessary to invoke http.Serve(), typically in a go statement. -func (server *Server) HandleHTTP(rpcPath, debugPath string) { - http.Handle(rpcPath, server) - http.Handle(debugPath, debugHTTP{server}) - log.Println("rpc server debug path:", debugPath) +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) } +// HandleHTTP is a convenient approach for default server to register HTTP handlers func HandleHTTP() { - DefaultServer.HandleHTTP(defaultRPCPath, defaultDebugPath) + DefaultServer.HandleHTTP() } diff --git a/gee-rpc/day6-discovery/client.go b/gee-rpc/day6-discovery/client.go new file mode 100644 index 0000000..d958696 --- /dev/null +++ b/gee-rpc/day6-discovery/client.go @@ -0,0 +1,324 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// Call represents an active RPC. +type Call struct { + Seq uint64 + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + opt *Option + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing { + return ErrShutdown + } + client.closing = true + return client.cc.Close() +} + +// IsAvailable return true if the client does work +func (client *Client) IsAvailable() bool { + client.mu.Lock() + defer client.mu.Unlock() + return !client.shutdown && !client.closing +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closing || client.shutdown { + return 0, ErrShutdown + } + seq := client.seq + client.pending[seq] = call + client.seq++ + return seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + call.Seq = seq + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) + select { + case <-ctx.Done(): + client.removeCall(call.Seq) + return errors.New("rpc client: call failed: " + ctx.Err().Error()) + case call := <-call.Done: + return call.Error + } +} + +func parseOptions(opts ...*Option) (*Option, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return DefaultOption, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = DefaultOption.MagicNumber + if opt.CodecType == "" { + opt.CodecType = DefaultOption.CodecType + } + return opt, nil +} + +func NewClient(conn net.Conn, opt *Option) (*Client, error) { + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err := fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err := json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + _ = conn.Close() + return nil, err + } + return newClientCodec(f(conn), opt), nil +} + +func newClientCodec(cc codec.Codec, opt *Option) *Client { + client := &Client{ + seq: 1, // seq starts with 1, 0 means invalid call + cc: cc, + opt: opt, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +type clientResult struct { + client *Client + err error +} + +type dialFunc func(network, address string, opt *Option) (client *Client, err error) + +func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err + } + ch := make(chan clientResult) + go func() { + client, err := f(network, address, opt) + ch <- clientResult{client: client, err: err} + }() + if opt.ConnectTimeout == 0 { + result := <-ch + return result.client, result.err + } + select { + case <-time.After(opt.ConnectTimeout): + return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + case result := <-ch: + return result.client, result.err + } +} + +func dial(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn, opt) +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dial, network, address, opts...) +} + +func dialHTTP(network, address string, opt *Option) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + _ = conn.Close() + return nil, err +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Option) (*Client, error) { + return dialTimeout(dialHTTP, network, address, opts...) +} + +// XDial use a general format to represent a rpc server +// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock +func XDial(rpcAddr string, opts ...*Option) (*Client, error) { + parts := strings.Split(rpcAddr, "@") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + switch protocol { + case "http": + return DialHTTP("tcp", addr) + default: + // tcp, unix or other transport protocol + return Dial(protocol, addr) + } +} diff --git a/gee-rpc/day6-discovery/client_test.go b/gee-rpc/day6-discovery/client_test.go new file mode 100644 index 0000000..10b9817 --- /dev/null +++ b/gee-rpc/day6-discovery/client_test.go @@ -0,0 +1,84 @@ +package geerpc + +import ( + "context" + "net" + "os" + "runtime" + "strings" + "testing" + "time" +) + +type Bar int + +func (b Bar) Timeout(argv int, reply *int) error { + time.Sleep(time.Second * 2) + return nil +} + +func startServer(addr chan string) { + var b Bar + _ = Register(&b) + // pick a free port + l, _ := net.Listen("tcp", ":0") + addr <- l.Addr().String() + Accept(l) +} + +func TestClient_dialTimeout(t *testing.T) { + t.Parallel() + f := func(network, address string, opt *Option) (client *Client, err error) { + time.Sleep(time.Second * 2) + return nil, nil + } + t.Run("timeout", func(t *testing.T) { + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + }) + t.Run("0", func(t *testing.T) { + _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _assert(err == nil, "0 means no limit") + }) +} + +func TestClient_Call(t *testing.T) { + t.Parallel() + addrCh := make(chan string) + go startServer(addrCh) + addr := <-addrCh + t.Run("client timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + var reply int + err := client.Call(ctx, "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error") + }) + t.Run("server handle timeout", func(t *testing.T) { + client, _ := Dial("tcp", addr, &Option{ + HandleTimeout: time.Second, + }) + var reply int + err := client.Call(context.Background(), "Bar.Timeout", 1, &reply) + _assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error") + }) +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/geerpc.sock" + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix@" + addr) + _assert(err == nil, "failed to connect unix socket") + } +} diff --git a/gee-rpc/day6-discovery/codec/codec.go b/gee-rpc/day6-discovery/codec/codec.go new file mode 100644 index 0000000..ba28fba --- /dev/null +++ b/gee-rpc/day6-discovery/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day6-discovery/codec/gob.go b/gee-rpc/day6-discovery/codec/gob.go new file mode 100644 index 0000000..808d97b --- /dev/null +++ b/gee-rpc/day6-discovery/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err := c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return err + } + if err := c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return err + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day6-discovery/debug.go b/gee-rpc/day6-discovery/debug.go new file mode 100644 index 0000000..ece1ffd --- /dev/null +++ b/gee-rpc/day6-discovery/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/geerpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day6-discovery/go.mod b/gee-rpc/day6-discovery/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day6-discovery/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day6-discovery/main/main.go b/gee-rpc/day6-discovery/main/main.go new file mode 100644 index 0000000..a71af74 --- /dev/null +++ b/gee-rpc/day6-discovery/main/main.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "geerpc" + "log" + "net" + "net/http" + "sync" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addrCh chan string) { + var foo Foo + l, _ := net.Listen("tcp", ":9999") + _ = geerpc.Register(&foo) + geerpc.HandleHTTP() + addrCh <- l.Addr().String() + _ = http.Serve(l, nil) +} + +func call(addrCh chan string) { + client, _ := geerpc.DialHTTP("tcp", <-addrCh) + defer func() { _ = client.Close() }() + + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} + +func main() { + ch := make(chan string) + go call(ch) + startServer(ch) +} diff --git a/gee-rpc/day6-discovery/server.go b/gee-rpc/day6-discovery/server.go new file mode 100644 index 0000000..38fad20 --- /dev/null +++ b/gee-rpc/day6-discovery/server.go @@ -0,0 +1,266 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "time" +) + +const MagicNumber = 0x3bef5c + +type Option struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body + ConnectTimeout time.Duration // 0 means no limit + HandleTimeout time.Duration +} + +var DefaultOption = &Option{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, + ConnectTimeout: time.Second * 10, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Option + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn), &opt) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec, opt *Option) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + called := make(chan struct{}) + sent := make(chan struct{}) + go func() { + err := req.svc.call(req.mtype, req.argv, req.replyv) + called <- struct{}{} + if err != nil { + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + sent <- struct{}{} + return + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) + sent <- struct{}{} + }() + + if timeout == 0 { + <-called + <-sent + return + } + select { + case <-time.After(timeout): + req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout) + server.sendResponse(cc, req.h, invalidRequest, sending) + case <-called: + <-sent + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP() { + http.Handle(defaultRPCPath, server) + http.Handle(defaultDebugPath, debugHTTP{server}) + log.Println("rpc server debug path:", defaultDebugPath) +} + +// HandleHTTP is a convenient approach for default server to register HTTP handlers +func HandleHTTP() { + DefaultServer.HandleHTTP() +} diff --git a/gee-rpc/day6-discovery/service.go b/gee-rpc/day6-discovery/service.go new file mode 100644 index 0000000..306683c --- /dev/null +++ b/gee-rpc/day6-discovery/service.go @@ -0,0 +1,99 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint64 +} + +func (m *methodType) NumCalls() uint64 { + return atomic.LoadUint64(&m.numCalls) +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.numCalls, 1) + f := m.method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day6-discovery/service_test.go b/gee-rpc/day6-discovery/service_test.go new file mode 100644 index 0000000..c8266df --- /dev/null +++ b/gee-rpc/day6-discovery/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum") +} diff --git a/gee-rpc/day6-discovery/xclient/discovery.go b/gee-rpc/day6-discovery/xclient/discovery.go new file mode 100644 index 0000000..98b58d6 --- /dev/null +++ b/gee-rpc/day6-discovery/xclient/discovery.go @@ -0,0 +1,57 @@ +package xclient + +import ( + "math/rand" + "sync" + "time" +) + +type SelectMode int + +const ( + RandomSelect SelectMode = iota // select randomly + RobbinSelect // select using Robbin algorithm +) + +type Discovery interface { + Get(mode SelectMode) string +} + +var _ Discovery = (*MultiServersDiscovery)(nil) + +// MultiServersDiscovery is a discovery for multi servers without a registry center +// user provides the server addresses explicitly instead +type MultiServersDiscovery struct { + r *rand.Rand // generate random number + mu sync.RWMutex // protect following + servers []string +} + +// Update the servers of discovery dynamically if needed +func (d *MultiServersDiscovery) Update(servers []string) { + d.mu.Lock() + defer d.mu.Unlock() + d.servers = servers +} + +func (d *MultiServersDiscovery) Get(mode SelectMode) string { + d.mu.RLock() + defer d.mu.RUnlock() + if len(d.servers) == 0 { + return "" + } + switch mode { + case RandomSelect: + return d.servers[d.r.Intn(len(d.servers))] + default: + return "" + } +} + +// NewMultiServerDiscovery creates a MultiServersDiscovery instance +func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery { + return &MultiServersDiscovery{ + servers: servers, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} diff --git a/gee-rpc/day6-discovery/xclient/xclient.go b/gee-rpc/day6-discovery/xclient/xclient.go new file mode 100644 index 0000000..ff6992f --- /dev/null +++ b/gee-rpc/day6-discovery/xclient/xclient.go @@ -0,0 +1,48 @@ +package xclient + +import ( + "context" + . "geerpc" + "io" + "sync" +) + +type XClient struct { + d Discovery + mode SelectMode + opt *Option + clients sync.Map +} + +var _ io.Closer = (*XClient)(nil) + +func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient { + return &XClient{d: d, mode: mode, opt: opt} +} + +func (xc *XClient) Close() error { + xc.clients.Range(func(k, v interface{}) bool { + // I have no idea how to deal with error, just ignore it. + _ = v.(*Client).Close() + return true + }) + xc.clients = sync.Map{} + return nil +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +// xc will choose a proper server. +func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { + rpcAddr := xc.d.Get(xc.mode) + client, ok := xc.clients.Load(rpcAddr) + if !ok { + var err error + client, err = XDial(rpcAddr, xc.opt) + if err != nil { + return err + } + xc.clients.Store(rpcAddr, client) + } + return client.(*Client).Call(ctx, serviceMethod, args, reply) +}