From dc840ebae2b5facf35f617c4874a9aaf20ffeb45 Mon Sep 17 00:00:00 2001 From: gzdaijie Date: Fri, 2 Oct 2020 23:44:13 +0800 Subject: [PATCH] gee-rpc: add Client struct filed opt to record options user will pass --- gee-rpc/day2-client/client.go | 42 ++++++++++++++++---------- gee-rpc/day3-service/client.go | 42 ++++++++++++++++---------- gee-rpc/day4-http-debug/client.go | 50 +++++++++++++++++-------------- 3 files changed, 80 insertions(+), 54 deletions(-) diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go index 83119f0..b73c461 100644 --- a/gee-rpc/day2-client/client.go +++ b/gee-rpc/day2-client/client.go @@ -34,6 +34,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec + opt *Options sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -172,15 +173,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return call.Error } -func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { - var err error - defer func() { - if err != nil { - _ = conn.Close() - } - }() - if opt.MagicNumber == 0 { - opt.MagicNumber = MagicNumber +func parseOptions(opts ...*Options) (*Options, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return defaultOptions, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = defaultOptions.MagicNumber + if opt.CodecType == "" { + opt.CodecType = defaultOptions.CodecType + } + return opt, nil +} + +func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { @@ -191,14 +203,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { // send options with server if err = json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) + _ = conn.Close() return nil, err } - return newClientCodec(f(conn)), nil + return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec) *Client { +func newClientCodec(cc codec.Codec, opt *Options) *Client { client := &Client{ cc: cc, + opt: opt, pending: make(map[uint64]*Call), } go client.receive() @@ -207,13 +221,9 @@ func newClientCodec(cc codec.Codec) *Client { // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Options) (*Client, error) { - opt := defaultOptions - if len(opts) > 0 && opts[0] != nil { - opt = opts[0] - } conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opt) + return NewClient(conn, opts...) } diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go index 83119f0..b73c461 100644 --- a/gee-rpc/day3-service/client.go +++ b/gee-rpc/day3-service/client.go @@ -34,6 +34,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec + opt *Options sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -172,15 +173,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return call.Error } -func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { - var err error - defer func() { - if err != nil { - _ = conn.Close() - } - }() - if opt.MagicNumber == 0 { - opt.MagicNumber = MagicNumber +func parseOptions(opts ...*Options) (*Options, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return defaultOptions, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = defaultOptions.MagicNumber + if opt.CodecType == "" { + opt.CodecType = defaultOptions.CodecType + } + return opt, nil +} + +func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { @@ -191,14 +203,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { // send options with server if err = json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) + _ = conn.Close() return nil, err } - return newClientCodec(f(conn)), nil + return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec) *Client { +func newClientCodec(cc codec.Codec, opt *Options) *Client { client := &Client{ cc: cc, + opt: opt, pending: make(map[uint64]*Call), } go client.receive() @@ -207,13 +221,9 @@ func newClientCodec(cc codec.Codec) *Client { // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Options) (*Client, error) { - opt := defaultOptions - if len(opts) > 0 && opts[0] != nil { - opt = opts[0] - } conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opt) + return NewClient(conn, opts...) } diff --git a/gee-rpc/day4-http-debug/client.go b/gee-rpc/day4-http-debug/client.go index 8ab5ef0..7ae8938 100644 --- a/gee-rpc/day4-http-debug/client.go +++ b/gee-rpc/day4-http-debug/client.go @@ -19,7 +19,7 @@ import ( // Call represents an active RPC. type Call struct { - ServiceMethod string // format "." + ServiceMethod string // format "." Args interface{} // arguments to the function Reply interface{} // reply from the function Error error // if error occurs, it will be set @@ -36,6 +36,7 @@ func (call *Call) done() { // multiple goroutines simultaneously. type Client struct { cc codec.Codec + opt *Options sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following @@ -174,15 +175,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return call.Error } -func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { - var err error - defer func() { - if err != nil { - _ = conn.Close() - } - }() - if opt.MagicNumber == 0 { - opt.MagicNumber = MagicNumber +func parseOptions(opts ...*Options) (*Options, error) { + // if opts is nil or pass nil as parameter + if len(opts) == 0 || opts[0] == nil { + return defaultOptions, nil + } + if len(opts) != 1 { + return nil, errors.New("number of options is more than 1") + } + opt := opts[0] + opt.MagicNumber = defaultOptions.MagicNumber + if opt.CodecType == "" { + opt.CodecType = defaultOptions.CodecType + } + return opt, nil +} + +func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) { + opt, err := parseOptions(opts...) + if err != nil { + return nil, err } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { @@ -193,14 +205,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { // send options with server if err = json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) + _ = conn.Close() return nil, err } - return newClientCodec(f(conn)), nil + return newClientCodec(f(conn), opt), nil } -func newClientCodec(cc codec.Codec) *Client { +func newClientCodec(cc codec.Codec, opt *Options) *Client { client := &Client{ cc: cc, + opt: opt, pending: make(map[uint64]*Call), } go client.receive() @@ -209,15 +223,11 @@ func newClientCodec(cc codec.Codec) *Client { // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Options) (*Client, error) { - opt := defaultOptions - if len(opts) > 0 && opts[0] != nil { - opt = opts[0] - } conn, err := net.Dial(network, address) if err != nil { return nil, err } - return NewClient(conn, opt) + return NewClient(conn, opts...) } // DialHTTP connects to an HTTP RPC server at the specified network address @@ -229,10 +239,6 @@ func DialHTTP(network, address string, opts ...*Options) (*Client, error) { // DialHTTPPath connects to an HTTP RPC server // at the specified network address and path. func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, error) { - opt := defaultOptions - if len(opts) > 0 && opts[0] != nil { - opt = opts[0] - } conn, err := net.Dial(network, address) if err != nil { return nil, err @@ -243,7 +249,7 @@ func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, err // before switching to RPC protocol. resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) if err == nil && resp.Status == connected { - return NewClient(conn, opt) + return NewClient(conn, opts...) } if err == nil { err = errors.New("unexpected HTTP response: " + resp.Status)