Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fd passing #75

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 90 additions & 4 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package ttrpc
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"sync"

"github.com/pkg/errors"
"golang.org/x/sys/unix"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -36,8 +38,9 @@ const (
type messageType uint8

const (
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeFileDescriptor messageType = 0x3
)

// messageHeader represents the fixed-length message header of 10 bytes sent
Expand Down Expand Up @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel {
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn)
if err != nil {
return messageHeader{}, nil, err
}
Expand All @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

p := ch.getmbuf(int(mh.Length))
if _, err := io.ReadFull(ch.br, p); err != nil {
if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil {
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
}

return mh, p, nil
}

func (ch *channel) recvFD(files *FileList) error {
var (
fds []int
err error
)

switch t := ch.conn.(type) {
case FdReceiver:
fds, err = t.Recvfd()
if err != nil {
return err
}
case unixReader:
oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the magic number for fd messages, as I recall.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of the message?

defer ch.putmbuf(oob)

_, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob)
if err != nil {
return err
}

ls, err := unix.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return fmt.Errorf("error parsing socket controll message: %w", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

controll -> control

}

for _, m := range ls {
fdsTemp, err := unix.ParseUnixRights(&m)
if err != nil {
return fmt.Errorf("error parsing unix rights message: %w", err)
}
fds = append(fds, fdsTemp...)
}
default:
return fmt.Errorf("receiving file descriptors is not supported on the transport")
}

if len(files.List) != len(fds) {
return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List))
}
for i, fd := range fds {
files.List[i].Fileno = int64(fd)
}
return nil
}

func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error {
fds := make([]int, len(files.List))

for i, f := range files.List {
fds[i] = int(f.Fileno)
}

switch t := ch.conn.(type) {
case unixWriter:
// Must send at least a single byte over unix sockets for the ancillary data to be accepted.
_, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil)
return err
case FdSender:
return t.SendFd(fds)
default:
return fmt.Errorf("sending file descriptors is not supported on the transport")
}
}

func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
Expand Down Expand Up @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client
type FdReceiver interface {
Recvfd() ([]int, error)
}

// FdSender is an interface used that the transport may implement to send file descriptors to the server.
type FdSender interface {
SendFd([]int) error
}

type unixReader interface {
ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error)
}

type unixWriter interface {
WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error)
}
Comment on lines +233 to +239
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some doc comments? It is hard to understand the roles of p, b, oob.

79 changes: 68 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
}

type callRequest struct {
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
files *FileList
}

func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) {
ls := make([]*File, len(files))
for i, f := range files {
ls[i] = &File{
Name: f.Name(),
Fileno: int64(f.Fd()),
}
}

resp := &Response{}
fl := &FileList{List: ls}
if err := c.dispatch(ctx, nil, resp, fl); err != nil {
return nil, err
}
if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
return nil, status.ErrorProto(resp.Status)
}

fl.Reset()

if err := c.codec.Unmarshal(resp.Payload, fl); err != nil {
return nil, err
}

fds := make([]int64, len(fl.List))
for i, f := range fl.List {
fds[i] = f.Fileno
}
return fds, nil
}

func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
Expand Down Expand Up @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error {
return c.dispatch(ctx, req, resp, nil)
}); err != nil {
return err
}

Expand All @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
return nil
}

func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error {
errs := make(chan error, 1)
call := &callRequest{
ctx: ctx,
req: req,
resp: resp,
errs: errs,
ctx: ctx,
req: req,
resp: resp,
errs: errs,
files: files,
}

select {
Expand Down Expand Up @@ -270,13 +305,35 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
var (
data interface{}
mt messageType
)

switch {
case call.files != nil:
data = call.files
mt = messageTypeFileDescriptor
case call.req != nil:
data = call.req
mt = messageTypeRequest
}

if err := c.send(streamID, mt, data); err != nil {
call.errs <- err
continue
}

waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids

if call.files != nil {
if err := c.channel.sendFd(streamID, mt, call.files); err != nil {
call.errs <- err
continue
}
}

case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
Expand Down
146 changes: 146 additions & 0 deletions fd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright The containerd Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ttrpc

import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"testing"
"time"
)

func TestSendRecvFd(t *testing.T) {
var (
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute))
addr, listener = newTestListener(t)
)

defer cancel()

// Spin up an out of process ttrpc server
if err := listenerCmd(ctx, t.Name(), listener); err != nil {
t.Fatal(err)
}

var (
client, cleanup = newTestClient(t, addr)

tclient = testFdClient{client}
)
defer cleanup()

r, w, err := os.Pipe()
if err != nil {
t.Fatal(err, "error creating test pipe")
}
defer r.Close()

type readResp struct {
buf []byte
err error
}

expect := []byte("hello")

chResp := make(chan readResp, 1)
go func() {
buf := make([]byte, len(expect))
_, err := io.ReadFull(r, buf)
chResp <- readResp{buf, err}
}()

if err := tclient.Test(ctx, w); err != nil {
t.Fatal(err)
}

select {
case <-ctx.Done():
t.Fatal(ctx.Err())
case resp := <-chResp:
if resp.err != nil {
t.Error(err)
}
if !bytes.Equal(resp.buf, expect) {
t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(expect), string(resp.buf))
}
}
}

type testFdPayload struct {
Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"`
}

func (r *testFdPayload) Reset() { *r = testFdPayload{} }
func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) }
func (r *testFdPayload) ProtoMessage() {}

type testingServerFd struct {
respData []byte
}

func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error {
for i, fd := range req.Fds {
f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i))
go func() {
f.Write(s.respData)
f.Close()
}()
}

return nil
}

type testFdClient struct {
client *Client
}

func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error {
fds, err := c.client.Sendfd(ctx, files)
if err != nil {
return fmt.Errorf("error sending fds: %w", err)
}

tp := testFdPayload{}
return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp)
}

func handleTestSendRecvFd(l net.Listener) error {
s, err := NewServer()
if err != nil {
return err
}
testImpl := &testingServerFd{respData: []byte("hello")}

s.Register("Test", map[string]Method{
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
req := &testFdPayload{}

if err := unmarshal(req); err != nil {
return nil, err
}

return &testFdPayload{}, testImpl.Test(ctx, req)
},
})

return s.Serve(context.TODO(), l)
}
Loading