From 215aefac3f641dcc1ba06b63ee30d89606f637a1 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Wed, 27 Jan 2021 18:57:09 +0000 Subject: [PATCH] Add support for fd passing This adds a new message type for passing file descriptors. How this works is: 1. Client sends a message with a header for messageTypeFileDescriptor along with the list of descriptors to be sent 2. Client sends 2nd message to actually pass along the descriptors (needed for unix sockets). 3. Server sees the message type and waits to receive the fd's. 4. Once fd's are seen the server responds with the real fd numbers that are used which an application can use in future calls. To accomplish this reliably (on unix sockets) I had to drop the usage of the bufio.Reader because we need to ensure exact message boundaries. Within ttrpc this only support unix sockets and `net.Conn` implementations that implement `SendFds`/`ReceiveFds` (this interface is totally invented here). Something to consider, I have not attempted to do fd passing on Windows which will need other mechanisms entirely (and the conn's provided by winio are not sufficient for fd passing). I'm not sure if this new messaging will actually work on a Windows implementation. Perhaps the message tpye should be specifically for unix sockets? I'm not sure how this would be enforced at the moment except by checking if the `net.Conn` is a `*net.UnixConn`. Signed-off-by: Brian Goff --- channel.go | 94 +++++++++++++++++++++++++++++++++++-- client.go | 79 ++++++++++++++++++++++++++----- fd_test.go | 125 +++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 80 ++++++++++++++++++++++--------- server_test.go | 2 + types.go | 23 +++++++++ 6 files changed, 366 insertions(+), 37 deletions(-) create mode 100644 fd_test.go diff --git a/channel.go b/channel.go index aa8c9541c..ff627f944 100644 --- a/channel.go +++ b/channel.go @@ -19,11 +19,13 @@ package ttrpc import ( "bufio" "encoding/binary" + "fmt" "io" "net" "sync" "github.com/pkg/errors" + "golang.org/x/sys/unix" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -36,8 +38,9 @@ const ( type messageType uint8 const ( - messageTypeRequest messageType = 0x1 - messageTypeResponse messageType = 0x2 + messageTypeRequest messageType = 0x1 + messageTypeResponse messageType = 0x2 + messageTypeFileDescriptor messageType = 0x3 ) // messageHeader represents the fixed-length message header of 10 bytes sent @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel { // the correct consumer. The bytes on the underlying channel // will be discarded. func (ch *channel) recv() (messageHeader, []byte, error) { - mh, err := readMessageHeader(ch.hrbuf[:], ch.br) + mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn) if err != nil { return messageHeader{}, nil, err } @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } p := ch.getmbuf(int(mh.Length)) - if _, err := io.ReadFull(ch.br, p); err != nil { + if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil { return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } return mh, p, nil } +func (ch *channel) recvFD(files *FileList) error { + var ( + fds []int + err error + ) + + switch t := ch.conn.(type) { + case FdReceiver: + fds, err = t.ReceiveFds() + if err != nil { + return err + } + case unixReader: + oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4)) + defer ch.putmbuf(oob) + + _, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob) + if err != nil { + return err + } + + ls, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return fmt.Errorf("error parsing socket controll message: %w", err) + } + + for _, m := range ls { + fdsTemp, err := unix.ParseUnixRights(&m) + if err != nil { + return fmt.Errorf("error parsing unix rights message: %w", err) + } + fds = append(fds, fdsTemp...) + } + default: + return fmt.Errorf("receiving file descriptors is not supported on the transport") + } + + if len(files.List) != len(fds) { + return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List)) + } + for i, fd := range fds { + files.List[i].Fileno = int64(fd) + } + return nil +} + +func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error { + fds := make([]int, len(files.List)) + + for i, f := range files.List { + fds[i] = int(f.Fileno) + } + + switch t := ch.conn.(type) { + case unixWriter: + // Must send at least a single byte over unix sockets for the ancillary data to be accepted. + _, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil) + return err + case FdSender: + return t.SendFds(fds) + default: + return fmt.Errorf("sending file descriptors is not supported on the transport") + } +} + func (ch *channel) send(streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client +type FdReceiver interface { + ReceiveFds() ([]int, error) +} + +// FdSender is an interface used that the transport may implement to send file descriptors to the server. +type FdSender interface { + SendFds([]int) error +} + +type unixReader interface { + ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) +} + +type unixWriter interface { + WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) +} diff --git a/client.go b/client.go index 30c9b73f3..c1a8eeea4 100644 --- a/client.go +++ b/client.go @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { } type callRequest struct { - ctx context.Context - req *Request - resp *Response // response will be written back here - errs chan error // error written here on completion + ctx context.Context + req *Request + resp *Response // response will be written back here + errs chan error // error written here on completion + files *FileList +} + +func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) { + ls := make([]*File, len(files)) + for i, f := range files { + ls[i] = &File{ + Name: f.Name(), + Fileno: int64(f.Fd()), + } + } + + resp := &Response{} + fl := &FileList{List: ls} + if err := c.dispatch(ctx, nil, resp, fl); err != nil { + return nil, err + } + if resp.Status != nil && resp.Status.Code != int32(codes.OK) { + return nil, status.ErrorProto(resp.Status) + } + + fl.Reset() + + if err := c.codec.Unmarshal(resp.Payload, fl); err != nil { + return nil, err + } + + fds := make([]int64, len(fl.List)) + for i, f := range fl.List { + fds[i] = f.Fileno + } + return fds, nil } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int info := &UnaryClientInfo{ FullMethod: fullPath(service, method), } - if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil { + if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error { + return c.dispatch(ctx, req, resp, nil) + }); err != nil { return err } @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int return nil } -func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error { errs := make(chan error, 1) call := &callRequest{ - ctx: ctx, - req: req, - resp: resp, - errs: errs, + ctx: ctx, + req: req, + resp: resp, + errs: errs, + files: files, } select { @@ -270,13 +305,35 @@ func (c *Client) run() { for { select { case call := <-calls: - if err := c.send(streamID, messageTypeRequest, call.req); err != nil { + var ( + data interface{} + mt messageType + ) + + switch { + case call.files != nil: + data = call.files + mt = messageTypeFileDescriptor + case call.req != nil: + data = call.req + mt = messageTypeRequest + } + + if err := c.send(streamID, mt, data); err != nil { call.errs <- err continue } waiters[streamID] = call streamID += 2 // enforce odd client initiated request ids + + if call.files != nil { + if err := c.channel.sendFd(streamID, mt, call.files); err != nil { + call.errs <- err + continue + } + } + case msg := <-incoming: call, ok := waiters[msg.StreamID] if !ok { diff --git a/fd_test.go b/fd_test.go new file mode 100644 index 000000000..685f4b722 --- /dev/null +++ b/fd_test.go @@ -0,0 +1,125 @@ +package ttrpc + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "strconv" + "testing" + "time" +) + +type HandshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) + +func (f HandshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + return f(ctx, conn) +} + +func TestSendRecevFd(t *testing.T) { + var ( + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute)) + server = mustServer(t)(NewServer()) + testImpl = &testingServerFd{respData: []byte("hello")} + addr, listener = newTestListener(t) + ) + + defer cancel() + defer listener.Close() + + server.Register("Test", map[string]Method{ + "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + req := &testFdPayload{} + + if err := unmarshal(req); err != nil { + return nil, err + } + + return &testFdPayload{}, testImpl.Test(ctx, req) + }, + }) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + var ( + client, cleanup = newTestClient(t, addr) + + tclient = testFdClient{client} + ) + defer cleanup() + + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err, "error creating test pipe") + } + defer r.Close() + + type readResp struct { + buf []byte + err error + } + + chResp := make(chan readResp, 1) + go func() { + buf := make([]byte, len(testImpl.respData)) + _, err := io.ReadFull(r, buf) + chResp <- readResp{buf, err} + }() + + if err := tclient.Test(ctx, w); err != nil { + t.Fatal(err) + } + + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case resp := <-chResp: + if resp.err != nil { + t.Error(err) + } + if !bytes.Equal(resp.buf, testImpl.respData) { + t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(testImpl.respData), string(resp.buf)) + } + } +} + +type testFdPayload struct { + Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"` +} + +func (r *testFdPayload) Reset() { *r = testFdPayload{} } +func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) } +func (r *testFdPayload) ProtoMessage() {} + +type testingServerFd struct { + respData []byte +} + +func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error { + for i, fd := range req.Fds { + f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i)) + go func() { + f.Write(s.respData) + f.Close() + }() + } + + return nil +} + +type testFdClient struct { + client *Client +} + +func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error { + fds, err := c.client.Sendfd(ctx, files) + if err != nil { + return fmt.Errorf("error sending fds: %w", err) + } + + tp := testFdPayload{} + return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp) +} diff --git a/server.go b/server.go index c18b4e43b..ae9f5f134 100644 --- a/server.go +++ b/server.go @@ -302,8 +302,9 @@ func (c *serverConn) close() error { func (c *serverConn) run(sctx context.Context) { type ( request struct { - id uint32 - req *Request + id uint32 + req *Request + files *FileList } response struct { @@ -376,20 +377,42 @@ func (c *serverConn) run(sctx context.Context) { continue } - if mh.Type != messageTypeRequest { - // we must ignore this for future compat. - continue - } - - var req Request - if err := c.server.codec.Unmarshal(p, &req); err != nil { + var ( + files *FileList + req *Request + ) + switch mh.Type { + case messageTypeRequest: + req = &Request{} + if err := c.server.codec.Unmarshal(p, req); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { + return + } + continue + } ch.putmbuf(p) - if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { - return + case messageTypeFileDescriptor: + files = &FileList{} + if err := c.server.codec.Unmarshal(p, files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + continue } + + if err := ch.recvFD(files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + } + ch.putmbuf(p) + default: + // Ignore other types for future compatability continue } - ch.putmbuf(p) if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. @@ -403,8 +426,9 @@ func (c *serverConn) run(sctx context.Context) { // because we have already accepted the client request. select { case requests <- request{ - id: mh.StreamID, - req: &req, + id: mh.StreamID, + req: req, + files: files, }: case <-done: return @@ -432,19 +456,31 @@ func (c *serverConn) run(sctx context.Context) { case request := <-requests: active++ go func(id uint32) { - ctx, cancel := getRequestContext(ctx, request.req) - defer cancel() - - p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) - resp := &Response{ - Status: status.Proto(), - Payload: p, + var resp Response + + switch { + case request.files != nil: + p, err := c.server.codec.Marshal(request.files) + if err != nil { + s, _ := status.FromError(err) + resp.Status = s.Proto() + } + resp.Payload = p + case request.req != nil: + ctx, cancel := getRequestContext(ctx, request.req) + defer cancel() + + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + resp.Status = status.Proto() + resp.Payload = p + default: + resp.Status = status.New(codes.Internal, "unknown request type").Proto() } select { case responses <- response{ id: id, - resp: resp, + resp: &resp, }: case <-done: } diff --git a/server_test.go b/server_test.go index 8094ca920..38aa128b2 100644 --- a/server_test.go +++ b/server_test.go @@ -103,6 +103,8 @@ func init() { proto.RegisterType((*testPayload)(nil), "testPayload") proto.RegisterType((*Request)(nil), "Request") proto.RegisterType((*Response)(nil), "Response") + proto.RegisterType((*FileList)(nil), "FileList") + proto.RegisterType((*testFdPayload)(nil), "testFdPayload") } func TestServer(t *testing.T) { diff --git a/types.go b/types.go index 9a1c19a72..82038c6fb 100644 --- a/types.go +++ b/types.go @@ -28,6 +28,7 @@ type Request struct { Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"` + // Files *FileList `protobuf:"bytes,6,rep,name=files,proto3"` } func (r *Request) Reset() { *r = Request{} } @@ -61,3 +62,25 @@ type KeyValue struct { func (m *KeyValue) Reset() { *m = KeyValue{} } func (*KeyValue) ProtoMessage() {} func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) } + +type FileList struct { + List []*File `protobuf:"bytes,1,rep,name=list,proto3"` +} + +func (r *FileList) Reset() { *r = FileList{} } +func (r *FileList) String() string { return fmt.Sprintf("%+#v", r) } +func (r *FileList) ProtoMessage() {} + +// File represents a file descriptor that is transferred. +// Once the file descriptor is passed, the server should be able to call `os.NewFile(f.Fileno, f.Name)` +// The Fileno field will be filled in by the server side after the descriptor has been passed. +type File struct { + // Name is the name to be used when accepting the file descriptor + Name string `protobuf:"bytes,1,opt,name=name,proto3"` + // Fileno is the file descriptor id/pointer + Fileno int64 `protobuf:"varint,2,opt,name=timeout_nano,proto3"` +} + +func (r *File) Reset() { *r = File{} } +func (r *File) String() string { return fmt.Sprintf("%+#v", r) } +func (r *File) ProtoMessage() {}