diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go index 3fd885e..83119f0 100644 --- a/gee-rpc/day2-client/client.go +++ b/gee-rpc/day2-client/client.go @@ -205,16 +205,15 @@ func newClientCodec(cc codec.Codec) *Client { return client } -// DialWithOptions connects to an RPC server at the specified network address -func DialWithOptions(network, address string, opt *Options) (*Client, error) { +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Options) (*Client, error) { + opt := defaultOptions + if len(opts) > 0 && opts[0] != nil { + opt = opts[0] + } conn, err := net.Dial(network, address) if err != nil { return nil, err } return NewClient(conn, opt) } - -// Dial connects to an RPC server at the specified network address -func Dial(network, address string) (*Client, error) { - return DialWithOptions(network, address, defaultOptions) -} diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go index 3fd885e..83119f0 100644 --- a/gee-rpc/day3-service/client.go +++ b/gee-rpc/day3-service/client.go @@ -205,16 +205,15 @@ func newClientCodec(cc codec.Codec) *Client { return client } -// DialWithOptions connects to an RPC server at the specified network address -func DialWithOptions(network, address string, opt *Options) (*Client, error) { +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Options) (*Client, error) { + opt := defaultOptions + if len(opts) > 0 && opts[0] != nil { + opt = opts[0] + } conn, err := net.Dial(network, address) if err != nil { return nil, err } return NewClient(conn, opt) } - -// Dial connects to an RPC server at the specified network address -func Dial(network, address string) (*Client, error) { - return DialWithOptions(network, address, defaultOptions) -} diff --git a/gee-rpc/day3-service/service.go b/gee-rpc/day3-service/service.go index fcdfe8d..e0711cd 100644 --- a/gee-rpc/day3-service/service.go +++ b/gee-rpc/day3-service/service.go @@ -8,31 +8,31 @@ import ( ) type methodType struct { - method reflect.Method - argType reflect.Type - replyType reflect.Type - numCalls uint64 + Method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + NumCalls uint64 } func (m *methodType) newArgv() reflect.Value { var argv reflect.Value // arg may be a pointer type, or a value type - if m.argType.Kind() == reflect.Ptr { - argv = reflect.New(m.argType.Elem()) + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) } else { - argv = reflect.New(m.argType).Elem() + argv = reflect.New(m.ArgType).Elem() } return argv } func (m *methodType) newReplyv() reflect.Value { // reply must be a pointer type - replyv := reflect.New(m.replyType.Elem()) - switch m.replyType.Elem().Kind() { + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { case reflect.Map: - replyv.Elem().Set(reflect.MakeMap(m.replyType.Elem())) + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) case reflect.Slice: - replyv.Elem().Set(reflect.MakeSlice(m.replyType.Elem(), 0, 0)) + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) } return replyv } @@ -72,17 +72,17 @@ func (s *service) registerMethods() { continue } s.method[method.Name] = &methodType{ - method: method, - argType: argType, - replyType: replyType, + Method: method, + ArgType: argType, + ReplyType: replyType, } log.Printf("rpc server: register %s.%s\n", s.name, method.Name) } } func (s *service) call(m *methodType, argv, replyv reflect.Value) error { - atomic.AddUint64(&m.numCalls, 1) - f := m.method.Func + atomic.AddUint64(&m.NumCalls, 1) + f := m.Method.Func returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) if errInter := returnValues[0].Interface(); errInter != nil { return errInter.(error) diff --git a/gee-rpc/day4-http-debug/client.go b/gee-rpc/day4-http-debug/client.go new file mode 100644 index 0000000..3b53c16 --- /dev/null +++ b/gee-rpc/day4-http-debug/client.go @@ -0,0 +1,254 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "sync" +) + +// Call represents an active RPC. +type Call struct { + ServiceMethod string // format "." + Args interface{} // arguments to the function + Reply interface{} // reply from the function + Error error // if error occurs, it will be set + Done chan *Call // Strobes when call is complete. +} + +func (call *Call) done() { + call.Done <- call +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + cc codec.Codec + sending sync.Mutex // protect following + header codec.Header + mu sync.Mutex // protect following + seq uint64 + pending map[uint64]*Call + closed bool // user has called Close +} + +var _ io.Closer = (*Client)(nil) + +var ErrShutdown = errors.New("connection is shut down") + +// Close the connection +func (client *Client) Close() error { + client.mu.Lock() + defer client.mu.Unlock() + if client.closed { + return ErrShutdown + } + client.closed = true + return client.cc.Close() +} + +func (client *Client) registerCall(call *Call) (uint64, error) { + client.mu.Lock() + defer client.mu.Unlock() + if client.closed { + return 0, ErrShutdown + } + seq := client.seq + client.pending[seq] = call + client.seq++ + return seq, nil +} + +func (client *Client) removeCall(seq uint64) *Call { + client.mu.Lock() + defer client.mu.Unlock() + call := client.pending[seq] + delete(client.pending, seq) + return call +} + +func (client *Client) terminateCalls(err error) { + client.sending.Lock() + defer client.sending.Unlock() + client.mu.Lock() + defer client.mu.Unlock() + for _, call := range client.pending { + call.Error = err + call.done() + } +} + +func (client *Client) send(call *Call) { + // make sure that the client will send a complete request + client.sending.Lock() + defer client.sending.Unlock() + + // register this call. + seq, err := client.registerCall(call) + if err != nil { + call.Error = err + call.done() + return + } + + // prepare request header + client.header.ServiceMethod = call.ServiceMethod + client.header.Seq = seq + client.header.Error = "" + + // encode and send the request + if err := client.cc.Write(&client.header, call.Args); err != nil { + call := client.removeCall(seq) + // call may be nil, it usually means that Write partially failed, + // client has received the response and handled + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) receive() { + var err error + for err == nil { + var h codec.Header + if err = client.cc.ReadHeader(&h); err != nil { + break + } + call := client.removeCall(h.Seq) + switch { + case call == nil: + // it usually means that Write partially failed + // and call was already removed. + err = client.cc.ReadBody(nil) + case h.Error != "": + call.Error = fmt.Errorf(h.Error) + err = client.cc.ReadBody(nil) + call.done() + default: + err = client.cc.ReadBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // error occurs, so terminateCalls pending calls + client.terminateCalls(err) +} + +// Go invokes the function asynchronously. +// It returns the Call structure representing the invocation. +func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { + if done == nil { + done = make(chan *Call, 10) + } else if cap(done) == 0 { + log.Panic("rpc client: done channel is unbuffered") + } + call := &Call{ + ServiceMethod: serviceMethod, + Args: args, + Reply: reply, + Done: done, + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, +// and returns its error status. +func (client *Client) Call(serviceMethod string, args, reply interface{}) error { + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} + +func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { + var err error + defer func() { + if err != nil { + _ = conn.Close() + } + }() + if opt.MagicNumber == 0 { + opt.MagicNumber = MagicNumber + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + err = fmt.Errorf("invalid codec type %s", opt.CodecType) + log.Println("rpc client: codec error:", err) + return nil, err + } + // send options with server + if err = json.NewEncoder(conn).Encode(opt); err != nil { + log.Println("rpc client: options error: ", err) + return nil, err + } + return newClientCodec(f(conn)), nil +} + +func newClientCodec(cc codec.Codec) *Client { + client := &Client{ + cc: cc, + pending: make(map[uint64]*Call), + } + go client.receive() + return client +} + +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Options) (*Client, error) { + opt := defaultOptions + if len(opts) > 0 && opts[0] != nil { + opt = opts[0] + } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn, opt) +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string, opts ...*Options) (*Client, error) { + return DialHTTPPath(network, address, defaultRPCPath, opts...) +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, error) { + opt := defaultOptions + if len(opts) > 0 && opts[0] != nil { + opt = opts[0] + } + var err error + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", path)) + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn, opt) + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + _ = conn.Close() + return nil, err +} diff --git a/gee-rpc/day4-http-debug/codec/codec.go b/gee-rpc/day4-http-debug/codec/codec.go new file mode 100644 index 0000000..ba28fba --- /dev/null +++ b/gee-rpc/day4-http-debug/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import ( + "io" +) + +type Header struct { + ServiceMethod string // format "Service.Method" + Seq uint64 // sequence number chosen by client + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(io.ReadWriteCloser) Codec + +type Type string + +const ( + GobType Type = "application/gob" + JsonType Type = "application/json" +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec +} diff --git a/gee-rpc/day4-http-debug/codec/gob.go b/gee-rpc/day4-http-debug/codec/gob.go new file mode 100644 index 0000000..808d97b --- /dev/null +++ b/gee-rpc/day4-http-debug/codec/gob.go @@ -0,0 +1,57 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "io" + "log" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + if err := c.enc.Encode(h); err != nil { + log.Println("rpc: gob error encoding header:", err) + return err + } + if err := c.enc.Encode(body); err != nil { + log.Println("rpc: gob error encoding body:", err) + return err + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/gee-rpc/day4-http-debug/debug.go b/gee-rpc/day4-http-debug/debug.go new file mode 100644 index 0000000..d76de55 --- /dev/null +++ b/gee-rpc/day4-http-debug/debug.go @@ -0,0 +1,60 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "fmt" + "html/template" + "net/http" +) + +const debugText = ` + + GeeRPC Services + {{range .}} +
+ Service {{.Name}} +
+ + + {{range $name, $mtype := .Method}} + + + + + {{end}} +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}}
+ {{end}} + + ` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +// Runs at /debug/rpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services []debugService + server.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{ + Name: namei.(string), + Method: svc.method, + }) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/gee-rpc/day4-http-debug/go.mod b/gee-rpc/day4-http-debug/go.mod new file mode 100644 index 0000000..0ec8aeb --- /dev/null +++ b/gee-rpc/day4-http-debug/go.mod @@ -0,0 +1,3 @@ +module geerpc + +go 1.13 diff --git a/gee-rpc/day4-http-debug/main/main.go b/gee-rpc/day4-http-debug/main/main.go new file mode 100644 index 0000000..13bb592 --- /dev/null +++ b/gee-rpc/day4-http-debug/main/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "geerpc" + "log" + "net/http" + "sync" + "time" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func startServer(addr string) { + var foo Foo + _ = geerpc.Register(&foo) + geerpc.HandleHTTP() + log.Fatal(http.ListenAndServe(addr, nil)) +} + +func call() { + // start server may cost some time + time.Sleep(time.Second) + client, _ := geerpc.DialHTTP("tcp", ":9999") + defer func() { _ = client.Close() }() + + // send request & receive response + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + args := &Args{Num1: i, Num2: i * i} + var reply int + if err := client.Call("Foo.Sum", args, &reply); err != nil { + log.Fatal("call Foo.Sum error:", err) + } + log.Printf("%d + %d = %d", args.Num1, args.Num2, reply) + }(i) + } + wg.Wait() +} + +func main() { + go call() + startServer(":9999") +} diff --git a/gee-rpc/day4-http-debug/server.go b/gee-rpc/day4-http-debug/server.go new file mode 100644 index 0000000..ffb85ee --- /dev/null +++ b/gee-rpc/day4-http-debug/server.go @@ -0,0 +1,238 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package geerpc + +import ( + "encoding/json" + "errors" + "geerpc/codec" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" +) + +const MagicNumber = 0x3bef5c + +type Options struct { + MagicNumber int // MagicNumber marks this's a geerpc request + CodecType codec.Type // client may choose different Codec to encode body +} + +var defaultOptions = &Options{ + MagicNumber: MagicNumber, + CodecType: codec.GobType, +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { _ = conn.Close() }() + var opt Options + if err := json.NewDecoder(conn).Decode(&opt); err != nil { + log.Println("rpc server: options error: ", err) + return + } + if opt.MagicNumber != MagicNumber { + log.Printf("rpc server: invalid magic number %x", opt.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opt.CodecType] + if f == nil { + log.Printf("rpc server: invalid codec type %s", opt.CodecType) + return + } + server.serveCodec(f(conn)) +} + +// invalidRequest is a placeholder for response argv when error occurs +var invalidRequest = struct{}{} + +func (server *Server) serveCodec(cc codec.Codec) { + sending := new(sync.Mutex) // make sure to send a complete response + wg := new(sync.WaitGroup) // wait until all request are handled + for { + req, err := server.readRequest(cc) + if err != nil { + if req == nil { + break // it's not possible to recover, so close the connection + } + req.h.Error = err.Error() + server.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go server.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} + +// request stores all information of a call +type request struct { + h *codec.Header // header of request + argv, replyv reflect.Value // argv and replyv of request + mtype *methodType + svc *service +} + +func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} + +func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { + dot := strings.LastIndex(serviceMethod, ".") + if dot < 0 { + err = errors.New("rpc server: service/Method request ill-formed: " + serviceMethod) + return + } + serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc server: can't find service " + serviceName) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc server: can't find Method " + methodName) + } + return +} + +func (server *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := server.readRequestHeader(cc) + if err != nil { + return nil, err + } + req := &request{h: h} + req.svc, req.mtype, err = server.findService(h.ServiceMethod) + if err != nil { + return req, err + } + req.argv = req.mtype.newArgv() + req.replyv = req.mtype.newReplyv() + + // make sure that argvi is a pointer, ReadBody need a pointer as parameter + argvi := req.argv.Interface() + if req.argv.Type().Kind() != reflect.Ptr { + argvi = req.argv.Addr().Interface() + } + if err = cc.ReadBody(argvi); err != nil { + log.Println("rpc server: read body err:", err) + return req, err + } + return req, nil +} + +func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + if err := cc.Write(h, body); err != nil { + log.Println("rpc server: write response error:", err) + } +} + +func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + defer wg.Done() + err := req.svc.call(req.mtype, req.argv, req.replyv) + if err != nil { + req.h.Error = err.Error() + } + server.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Println("rpc server: accept error:", err) + return + } + go server.ServeConn(conn) + } +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported Method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +func (server *Server) Register(rcvr interface{}) error { + s := newService(rcvr) + if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { + return errors.New("rpc: service already defined: " + s.name) + } + return nil +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +const ( + connected = "200 Connected to Gee RPC" + defaultRPCPath = "/_geeprc_" + defaultDebugPath = "/debug/geerpc" +) + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP(rpcPath, debugPath string) { + http.Handle(rpcPath, server) + http.Handle(debugPath, debugHTTP{server}) + log.Println("rpc server debug path:", debugPath) +} + +func HandleHTTP() { + DefaultServer.HandleHTTP(defaultRPCPath, defaultDebugPath) +} diff --git a/gee-rpc/day4-http-debug/service.go b/gee-rpc/day4-http-debug/service.go new file mode 100644 index 0000000..e0711cd --- /dev/null +++ b/gee-rpc/day4-http-debug/service.go @@ -0,0 +1,95 @@ +package geerpc + +import ( + "go/ast" + "log" + "reflect" + "sync/atomic" +) + +type methodType struct { + Method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + NumCalls uint64 +} + +func (m *methodType) newArgv() reflect.Value { + var argv reflect.Value + // arg may be a pointer type, or a value type + if m.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(m.ArgType.Elem()) + } else { + argv = reflect.New(m.ArgType).Elem() + } + return argv +} + +func (m *methodType) newReplyv() reflect.Value { + // reply must be a pointer type + replyv := reflect.New(m.ReplyType.Elem()) + switch m.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) + } + return replyv +} + +type service struct { + name string + typ reflect.Type + rcvr reflect.Value + method map[string]*methodType +} + +func newService(rcvr interface{}) *service { + s := new(service) + s.rcvr = reflect.ValueOf(rcvr) + s.name = reflect.Indirect(s.rcvr).Type().Name() + s.typ = reflect.TypeOf(rcvr) + if !ast.IsExported(s.name) { + log.Fatalf("rpc server: %s is not a valid service name", s.name) + } + s.registerMethods() + return s +} + +func (s *service) registerMethods() { + s.method = make(map[string]*methodType) + for i := 0; i < s.typ.NumMethod(); i++ { + method := s.typ.Method(i) + mType := method.Type + if mType.NumIn() != 3 || mType.NumOut() != 1 { + continue + } + if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + continue + } + argType, replyType := mType.In(1), mType.In(2) + if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { + continue + } + s.method[method.Name] = &methodType{ + Method: method, + ArgType: argType, + ReplyType: replyType, + } + log.Printf("rpc server: register %s.%s\n", s.name, method.Name) + } +} + +func (s *service) call(m *methodType, argv, replyv reflect.Value) error { + atomic.AddUint64(&m.NumCalls, 1) + f := m.Method.Func + returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) + if errInter := returnValues[0].Interface(); errInter != nil { + return errInter.(error) + } + return nil +} + +func isExportedOrBuiltinType(t reflect.Type) bool { + return ast.IsExported(t.Name()) || t.PkgPath() == "" +} diff --git a/gee-rpc/day4-http-debug/service_test.go b/gee-rpc/day4-http-debug/service_test.go new file mode 100644 index 0000000..7ba2786 --- /dev/null +++ b/gee-rpc/day4-http-debug/service_test.go @@ -0,0 +1,48 @@ +package geerpc + +import ( + "fmt" + "reflect" + "testing" +) + +type Foo int + +type Args struct{ Num1, Num2 int } + +func (f Foo) Sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +// it's not a exported Method +func (f Foo) sum(args Args, reply *int) error { + *reply = args.Num1 + args.Num2 + return nil +} + +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func TestNewService(t *testing.T) { + var foo Foo + s := newService(&foo) + _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method)) + mType := s.method["Sum"] + _assert(mType != nil, "wrong Method, Sum shouldn't nil") +} + +func TestMethodType_Call(t *testing.T) { + var foo Foo + s := newService(&foo) + mType := s.method["Sum"] + + argv := mType.newArgv() + replyv := mType.newReplyv() + argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3})) + err := s.call(mType, argv, replyv) + _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls == 1, "failed to call Foo.Sum") +}