Skip to content

Commit

Permalink
Add support for fd passing
Browse files Browse the repository at this point in the history
This adds a new message type for passing file descriptors.
How this works is:

1. Client sends a message with a header for messageTypeFileDescriptor
   along with the list of descriptors to be sent
2. Client sends 2nd message to actually pass along the descriptors
   (needed for unix sockets).
3. Server sees the message type and waits to receive the fd's.
4. Once fd's are seen the server responds with the real fd numbers that
   are used which an application can use in future calls.

To accomplish this reliably (on unix sockets) I had to drop the usage of
the bufio.Reader because we need to ensure exact message boundaries.

Within ttrpc this only support unix sockets and `net.Conn` implementations
that implement `SendFds`/`ReceiveFds` (this interface is totally
invented here).

Something to consider, I have not attempted to do fd passing on Windows
which will need other mechanisms entirely (and the conn's provided by
winio are not sufficient for fd passing).
I'm not sure if this new messaging will actually work on a Windows
implementation.

Perhaps the message tpye should be specifically for unix sockets? I'm
not sure how this would be enforced at the moment except by checking if
the `net.Conn` is a `*net.UnixConn`.

Signed-off-by: Brian Goff <[email protected]>
  • Loading branch information
cpuguy83 committed Jan 27, 2021
1 parent bfba540 commit 215aefa
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 37 deletions.
94 changes: 90 additions & 4 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package ttrpc
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"sync"

"github.com/pkg/errors"
"golang.org/x/sys/unix"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -36,8 +38,9 @@ const (
type messageType uint8

const (
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeFileDescriptor messageType = 0x3
)

// messageHeader represents the fixed-length message header of 10 bytes sent
Expand Down Expand Up @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel {
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn)
if err != nil {
return messageHeader{}, nil, err
}
Expand All @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

p := ch.getmbuf(int(mh.Length))
if _, err := io.ReadFull(ch.br, p); err != nil {
if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil {
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
}

return mh, p, nil
}

func (ch *channel) recvFD(files *FileList) error {
var (
fds []int
err error
)

switch t := ch.conn.(type) {
case FdReceiver:
fds, err = t.ReceiveFds()
if err != nil {
return err
}
case unixReader:
oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4))
defer ch.putmbuf(oob)

_, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob)
if err != nil {
return err
}

ls, err := unix.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return fmt.Errorf("error parsing socket controll message: %w", err)
}

for _, m := range ls {
fdsTemp, err := unix.ParseUnixRights(&m)
if err != nil {
return fmt.Errorf("error parsing unix rights message: %w", err)
}
fds = append(fds, fdsTemp...)
}
default:
return fmt.Errorf("receiving file descriptors is not supported on the transport")
}

if len(files.List) != len(fds) {
return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List))
}
for i, fd := range fds {
files.List[i].Fileno = int64(fd)
}
return nil
}

func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error {
fds := make([]int, len(files.List))

for i, f := range files.List {
fds[i] = int(f.Fileno)
}

switch t := ch.conn.(type) {
case unixWriter:
// Must send at least a single byte over unix sockets for the ancillary data to be accepted.
_, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil)
return err
case FdSender:
return t.SendFds(fds)
default:
return fmt.Errorf("sending file descriptors is not supported on the transport")
}
}

func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
Expand Down Expand Up @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client
type FdReceiver interface {
ReceiveFds() ([]int, error)
}

// FdSender is an interface used that the transport may implement to send file descriptors to the server.
type FdSender interface {
SendFds([]int) error
}

type unixReader interface {
ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error)
}

type unixWriter interface {
WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error)
}
79 changes: 68 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
}

type callRequest struct {
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
files *FileList
}

func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) {
ls := make([]*File, len(files))
for i, f := range files {
ls[i] = &File{
Name: f.Name(),
Fileno: int64(f.Fd()),
}
}

resp := &Response{}
fl := &FileList{List: ls}
if err := c.dispatch(ctx, nil, resp, fl); err != nil {
return nil, err
}
if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
return nil, status.ErrorProto(resp.Status)
}

fl.Reset()

if err := c.codec.Unmarshal(resp.Payload, fl); err != nil {
return nil, err
}

fds := make([]int64, len(fl.List))
for i, f := range fl.List {
fds[i] = f.Fileno
}
return fds, nil
}

func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
Expand Down Expand Up @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error {
return c.dispatch(ctx, req, resp, nil)
}); err != nil {
return err
}

Expand All @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
return nil
}

func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error {
errs := make(chan error, 1)
call := &callRequest{
ctx: ctx,
req: req,
resp: resp,
errs: errs,
ctx: ctx,
req: req,
resp: resp,
errs: errs,
files: files,
}

select {
Expand Down Expand Up @@ -270,13 +305,35 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
var (
data interface{}
mt messageType
)

switch {
case call.files != nil:
data = call.files
mt = messageTypeFileDescriptor
case call.req != nil:
data = call.req
mt = messageTypeRequest
}

if err := c.send(streamID, mt, data); err != nil {
call.errs <- err
continue
}

waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids

if call.files != nil {
if err := c.channel.sendFd(streamID, mt, call.files); err != nil {
call.errs <- err
continue
}
}

case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
Expand Down
125 changes: 125 additions & 0 deletions fd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package ttrpc

import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"testing"
"time"
)

type HandshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)

func (f HandshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
return f(ctx, conn)
}

func TestSendRecevFd(t *testing.T) {
var (
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute))
server = mustServer(t)(NewServer())
testImpl = &testingServerFd{respData: []byte("hello")}
addr, listener = newTestListener(t)
)

defer cancel()
defer listener.Close()

server.Register("Test", map[string]Method{
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
req := &testFdPayload{}

if err := unmarshal(req); err != nil {
return nil, err
}

return &testFdPayload{}, testImpl.Test(ctx, req)
},
})

go server.Serve(ctx, listener)
defer server.Shutdown(ctx)

var (
client, cleanup = newTestClient(t, addr)

tclient = testFdClient{client}
)
defer cleanup()

r, w, err := os.Pipe()
if err != nil {
t.Fatal(err, "error creating test pipe")
}
defer r.Close()

type readResp struct {
buf []byte
err error
}

chResp := make(chan readResp, 1)
go func() {
buf := make([]byte, len(testImpl.respData))
_, err := io.ReadFull(r, buf)
chResp <- readResp{buf, err}
}()

if err := tclient.Test(ctx, w); err != nil {
t.Fatal(err)
}

select {
case <-ctx.Done():
t.Fatal(ctx.Err())
case resp := <-chResp:
if resp.err != nil {
t.Error(err)
}
if !bytes.Equal(resp.buf, testImpl.respData) {
t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(testImpl.respData), string(resp.buf))
}
}
}

type testFdPayload struct {
Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"`
}

func (r *testFdPayload) Reset() { *r = testFdPayload{} }
func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) }
func (r *testFdPayload) ProtoMessage() {}

type testingServerFd struct {
respData []byte
}

func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error {
for i, fd := range req.Fds {
f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i))
go func() {
f.Write(s.respData)
f.Close()
}()
}

return nil
}

type testFdClient struct {
client *Client
}

func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error {
fds, err := c.client.Sendfd(ctx, files)
if err != nil {
return fmt.Errorf("error sending fds: %w", err)
}

tp := testFdPayload{}
return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp)
}
Loading

0 comments on commit 215aefa

Please sign in to comment.