From 7f9b0a092b8aaf9742b33752a20dc37318d70738 Mon Sep 17 00:00:00 2001 From: Murphy Law Date: Mon, 24 Feb 2020 00:57:32 -0800 Subject: [PATCH] zmq4: implement proper REQ and REP socket This fixes problems with REP sending messages to all connected peers rather than to the originating peer. The REQ/REP semantics now matches the recommended implementation behavior as specified in the ZMQ RFC. In particular: * REQ shall route outgoing messages to connected peers using a round-robin strategy * REQ shall block on sending, or return a suitable error, when it has no connected peers. * REP shall not block on sending Thus, this changes the behavior a bit as REP no longer blocks on Send. Fixes #70 --- czmq4_test.go | 14 +-- rep.go | 264 +++++++++++++++++++++++++++++++++++++++++++- req.go | 159 ++++++++++++++++++++++++-- zmq4_reqrep_test.go | 196 ++++++++++++++++++++++++++++++-- 4 files changed, 600 insertions(+), 33 deletions(-) diff --git a/czmq4_test.go b/czmq4_test.go index 421b1e4..c523d85 100644 --- a/czmq4_test.go +++ b/czmq4_test.go @@ -68,43 +68,43 @@ var ( { name: "tcp-creq-rep", endpoint: must(EndPoint("tcp")), - req: zmq4.NewCReq(bkg), + req1: zmq4.NewCReq(bkg), rep: zmq4.NewRep(bkg), }, { name: "tcp-req-crep", endpoint: must(EndPoint("tcp")), - req: zmq4.NewReq(bkg), + req1: zmq4.NewReq(bkg), rep: zmq4.NewCRep(bkg), }, { name: "tcp-creq-crep", endpoint: must(EndPoint("tcp")), - req: zmq4.NewCReq(bkg), + req1: zmq4.NewCReq(bkg), rep: zmq4.NewCRep(bkg), }, { name: "ipc-creq-rep", endpoint: "ipc://ipc-creq-rep", - req: zmq4.NewCReq(bkg), + req1: zmq4.NewCReq(bkg), rep: zmq4.NewRep(bkg), }, { name: "ipc-req-crep", endpoint: "ipc://ipc-req-crep", - req: zmq4.NewReq(bkg), + req1: zmq4.NewReq(bkg), rep: zmq4.NewCRep(bkg), }, { name: "ipc-creq-crep", endpoint: "ipc://ipc-creq-crep", - req: zmq4.NewCReq(bkg), + req1: zmq4.NewCReq(bkg), rep: zmq4.NewCRep(bkg), }, { name: "inproc-creq-crep", endpoint: "inproc://inproc-creq-crep", - req: zmq4.NewCReq(bkg), + req1: zmq4.NewCReq(bkg), rep: zmq4.NewCRep(bkg), }, } diff --git a/rep.go b/rep.go index 60b1eb7..fcc6819 100644 --- a/rep.go +++ b/rep.go @@ -7,12 +7,18 @@ package zmq4 import ( "context" "net" + "sync" + + "golang.org/x/xerrors" ) // NewRep returns a new REP ZeroMQ socket. // The returned socket value is initially unbound. func NewRep(ctx context.Context, opts ...Option) Socket { rep := &repSocket{newSocket(ctx, Rep, opts...)} + sharedState := newRepState() + rep.sck.w = newRepWriter(rep.sck.ctx, sharedState) + rep.sck.r = newRepReader(rep.sck.ctx, sharedState) return rep } @@ -29,22 +35,28 @@ func (rep *repSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (rep *repSocket) Send(msg Msg) error { - msg.Frames = append([][]byte{nil}, msg.Frames...) - return rep.sck.Send(msg) + ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout()) + defer cancel() + return rep.sck.w.write(ctx, msg) } // SendMulti puts the message on the outbound send queue. // SendMulti blocks until the message can be queued or the send deadline expires. // The message will be sent as a multipart message. func (rep *repSocket) SendMulti(msg Msg) error { - msg.Frames = append([][]byte{nil}, msg.Frames...) - return rep.sck.SendMulti(msg) + msg.multipart = true + ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout()) + defer cancel() + return rep.sck.w.write(ctx, msg) } // Recv receives a complete message. func (rep *repSocket) Recv() (Msg, error) { - msg, err := rep.sck.Recv() - if len(msg.Frames) > 1 { + ctx, cancel := context.WithCancel(rep.sck.ctx) + defer cancel() + var msg Msg + err := rep.sck.r.read(ctx, &msg) + if err == nil && len(msg.Frames) > 1 { msg.Frames = msg.Frames[1:] } return msg, err @@ -81,6 +93,246 @@ func (rep *repSocket) SetOption(name string, value interface{}) error { return rep.sck.SetOption(name, value) } +type repMsg struct { + conn *Conn + msg Msg +} + +type repReader struct { + ctx context.Context + state *repState + + mu sync.Mutex + conns []*Conn + + msgCh chan repMsg +} + +func newRepReader(ctx context.Context, state *repState) *repReader { + const qsize = 10 + return &repReader{ + ctx: ctx, + msgCh: make(chan repMsg, qsize), + state: state, + } +} + +func (r *repReader) addConn(c *Conn) { + go r.listen(r.ctx, c) + r.mu.Lock() + r.conns = append(r.conns, c) + r.mu.Unlock() +} + +func (r *repReader) rmConn(conn *Conn) { + r.mu.Lock() + defer r.mu.Unlock() + + cur := -1 + for i := range r.conns { + if r.conns[i] == conn { + cur = i + break + } + } + if cur >= 0 { + r.conns = append(r.conns[:cur], r.conns[cur+1:]...) + } +} + +func (r *repReader) read(ctx context.Context, msg *Msg) error { + select { + case <-ctx.Done(): + return ctx.Err() + case repMsg := <-r.msgCh: + if repMsg.msg.err != nil { + return repMsg.msg.err + } + pre, innerMsg := splitReq(repMsg.msg) + if pre == nil { + return xerrors.Errorf("zmq4: invalid REP message") + } + *msg = innerMsg + r.state.Set(repMsg.conn, pre) + } + return nil +} + +func (r *repReader) listen(ctx context.Context, conn *Conn) { + defer r.rmConn(conn) + defer conn.Close() + + for { + msg := conn.read() + select { + case <-ctx.Done(): + return + default: + if msg.err != nil { + return + } + r.msgCh <- repMsg{conn, msg} + } + } +} + +func (r *repReader) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + var err error + for _, conn := range r.conns { + e := conn.Close() + if e != nil && err == nil { + err = e + } + } + r.conns = nil + return err +} + +func splitReq(envelope Msg) (preamble [][]byte, msg Msg) { + for i, frame := range envelope.Frames { + if len(frame) != 0 { + continue + } + preamble = envelope.Frames[:i+1] + if i+1 < len(envelope.Frames) { + msg = NewMsgFrom(envelope.Frames[i+1:]...) + } + } + return +} + +type repSendPayload struct { + conn *Conn + preamble [][]byte + msg Msg +} + +type repWriter struct { + ctx context.Context + state *repState + + mu sync.Mutex + conns []*Conn + + sendCh chan repSendPayload +} + +func (r repSendPayload) buildReplyMsg() Msg { + var frames = make([][]byte, 0, len(r.preamble)+len(r.msg.Frames)) + frames = append(frames, r.preamble...) + frames = append(frames, r.msg.Frames...) + return NewMsgFrom(frames...) +} + +func newRepWriter(ctx context.Context, state *repState) *repWriter { + r := &repWriter{ + ctx: ctx, + state: state, + sendCh: make(chan repSendPayload), + } + go r.run() + return r +} + +func (r *repWriter) addConn(w *Conn) { + r.mu.Lock() + r.conns = append(r.conns, w) + r.mu.Unlock() +} + +func (r *repWriter) rmConn(conn *Conn) { + r.mu.Lock() + defer r.mu.Unlock() + + cur := -1 + for i := range r.conns { + if r.conns[i] == conn { + cur = i + break + } + } + if cur >= 0 { + r.conns = append(r.conns[:cur], r.conns[cur+1:]...) + } +} + +func (r *repWriter) write(ctx context.Context, msg Msg) error { + conn, preamble := r.state.Get() + r.sendCh <- repSendPayload{conn, preamble, msg} + return nil +} + +func (r *repWriter) run() { + for { + select { + case <-r.ctx.Done(): + return + case payload, ok := <-r.sendCh: + if !ok { + return + } + r.sendPayload(payload) + } + } +} + +func (r *repWriter) sendPayload(payload repSendPayload) { + r.mu.Lock() + defer r.mu.Unlock() + for _, conn := range r.conns { + if conn == payload.conn { + reply := payload.buildReplyMsg() + // not much we can do at this point. Perhaps log the error? + _ = conn.SendMsg(reply) + return + } + } +} + +func (r *repWriter) Close() error { + close(r.sendCh) + r.mu.Lock() + defer r.mu.Unlock() + + var err error + for _, conn := range r.conns { + e := conn.Close() + if e != nil && err == nil { + err = e + } + } + r.conns = nil + return err +} + +type repState struct { + mu sync.Mutex + conn *Conn + preamble [][]byte // includes delimiter +} + +func newRepState() *repState { + return &repState{} +} + +func (r *repState) Get() (conn *Conn, preamble [][]byte) { + r.mu.Lock() + conn = r.conn + preamble = r.preamble + r.mu.Unlock() + return +} + +func (r *repState) Set(conn *Conn, pre [][]byte) { + r.mu.Lock() + r.conn = conn + r.preamble = pre + r.mu.Unlock() +} + var ( _ Socket = (*repSocket)(nil) ) diff --git a/req.go b/req.go index 4457642..40b918c 100644 --- a/req.go +++ b/req.go @@ -7,18 +7,25 @@ package zmq4 import ( "context" "net" + "sync" + + "golang.org/x/xerrors" ) // NewReq returns a new REQ ZeroMQ socket. // The returned socket value is initially unbound. func NewReq(ctx context.Context, opts ...Option) Socket { - req := &reqSocket{newSocket(ctx, Req, opts...)} + state := &reqState{} + req := &reqSocket{newSocket(ctx, Req, opts...), state} + req.sck.r = newReqReader(req.sck.ctx, state) + req.sck.w = newReqWriter(req.sck.ctx, state) return req } // reqSocket is a REQ ZeroMQ socket. type reqSocket struct { - sck *socket + sck *socket + state *reqState } // Close closes the open Socket @@ -29,24 +36,27 @@ func (req *reqSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (req *reqSocket) Send(msg Msg) error { - msg.Frames = append([][]byte{nil}, msg.Frames...) - return req.sck.Send(msg) + ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout()) + defer cancel() + return req.sck.w.write(ctx, msg) } // SendMulti puts the message on the outbound send queue. // SendMulti blocks until the message can be queued or the send deadline expires. // The message will be sent as a multipart message. func (req *reqSocket) SendMulti(msg Msg) error { - msg.Frames = append([][]byte{nil}, msg.Frames...) - return req.sck.SendMulti(msg) + msg.multipart = true + ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout()) + defer cancel() + return req.sck.w.write(ctx, msg) } // Recv receives a complete message. func (req *reqSocket) Recv() (Msg, error) { - msg, err := req.sck.Recv() - if len(msg.Frames) > 1 { - msg.Frames = msg.Frames[1:] - } + ctx, cancel := context.WithCancel(req.sck.ctx) + defer cancel() + var msg Msg + err := req.sck.r.read(ctx, &msg) return msg, err } @@ -81,6 +91,135 @@ func (req *reqSocket) SetOption(name string, value interface{}) error { return req.sck.SetOption(name, value) } +type reqWriter struct { + mu sync.Mutex + conns []*Conn + nextConn int + state *reqState +} + +func newReqWriter(ctx context.Context, state *reqState) *reqWriter { + return &reqWriter{ + state: state, + } +} + +func (r *reqWriter) write(ctx context.Context, msg Msg) error { + msg.Frames = append([][]byte{nil}, msg.Frames...) + + r.mu.Lock() + defer r.mu.Unlock() + var err error + for i := 0; i < len(r.conns); i++ { + cur := i + r.nextConn%len(r.conns) + conn := r.conns[cur] + err = conn.SendMsg(msg) + if err == nil { + r.nextConn = cur + 1%len(r.conns) + r.state.Set(conn) + return nil + } + } + return xerrors.Errorf("zmq4: no connections available: %w", err) +} + +func (r *reqWriter) addConn(c *Conn) { + r.mu.Lock() + r.conns = append(r.conns, c) + r.mu.Unlock() +} + +func (r *reqWriter) rmConn(conn *Conn) { + r.mu.Lock() + defer r.mu.Unlock() + + cur := -1 + for i := range r.conns { + if r.conns[i] == conn { + cur = i + break + } + } + if cur >= 0 { + r.conns = append(r.conns[:cur], r.conns[cur+1:]...) + } + + r.state.Reset(conn) +} + +func (r *reqWriter) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + var err error + for _, conn := range r.conns { + e := conn.Close() + if e != nil && err == nil { + err = e + } + } + r.conns = nil + return err +} + +type reqReader struct { + state *reqState +} + +func newReqReader(ctx context.Context, state *reqState) *reqReader { + return &reqReader{ + state: state, + } +} + +func (r *reqReader) addConn(c *Conn) {} +func (r *reqReader) rmConn(c *Conn) {} + +func (r *reqReader) Close() error { + return nil +} + +func (r *reqReader) read(ctx context.Context, msg *Msg) error { + curConn := r.state.Get() + if curConn == nil { + return xerrors.Errorf("zmq4: no connections available") + } + *msg = curConn.read() + if msg.err != nil { + return msg.err + } + if len(msg.Frames) > 1 { + msg.Frames = msg.Frames[1:] + } + return nil +} + +type reqState struct { + mu sync.Mutex + lastConn *Conn +} + +func (r *reqState) Set(conn *Conn) { + r.mu.Lock() + defer r.mu.Unlock() + r.lastConn = conn +} + +// Reset resets the state iff c matches the resident connection +func (r *reqState) Reset(c *Conn) { + r.mu.Lock() + defer r.mu.Unlock() + if r.lastConn == c { + r.lastConn = nil + } +} + +func (r *reqState) Get() *Conn { + r.mu.Lock() + defer r.mu.Unlock() + return r.lastConn +} + var ( _ Socket = (*reqSocket)(nil) ) diff --git a/zmq4_reqrep_test.go b/zmq4_reqrep_test.go index eb98e9f..d5858f5 100644 --- a/zmq4_reqrep_test.go +++ b/zmq4_reqrep_test.go @@ -20,19 +20,19 @@ var ( { name: "tcp-req-rep", endpoint: must(EndPoint("tcp")), - req: zmq4.NewReq(bkg), + req1: zmq4.NewReq(bkg), rep: zmq4.NewRep(bkg), }, { name: "ipc-req-rep", endpoint: "ipc://ipc-req-rep", - req: zmq4.NewReq(bkg), + req1: zmq4.NewReq(bkg), rep: zmq4.NewRep(bkg), }, { name: "inproc-req-rep", endpoint: "inproc://inproc-req-rep", - req: zmq4.NewReq(bkg), + req1: zmq4.NewReq(bkg), rep: zmq4.NewRep(bkg), }, } @@ -42,7 +42,8 @@ type testCaseReqRep struct { name string skip bool endpoint string - req zmq4.Socket + req1 zmq4.Socket + req2 zmq4.Socket rep zmq4.Socket } @@ -59,7 +60,7 @@ func TestReqRep(t *testing.T) { for i := range reqreps { tc := reqreps[i] t.Run(tc.name, func(t *testing.T) { - defer tc.req.Close() + defer tc.req1.Close() defer tc.rep.Close() ep := tc.endpoint @@ -73,7 +74,7 @@ func TestReqRep(t *testing.T) { ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) defer timeout() - grp, ctx := errgroup.WithContext(ctx) + grp, _ := errgroup.WithContext(ctx) grp.Go(func() error { err := tc.rep.Listen(ep) @@ -112,12 +113,12 @@ func TestReqRep(t *testing.T) { }) grp.Go(func() error { - err := tc.req.Dial(ep) + err := tc.req1.Dial(ep) if err != nil { return xerrors.Errorf("could not dial: %w", err) } - if addr := tc.req.Addr(); addr != nil { + if addr := tc.req1.Addr(); addr != nil { return xerrors.Errorf("dialer with non-nil Addr") } @@ -129,11 +130,11 @@ func TestReqRep(t *testing.T) { {reqLang, repLang}, {reqQuit, repQuit}, } { - err = tc.req.Send(msg.req) + err = tc.req1.Send(msg.req) if err != nil { return xerrors.Errorf("could not send REQ message %v: %w", msg.req, err) } - rep, err := tc.req.Recv() + rep, err := tc.req1.Recv() if err != nil { return xerrors.Errorf("could not recv REP message %v: %w", msg.req, err) } @@ -151,3 +152,178 @@ func TestReqRep(t *testing.T) { }) } } + +func TestMultiReqRepIssue70(t *testing.T) { + var ( + reqName1 = zmq4.NewMsgString("NAME") + reqLang1 = zmq4.NewMsgString("LANG") + reqQuit1 = zmq4.NewMsgString("QUIT") + reqName2 = zmq4.NewMsgString("NAME2") + reqLang2 = zmq4.NewMsgString("LANG2") + reqQuit2 = zmq4.NewMsgString("QUIT2") + repName1 = zmq4.NewMsgString("zmq4") + repLang1 = zmq4.NewMsgString("Go") + repQuit1 = zmq4.NewMsgString("bye") + repName2 = zmq4.NewMsgString("zmq42") + repLang2 = zmq4.NewMsgString("Go2") + repQuit2 = zmq4.NewMsgString("bye2") + ) + + reqreps := []testCaseReqRep{ + { + name: "tcp-req-rep", + endpoint: must(EndPoint("tcp")), + req1: zmq4.NewReq(bkg), + req2: zmq4.NewReq(bkg), + rep: zmq4.NewRep(bkg), + }, + { + name: "ipc-req-rep", + endpoint: "ipc://ipc-req-rep", + req1: zmq4.NewReq(bkg), + req2: zmq4.NewReq(bkg), + rep: zmq4.NewRep(bkg), + }, + { + name: "inproc-req-rep", + endpoint: "inproc://inproc-req-rep", + req1: zmq4.NewReq(bkg), + req2: zmq4.NewReq(bkg), + rep: zmq4.NewRep(bkg), + }, + } + + for i := range reqreps { + tc := reqreps[i] + t.Run(tc.name, func(t *testing.T) { + defer tc.req1.Close() + defer tc.req2.Close() + defer tc.rep.Close() + + if tc.skip { + t.Skipf(tc.name) + } + t.Parallel() + + ep := tc.endpoint + cleanUp(ep) + + ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) + defer timeout() + + grp, _ := errgroup.WithContext(ctx) + grp.Go(func() error { + err := tc.rep.Listen(ep) + if err != nil { + return xerrors.Errorf("could not listen: %w", err) + } + + if addr := tc.rep.Addr(); addr == nil { + return xerrors.Errorf("listener with nil Addr") + } + + loop1, loop2 := true, true + for loop1 || loop2 { + msg, err := tc.rep.Recv() + if err != nil { + return xerrors.Errorf("could not recv REQ message: %w", err) + } + var rep zmq4.Msg + switch string(msg.Frames[0]) { + case "NAME": + rep = repName1 + case "LANG": + rep = repLang1 + case "QUIT": + rep = repQuit1 + loop1 = false + case "NAME2": + rep = repName2 + case "LANG2": + rep = repLang2 + case "QUIT2": + rep = repQuit2 + loop2 = false + } + + err = tc.rep.Send(rep) + if err != nil { + return xerrors.Errorf("could not send REP message to %v: %w", msg, err) + } + } + return err + }) + grp.Go(func() error { + + err := tc.req2.Dial(ep) + if err != nil { + return xerrors.Errorf("could not dial: %w", err) + } + + if addr := tc.req2.Addr(); addr != nil { + return xerrors.Errorf("dialer with non-nil Addr") + } + + for _, msg := range []struct { + req zmq4.Msg + rep zmq4.Msg + }{ + {reqName2, repName2}, + {reqLang2, repLang2}, + {reqQuit2, repQuit2}, + } { + err = tc.req2.Send(msg.req) + if err != nil { + return xerrors.Errorf("could not send REQ message %v: %w", msg.req, err) + } + rep, err := tc.req2.Recv() + if err != nil { + return xerrors.Errorf("could not recv REP message %v: %w", msg.req, err) + } + + if got, want := rep, msg.rep; !reflect.DeepEqual(got, want) { + return xerrors.Errorf("got = %v, want= %v", got, want) + } + } + return err + }) + grp.Go(func() error { + + err := tc.req1.Dial(ep) + if err != nil { + return xerrors.Errorf("could not dial: %w", err) + } + + if addr := tc.req1.Addr(); addr != nil { + return xerrors.Errorf("dialer with non-nil Addr") + } + + for _, msg := range []struct { + req zmq4.Msg + rep zmq4.Msg + }{ + {reqName1, repName1}, + {reqLang1, repLang1}, + {reqQuit1, repQuit1}, + } { + err = tc.req1.Send(msg.req) + if err != nil { + return xerrors.Errorf("could not send REQ message %v: %w", msg.req, err) + } + rep, err := tc.req1.Recv() + if err != nil { + return xerrors.Errorf("could not recv REP message %v: %w", msg.req, err) + } + + if got, want := rep, msg.rep; !reflect.DeepEqual(got, want) { + return xerrors.Errorf("got = %v, want= %v", got, want) + } + } + return err + }) + if err := grp.Wait(); err != nil { + t.Fatalf("error: %+v", err) + } + }) + } +}