diff --git a/channel.go b/channel.go index aa8c9541c..88b20bf55 100644 --- a/channel.go +++ b/channel.go @@ -19,11 +19,13 @@ package ttrpc import ( "bufio" "encoding/binary" + "fmt" "io" "net" "sync" "github.com/pkg/errors" + "golang.org/x/sys/unix" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -36,8 +38,9 @@ const ( type messageType uint8 const ( - messageTypeRequest messageType = 0x1 - messageTypeResponse messageType = 0x2 + messageTypeRequest messageType = 0x1 + messageTypeResponse messageType = 0x2 + messageTypeFileDescriptor messageType = 0x3 ) // messageHeader represents the fixed-length message header of 10 bytes sent @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel { // the correct consumer. The bytes on the underlying channel // will be discarded. func (ch *channel) recv() (messageHeader, []byte, error) { - mh, err := readMessageHeader(ch.hrbuf[:], ch.br) + mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn) if err != nil { return messageHeader{}, nil, err } @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } p := ch.getmbuf(int(mh.Length)) - if _, err := io.ReadFull(ch.br, p); err != nil { + if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil { return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } return mh, p, nil } +func (ch *channel) recvFD(files *FileList) error { + var ( + fds []int + err error + ) + + switch t := ch.conn.(type) { + case FdReceiver: + fds, err = t.Recvfd() + if err != nil { + return err + } + case unixReader: + oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4)) + defer ch.putmbuf(oob) + + _, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob) + if err != nil { + return err + } + + ls, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return fmt.Errorf("error parsing socket controll message: %w", err) + } + + for _, m := range ls { + fdsTemp, err := unix.ParseUnixRights(&m) + if err != nil { + return fmt.Errorf("error parsing unix rights message: %w", err) + } + fds = append(fds, fdsTemp...) + } + default: + return fmt.Errorf("receiving file descriptors is not supported on the transport") + } + + if len(files.List) != len(fds) { + return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List)) + } + for i, fd := range fds { + files.List[i].Fileno = int64(fd) + } + return nil +} + +func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error { + fds := make([]int, len(files.List)) + + for i, f := range files.List { + fds[i] = int(f.Fileno) + } + + switch t := ch.conn.(type) { + case unixWriter: + // Must send at least a single byte over unix sockets for the ancillary data to be accepted. + _, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil) + return err + case FdSender: + return t.SendFd(fds) + default: + return fmt.Errorf("sending file descriptors is not supported on the transport") + } +} + func (ch *channel) send(streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client +type FdReceiver interface { + Recvfd() ([]int, error) +} + +// FdSender is an interface used that the transport may implement to send file descriptors to the server. +type FdSender interface { + SendFd([]int) error +} + +type unixReader interface { + ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) +} + +type unixWriter interface { + WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) +} diff --git a/client.go b/client.go index 30c9b73f3..c1a8eeea4 100644 --- a/client.go +++ b/client.go @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { } type callRequest struct { - ctx context.Context - req *Request - resp *Response // response will be written back here - errs chan error // error written here on completion + ctx context.Context + req *Request + resp *Response // response will be written back here + errs chan error // error written here on completion + files *FileList +} + +func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) { + ls := make([]*File, len(files)) + for i, f := range files { + ls[i] = &File{ + Name: f.Name(), + Fileno: int64(f.Fd()), + } + } + + resp := &Response{} + fl := &FileList{List: ls} + if err := c.dispatch(ctx, nil, resp, fl); err != nil { + return nil, err + } + if resp.Status != nil && resp.Status.Code != int32(codes.OK) { + return nil, status.ErrorProto(resp.Status) + } + + fl.Reset() + + if err := c.codec.Unmarshal(resp.Payload, fl); err != nil { + return nil, err + } + + fds := make([]int64, len(fl.List)) + for i, f := range fl.List { + fds[i] = f.Fileno + } + return fds, nil } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int info := &UnaryClientInfo{ FullMethod: fullPath(service, method), } - if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil { + if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error { + return c.dispatch(ctx, req, resp, nil) + }); err != nil { return err } @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int return nil } -func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error { errs := make(chan error, 1) call := &callRequest{ - ctx: ctx, - req: req, - resp: resp, - errs: errs, + ctx: ctx, + req: req, + resp: resp, + errs: errs, + files: files, } select { @@ -270,13 +305,35 @@ func (c *Client) run() { for { select { case call := <-calls: - if err := c.send(streamID, messageTypeRequest, call.req); err != nil { + var ( + data interface{} + mt messageType + ) + + switch { + case call.files != nil: + data = call.files + mt = messageTypeFileDescriptor + case call.req != nil: + data = call.req + mt = messageTypeRequest + } + + if err := c.send(streamID, mt, data); err != nil { call.errs <- err continue } waiters[streamID] = call streamID += 2 // enforce odd client initiated request ids + + if call.files != nil { + if err := c.channel.sendFd(streamID, mt, call.files); err != nil { + call.errs <- err + continue + } + } + case msg := <-incoming: call, ok := waiters[msg.StreamID] if !ok { diff --git a/fd_test.go b/fd_test.go new file mode 100644 index 000000000..85710a41e --- /dev/null +++ b/fd_test.go @@ -0,0 +1,146 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "strconv" + "testing" + "time" +) + +func TestSendRecvFd(t *testing.T) { + var ( + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute)) + addr, listener = newTestListener(t) + ) + + defer cancel() + + // Spin up an out of process ttrpc server + if err := listenerCmd(ctx, t.Name(), listener); err != nil { + t.Fatal(err) + } + + var ( + client, cleanup = newTestClient(t, addr) + + tclient = testFdClient{client} + ) + defer cleanup() + + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err, "error creating test pipe") + } + defer r.Close() + + type readResp struct { + buf []byte + err error + } + + expect := []byte("hello") + + chResp := make(chan readResp, 1) + go func() { + buf := make([]byte, len(expect)) + _, err := io.ReadFull(r, buf) + chResp <- readResp{buf, err} + }() + + if err := tclient.Test(ctx, w); err != nil { + t.Fatal(err) + } + + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case resp := <-chResp: + if resp.err != nil { + t.Error(err) + } + if !bytes.Equal(resp.buf, expect) { + t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(expect), string(resp.buf)) + } + } +} + +type testFdPayload struct { + Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"` +} + +func (r *testFdPayload) Reset() { *r = testFdPayload{} } +func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) } +func (r *testFdPayload) ProtoMessage() {} + +type testingServerFd struct { + respData []byte +} + +func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error { + for i, fd := range req.Fds { + f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i)) + go func() { + f.Write(s.respData) + f.Close() + }() + } + + return nil +} + +type testFdClient struct { + client *Client +} + +func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error { + fds, err := c.client.Sendfd(ctx, files) + if err != nil { + return fmt.Errorf("error sending fds: %w", err) + } + + tp := testFdPayload{} + return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp) +} + +func handleTestSendRecvFd(l net.Listener) error { + s, err := NewServer() + if err != nil { + return err + } + testImpl := &testingServerFd{respData: []byte("hello")} + + s.Register("Test", map[string]Method{ + "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + req := &testFdPayload{} + + if err := unmarshal(req); err != nil { + return nil, err + } + + return &testFdPayload{}, testImpl.Test(ctx, req) + }, + }) + + return s.Serve(context.TODO(), l) +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 000000000..bf1b44d53 --- /dev/null +++ b/main_test.go @@ -0,0 +1,86 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "flag" + "fmt" + "net" + "os" + "os/exec" + "testing" +) + +func TestMain(m *testing.M) { + var ttrpcListener bool + flag.BoolVar(&ttrpcListener, "listener", false, "Makes the test binary run a ttrpc listener for testing purposes instead of running a test") + flag.Parse() + + if ttrpcListener { + handleListenerCmd() + } + + os.Exit(m.Run()) +} + +var listenerHandlers = map[string]func(net.Listener) error{ + "TestSendRecvFd": handleTestSendRecvFd, +} + +// Starts a ttrpc serverout of process. +// +// The caller is responsible for creating the listener. +// The listener implementation must be able to be converted to an *os.File +// The passed in listtener fd will be closed when this function returns +// (because the fd is copied and handed off to the other process). +func listenerCmd(ctx context.Context, handler string, l net.Listener) error { + defer l.Close() + + cmd := exec.CommandContext(ctx, os.Args[0], "-listener=true") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(cmd.Env, "TEST_HANDLER="+handler) + + f, err := l.(fileListener).File() + if err != nil { + return err + } + defer f.Close() + + cmd.ExtraFiles = []*os.File{f} + return cmd.Start() +} + +type fileListener interface { + File() (*os.File, error) +} + +func handleListenerCmd() { + h := listenerHandlers[os.Getenv("TEST_HANDLER")] + l, err := net.FileListener(os.NewFile(3, "TEST_LISTENER")) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + if err := h(l); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(2) + } + os.Exit(0) +} diff --git a/server.go b/server.go index c18b4e43b..ae9f5f134 100644 --- a/server.go +++ b/server.go @@ -302,8 +302,9 @@ func (c *serverConn) close() error { func (c *serverConn) run(sctx context.Context) { type ( request struct { - id uint32 - req *Request + id uint32 + req *Request + files *FileList } response struct { @@ -376,20 +377,42 @@ func (c *serverConn) run(sctx context.Context) { continue } - if mh.Type != messageTypeRequest { - // we must ignore this for future compat. - continue - } - - var req Request - if err := c.server.codec.Unmarshal(p, &req); err != nil { + var ( + files *FileList + req *Request + ) + switch mh.Type { + case messageTypeRequest: + req = &Request{} + if err := c.server.codec.Unmarshal(p, req); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { + return + } + continue + } ch.putmbuf(p) - if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { - return + case messageTypeFileDescriptor: + files = &FileList{} + if err := c.server.codec.Unmarshal(p, files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + continue } + + if err := ch.recvFD(files); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal file list error: %v", err)) { + return + } + } + ch.putmbuf(p) + default: + // Ignore other types for future compatability continue } - ch.putmbuf(p) if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. @@ -403,8 +426,9 @@ func (c *serverConn) run(sctx context.Context) { // because we have already accepted the client request. select { case requests <- request{ - id: mh.StreamID, - req: &req, + id: mh.StreamID, + req: req, + files: files, }: case <-done: return @@ -432,19 +456,31 @@ func (c *serverConn) run(sctx context.Context) { case request := <-requests: active++ go func(id uint32) { - ctx, cancel := getRequestContext(ctx, request.req) - defer cancel() - - p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) - resp := &Response{ - Status: status.Proto(), - Payload: p, + var resp Response + + switch { + case request.files != nil: + p, err := c.server.codec.Marshal(request.files) + if err != nil { + s, _ := status.FromError(err) + resp.Status = s.Proto() + } + resp.Payload = p + case request.req != nil: + ctx, cancel := getRequestContext(ctx, request.req) + defer cancel() + + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + resp.Status = status.Proto() + resp.Payload = p + default: + resp.Status = status.New(codes.Internal, "unknown request type").Proto() } select { case responses <- response{ id: id, - resp: resp, + resp: &resp, }: case <-done: } diff --git a/server_test.go b/server_test.go index 8094ca920..38aa128b2 100644 --- a/server_test.go +++ b/server_test.go @@ -103,6 +103,8 @@ func init() { proto.RegisterType((*testPayload)(nil), "testPayload") proto.RegisterType((*Request)(nil), "Request") proto.RegisterType((*Response)(nil), "Response") + proto.RegisterType((*FileList)(nil), "FileList") + proto.RegisterType((*testFdPayload)(nil), "testFdPayload") } func TestServer(t *testing.T) { diff --git a/types.go b/types.go index 9a1c19a72..090dcd2d0 100644 --- a/types.go +++ b/types.go @@ -61,3 +61,25 @@ type KeyValue struct { func (m *KeyValue) Reset() { *m = KeyValue{} } func (*KeyValue) ProtoMessage() {} func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) } + +type FileList struct { + List []*File `protobuf:"bytes,1,rep,name=list,proto3"` +} + +func (r *FileList) Reset() { *r = FileList{} } +func (r *FileList) String() string { return fmt.Sprintf("%+#v", r) } +func (r *FileList) ProtoMessage() {} + +// File represents a file descriptor that is transferred. +// Once the file descriptor is passed, the server should be able to call `os.NewFile(f.Fileno, f.Name)` +// The Fileno field will be filled in by the server side after the descriptor has been passed. +type File struct { + // Name is the name to be used when accepting the file descriptor + Name string `protobuf:"bytes,1,opt,name=name,proto3"` + // Fileno is the file descriptor id/pointer + Fileno int64 `protobuf:"varint,2,opt,name=timeout_nano,proto3"` +} + +func (r *File) Reset() { *r = File{} } +func (r *File) String() string { return fmt.Sprintf("%+#v", r) } +func (r *File) ProtoMessage() {}