From a19e82d0772436e91e681c483496164dfc845002 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Wed, 27 Jan 2021 18:57:09 +0000 Subject: [PATCH] Add support for fd passing This adds a new message type for passing file descriptors. How this works is: 1. Client sends a message with a header for messageTypeFileDescriptor along with the list of descriptors to be sent 2. Client sends 2nd message to actually pass along the descriptors (needed for unix sockets). 3. Server sees the message type and waits to receive the fd's. 4. Once fd's are seen the server responds with the real fd numbers that are used which an application can use in future calls. To accomplish this reliably (on unix sockets) I had to drop the usage of the bufio.Reader because we need to ensure exact message boundaries. Within ttrpc this only support unix sockets and `net.Conn` implementations that implement `SendFds`/`ReceiveFds` (this interface is totally invented here). Something to consider, I have not attempted to do fd passing on Windows which will need other mechanisms entirely (and the conn's provided by winio are not sufficient for fd passing). I'm not sure if this new messaging will actually work on a Windows implementation. Perhaps the message tpye should be specifically for unix sockets? I'm not sure how this would be enforced at the moment except by checking if the `net.Conn` is a `*net.UnixConn`. Signed-off-by: Brian Goff --- channel.go | 94 +++++++++++++++++++++++++++++-- client.go | 79 ++++++++++++++++++++++---- fd_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++++++++ main_test.go | 86 +++++++++++++++++++++++++++++ server.go | 80 +++++++++++++++++++-------- server_test.go | 2 + types.go | 22 ++++++++ 7 files changed, 472 insertions(+), 37 deletions(-) create mode 100644 fd_test.go create mode 100644 main_test.go diff --git a/channel.go b/channel.go index aa8c9541c..88b20bf55 100644 --- a/channel.go +++ b/channel.go @@ -19,11 +19,13 @@ package ttrpc import ( "bufio" "encoding/binary" + "fmt" "io" "net" "sync" "github.com/pkg/errors" + "golang.org/x/sys/unix" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -36,8 +38,9 @@ const ( type messageType uint8 const ( - messageTypeRequest messageType = 0x1 - messageTypeResponse messageType = 0x2 + messageTypeRequest messageType = 0x1 + messageTypeResponse messageType = 0x2 + messageTypeFileDescriptor messageType = 0x3 ) // messageHeader represents the fixed-length message header of 10 bytes sent @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel { // the correct consumer. The bytes on the underlying channel // will be discarded. func (ch *channel) recv() (messageHeader, []byte, error) { - mh, err := readMessageHeader(ch.hrbuf[:], ch.br) + mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn) if err != nil { return messageHeader{}, nil, err } @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } p := ch.getmbuf(int(mh.Length)) - if _, err := io.ReadFull(ch.br, p); err != nil { + if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil { return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } return mh, p, nil } +func (ch *channel) recvFD(files *FileList) error { + var ( + fds []int + err error + ) + + switch t := ch.conn.(type) { + case FdReceiver: + fds, err = t.Recvfd() + if err != nil { + return err + } + case unixReader: + oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4)) + defer ch.putmbuf(oob) + + _, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob) + if err != nil { + return err + } + + ls, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return fmt.Errorf("error parsing socket controll message: %w", err) + } + + for _, m := range ls { + fdsTemp, err := unix.ParseUnixRights(&m) + if err != nil { + return fmt.Errorf("error parsing unix rights message: %w", err) + } + fds = append(fds, fdsTemp...) + } + default: + return fmt.Errorf("receiving file descriptors is not supported on the transport") + } + + if len(files.List) != len(fds) { + return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List)) + } + for i, fd := range fds { + files.List[i].Fileno = int64(fd) + } + return nil +} + +func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error { + fds := make([]int, len(files.List)) + + for i, f := range files.List { + fds[i] = int(f.Fileno) + } + + switch t := ch.conn.(type) { + case unixWriter: + // Must send at least a single byte over unix sockets for the ancillary data to be accepted. + _, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil) + return err + case FdSender: + return t.SendFd(fds) + default: + return fmt.Errorf("sending file descriptors is not supported on the transport") + } +} + func (ch *channel) send(streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client +type FdReceiver interface { + Recvfd() ([]int, error) +} + +// FdSender is an interface used that the transport may implement to send file descriptors to the server. +type FdSender interface { + SendFd([]int) error +} + +type unixReader interface { + ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) +} + +type unixWriter interface { + WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) +} diff --git a/client.go b/client.go index 30c9b73f3..c1a8eeea4 100644 --- a/client.go +++ b/client.go @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { } type callRequest struct { - ctx context.Context - req *Request - resp *Response // response will be written back here - errs chan error // error written here on completion + ctx context.Context + req *Request + resp *Response // response will be written back here + errs chan error // error written here on completion + files *FileList +} + +func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) { + ls := make([]*File, len(files)) + for i, f := range files { + ls[i] = &File{ + Name: f.Name(), + Fileno: int64(f.Fd()), + } + } + + resp := &Response{} + fl := &FileList{List: ls} + if err := c.dispatch(ctx, nil, resp, fl); err != nil { + return nil, err + } + if resp.Status != nil && resp.Status.Code != int32(codes.OK) { + return nil, status.ErrorProto(resp.Status) + } + + fl.Reset() + + if err := c.codec.Unmarshal(resp.Payload, fl); err != nil { + return nil, err + } + + fds := make([]int64, len(fl.List)) + for i, f := range fl.List { + fds[i] = f.Fileno + } + return fds, nil } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int info := &UnaryClientInfo{ FullMethod: fullPath(service, method), } - if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil { + if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error { + return c.dispatch(ctx, req, resp, nil) + }); err != nil { return err } @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int return nil } -func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error { errs := make(chan error, 1) call := &callRequest{ - ctx: ctx, - req: req, - resp: resp, - errs: errs, + ctx: ctx, + req: req, + resp: resp, + errs: errs, + files: files, } select { @@ -270,13 +305,35 @@ func (c *Client) run() { for { select { case call := <-calls: - if err := c.send(streamID, messageTypeRequest, call.req); err != nil { + var ( + data interface{} + mt messageType + ) + + switch { + case call.files != nil: + data = call.files + mt = messageTypeFileDescriptor + case call.req != nil: + data = call.req + mt = messageTypeRequest + } + + if err := c.send(streamID, mt, data); err != nil { call.errs <- err continue } waiters[streamID] = call streamID += 2 // enforce odd client initiated request ids + + if call.files != nil { + if err := c.channel.sendFd(streamID, mt, call.files); err != nil { + call.errs <- err + continue + } + } + case msg := <-incoming: call, ok := waiters[msg.StreamID] if !ok { diff --git a/fd_test.go b/fd_test.go new file mode 100644 index 000000000..85710a41e --- /dev/null +++ b/fd_test.go @@ -0,0 +1,146 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "strconv" + "testing" + "time" +) + +func TestSendRecvFd(t *testing.T) { + var ( + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute)) + addr, listener = newTestListener(t) + ) + + defer cancel() + + // Spin up an out of process ttrpc server + if err := listenerCmd(ctx, t.Name(), listener); err != nil { + t.Fatal(err) + } + + var ( + client, cleanup = newTestClient(t, addr) + + tclient = testFdClient{client} + ) + defer cleanup() + + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err, "error creating test pipe") + } + defer r.Close() + + type readResp struct { + buf []byte + err error + } + + expect := []byte("hello") + + chResp := make(chan readResp, 1) + go func() { + buf := make([]byte, len(expect)) + _, err := io.ReadFull(r, buf) + chResp <- readResp{buf, err} + }() + + if err := tclient.Test(ctx, w); err != nil { + t.Fatal(err) + } + + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case resp := <-chResp: + if resp.err != nil { + t.Error(err) + } + if !bytes.Equal(resp.buf, expect) { + t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(expect), string(resp.buf)) + } + } +} + +type testFdPayload struct { + Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"` +} + +func (r *testFdPayload) Reset() { *r = testFdPayload{} } +func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) } +func (r *testFdPayload) ProtoMessage() {} + +type testingServerFd struct { + respData []byte +} + +func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error { + for i, fd := range req.Fds { + f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i)) + go func() { + f.Write(s.respData) + f.Close() + }() + } + + return nil +} + +type testFdClient struct { + client *Client +} + +func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error { + fds, err := c.client.Sendfd(ctx, files) + if err != nil { + return fmt.Errorf("error sending fds: %w", err) + } + + tp := testFdPayload{} + return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp) +} + +func handleTestSendRecvFd(l net.Listener) error { + s, err := NewServer() + if err != nil { + return err + } + testImpl := &testingServerFd{respData: []byte("hello")} + + s.Register("Test", map[string]Method{ + "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + req := &testFdPayload{} + + if err := unmarshal(req); err != nil { + return nil, err + } + + return &testFdPayload{}, testImpl.Test(ctx, req) + }, + }) + + return s.Serve(context.TODO(), l) +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 000000000..bf1b44d53 --- /dev/null +++ b/main_test.go @@ -0,0 +1,86 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "flag" + "fmt" + "net" + "os" + "os/exec" + "testing" +) + +func TestMain(m *testing.M) { + var ttrpcListener bool + flag.BoolVar(&ttrpcListener, "listener", false, "Makes the test binary run a ttrpc listener for testing purposes instead of running a test") + flag.Parse() + + if ttrpcListener { + handleListenerCmd() + } + + os.Exit(m.Run()) +} + +var listenerHandlers = map[string]func(net.Listener) error{ + "TestSendRecvFd": handleTestSendRecvFd, +} + +// Starts a ttrpc serverout of process. +// +// The caller is responsible for creating the listener. +// The listener implementation must be able to be converted to an *os.File +// The passed in listtener fd will be closed when this function returns +// (because the fd is copied and handed off to the other process). +func listenerCmd(ctx context.Context, handler string, l net.Listener) error { + defer l.Close() + + cmd := exec.CommandContext(ctx, os.Args[0], "-listener=true") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(cmd.Env, "TEST_HANDLER="+handler) + + f, err := l.(fileListener).File() + if err != nil { + return err + } + defer f.Close() + + cmd.ExtraFiles = []*os.File{f} + return cmd.Start() +} + +type fileListener interface { + File() (*os.File, error) +} + +func handleListenerCmd() { + h := listenerHandlers[os.Getenv("TEST_HANDLER")] + l, err := net.FileListener(os.NewFile(3, "TEST_LISTENER")) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + if err := h(l); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(2) + } + os.Exit(0) +} diff --git a/server.go b/server.go index c18b4e43b..ae9f5f134 100644 --- a/server.go +++ b/server.go @@ -302,8 +302,9 @@ func (c *serverConn) close() error { func (c *serverConn) run(sctx context.Context) { type ( request struct { - id uint32 - req *Request + id uint32 + req *Request + files *FileList } response struct { @@ -376,20 +377,42 @@ func (c *serverConn) run(sctx context.Context) { continue } - if mh.Type != messageTypeRequest { - // we must ignore this for future compat. - continue - } - - var req Request - if err := c.server.codec.Unmarshal(p, &req); err != nil { + var ( + files *FileList + req *Request + ) + switch mh.Type { + case messageTypeRequest: + req = &Request{} + if err := c.server.codec.Unmarshal(p, req); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { + return + } + continue + } ch.putmbuf(p) - if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { - return + case messageTypeFileDescriptor: + files = &FileList{} + if err := c.server.codec.Unmarshal(p, files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + continue } + + if err := ch.recvFD(files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + } + ch.putmbuf(p) + default: + // Ignore other types for future compatability continue } - ch.putmbuf(p) if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. @@ -403,8 +426,9 @@ func (c *serverConn) run(sctx context.Context) { // because we have already accepted the client request. select { case requests <- request{ - id: mh.StreamID, - req: &req, + id: mh.StreamID, + req: req, + files: files, }: case <-done: return @@ -432,19 +456,31 @@ func (c *serverConn) run(sctx context.Context) { case request := <-requests: active++ go func(id uint32) { - ctx, cancel := getRequestContext(ctx, request.req) - defer cancel() - - p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) - resp := &Response{ - Status: status.Proto(), - Payload: p, + var resp Response + + switch { + case request.files != nil: + p, err := c.server.codec.Marshal(request.files) + if err != nil { + s, _ := status.FromError(err) + resp.Status = s.Proto() + } + resp.Payload = p + case request.req != nil: + ctx, cancel := getRequestContext(ctx, request.req) + defer cancel() + + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + resp.Status = status.Proto() + resp.Payload = p + default: + resp.Status = status.New(codes.Internal, "unknown request type").Proto() } select { case responses <- response{ id: id, - resp: resp, + resp: &resp, }: case <-done: } diff --git a/server_test.go b/server_test.go index 8094ca920..38aa128b2 100644 --- a/server_test.go +++ b/server_test.go @@ -103,6 +103,8 @@ func init() { proto.RegisterType((*testPayload)(nil), "testPayload") proto.RegisterType((*Request)(nil), "Request") proto.RegisterType((*Response)(nil), "Response") + proto.RegisterType((*FileList)(nil), "FileList") + proto.RegisterType((*testFdPayload)(nil), "testFdPayload") } func TestServer(t *testing.T) { diff --git a/types.go b/types.go index 9a1c19a72..090dcd2d0 100644 --- a/types.go +++ b/types.go @@ -61,3 +61,25 @@ type KeyValue struct { func (m *KeyValue) Reset() { *m = KeyValue{} } func (*KeyValue) ProtoMessage() {} func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) } + +type FileList struct { + List []*File `protobuf:"bytes,1,rep,name=list,proto3"` +} + +func (r *FileList) Reset() { *r = FileList{} } +func (r *FileList) String() string { return fmt.Sprintf("%+#v", r) } +func (r *FileList) ProtoMessage() {} + +// File represents a file descriptor that is transferred. +// Once the file descriptor is passed, the server should be able to call `os.NewFile(f.Fileno, f.Name)` +// The Fileno field will be filled in by the server side after the descriptor has been passed. +type File struct { + // Name is the name to be used when accepting the file descriptor + Name string `protobuf:"bytes,1,opt,name=name,proto3"` + // Fileno is the file descriptor id/pointer + Fileno int64 `protobuf:"varint,2,opt,name=timeout_nano,proto3"` +} + +func (r *File) Reset() { *r = File{} } +func (r *File) String() string { return fmt.Sprintf("%+#v", r) } +func (r *File) ProtoMessage() {}