Skip to content

Commit

Permalink
gee-rpc: add Client struct filed opt to record options user will pass
Browse files Browse the repository at this point in the history
  • Loading branch information
geektutu committed Oct 2, 2020
1 parent b9f5ca1 commit dc840eb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 54 deletions.
42 changes: 26 additions & 16 deletions gee-rpc/day2-client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (call *Call) done() {
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Options
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
Expand Down Expand Up @@ -172,15 +173,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error
return call.Error
}

func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
var err error
defer func() {
if err != nil {
_ = conn.Close()
}
}()
if opt.MagicNumber == 0 {
opt.MagicNumber = MagicNumber
func parseOptions(opts ...*Options) (*Options, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return defaultOptions, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = defaultOptions.MagicNumber
if opt.CodecType == "" {
opt.CodecType = defaultOptions.CodecType
}
return opt, nil
}

func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
Expand All @@ -191,14 +203,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
// send options with server
if err = json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn)), nil
return newClientCodec(f(conn), opt), nil
}

func newClientCodec(cc codec.Codec) *Client {
func newClientCodec(cc codec.Codec, opt *Options) *Client {
client := &Client{
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
Expand All @@ -207,13 +221,9 @@ func newClientCodec(cc codec.Codec) *Client {

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Options) (*Client, error) {
opt := defaultOptions
if len(opts) > 0 && opts[0] != nil {
opt = opts[0]
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
return NewClient(conn, opts...)
}
42 changes: 26 additions & 16 deletions gee-rpc/day3-service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (call *Call) done() {
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Options
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
Expand Down Expand Up @@ -172,15 +173,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error
return call.Error
}

func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
var err error
defer func() {
if err != nil {
_ = conn.Close()
}
}()
if opt.MagicNumber == 0 {
opt.MagicNumber = MagicNumber
func parseOptions(opts ...*Options) (*Options, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return defaultOptions, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = defaultOptions.MagicNumber
if opt.CodecType == "" {
opt.CodecType = defaultOptions.CodecType
}
return opt, nil
}

func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
Expand All @@ -191,14 +203,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
// send options with server
if err = json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn)), nil
return newClientCodec(f(conn), opt), nil
}

func newClientCodec(cc codec.Codec) *Client {
func newClientCodec(cc codec.Codec, opt *Options) *Client {
client := &Client{
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
Expand All @@ -207,13 +221,9 @@ func newClientCodec(cc codec.Codec) *Client {

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Options) (*Client, error) {
opt := defaultOptions
if len(opts) > 0 && opts[0] != nil {
opt = opts[0]
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
return NewClient(conn, opts...)
}
50 changes: 28 additions & 22 deletions gee-rpc/day4-http-debug/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

// Call represents an active RPC.
type Call struct {
ServiceMethod string // format "<service>.<Method>"
ServiceMethod string // format "<service>.<method>"
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Expand All @@ -36,6 +36,7 @@ func (call *Call) done() {
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Options
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
Expand Down Expand Up @@ -174,15 +175,26 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error
return call.Error
}

func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
var err error
defer func() {
if err != nil {
_ = conn.Close()
}
}()
if opt.MagicNumber == 0 {
opt.MagicNumber = MagicNumber
func parseOptions(opts ...*Options) (*Options, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return defaultOptions, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = defaultOptions.MagicNumber
if opt.CodecType == "" {
opt.CodecType = defaultOptions.CodecType
}
return opt, nil
}

func NewClient(conn io.ReadWriteCloser, opts ...*Options) (*Client, error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
Expand All @@ -193,14 +205,16 @@ func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) {
// send options with server
if err = json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn)), nil
return newClientCodec(f(conn), opt), nil
}

func newClientCodec(cc codec.Codec) *Client {
func newClientCodec(cc codec.Codec, opt *Options) *Client {
client := &Client{
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
Expand All @@ -209,15 +223,11 @@ func newClientCodec(cc codec.Codec) *Client {

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Options) (*Client, error) {
opt := defaultOptions
if len(opts) > 0 && opts[0] != nil {
opt = opts[0]
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
return NewClient(conn, opts...)
}

// DialHTTP connects to an HTTP RPC server at the specified network address
Expand All @@ -229,10 +239,6 @@ func DialHTTP(network, address string, opts ...*Options) (*Client, error) {
// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, error) {
opt := defaultOptions
if len(opts) > 0 && opts[0] != nil {
opt = opts[0]
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
Expand All @@ -243,7 +249,7 @@ func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, err
// before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn, opt)
return NewClient(conn, opts...)
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
Expand Down

0 comments on commit dc840eb

Please sign in to comment.