Skip to content

Commit

Permalink
Merge pull request #584 from peterverraedt/configure-maxTxPacket
Browse files Browse the repository at this point in the history
Add WithMaxTxPacket server option
  • Loading branch information
puellanivis authored Apr 26, 2024
2 parents 06342e8 + c1f47ba commit 5494656
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 40 deletions.
2 changes: 1 addition & 1 deletion packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4

func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
dataLen := p.Len
if dataLen > maxTxPacket {
dataLen = maxTxPacket
Expand Down
32 changes: 25 additions & 7 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"sync"
)

var maxTxPacket uint32 = 1 << 15
const defaultMaxTxPacket uint32 = 1 << 15

// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
Expand All @@ -28,6 +28,7 @@ type RequestServer struct {
pktMgr *packetManager

startDirectory string
maxTxPacket uint32

mu sync.RWMutex
handleCount int
Expand Down Expand Up @@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption {
}
}

// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithRSMaxTxPacket(size uint32) RequestServerOption {
return func(rs *RequestServer) {
if size < defaultMaxTxPacket {
return
}

rs.maxTxPacket = size
}
}

// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
Expand All @@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
pktMgr: newPktMgr(svrConn),

startDirectory: "/",
maxTxPacket: defaultMaxTxPacket,

openRequests: make(map[string]*Request),
}
Expand Down Expand Up @@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
Expand All @@ -272,32 +290,32 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpExtendedPacketPosixRename:
request := &Request{
Method: "PosixRename",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
Expand Down
24 changes: 12 additions & 12 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,14 @@ func (r *Request) transferError(err error) {
}

// called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pkt, alloc, orderID)
return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
case "Put":
return fileput(handlers.FilePut, r, pkt, alloc, orderID)
return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Open":
return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt)
case "List":
Expand Down Expand Up @@ -392,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
}

// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}

data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)

n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
Expand All @@ -414,28 +414,28 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde
}

// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}

data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)

_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}

// wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
}

switch p := pkt.(type) {
case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)

n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
Expand All @@ -461,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
}

// file data for additional read/write packets
func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) {
case *sshFxpReadPacket:
return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length
}
Expand Down
22 changes: 11 additions & 11 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestRequestGet(t *testing.T) {
for i, txt := range []string{"file-", "data."} {
pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
Offset: uint64(i * 5), Len: 5}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
dpkt := rpkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i))
assert.Equal(t, string(dpkt.Data), txt)
Expand All @@ -162,7 +162,7 @@ func TestRequestCustomError(t *testing.T) {
pkt := fakePacket{myid: 1}
cmdErr := errors.New("stat not supported")
handlers.returnError(cmdErr)
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr))
}

Expand All @@ -173,11 +173,11 @@ func TestRequestPut(t *testing.T) {
request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
Data: []byte("file-")}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5,
Data: []byte("data.")}
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
Expand All @@ -186,19 +186,19 @@ func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)

handlers.returnError(errTest)
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest))
}

func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
spkt, ok := rpkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
Expand All @@ -215,13 +215,13 @@ func TestRequestInfoList(t *testing.T) {
assert.Equal(t, hpkt.Handle, "1")
}
pkt = fakePacket{myid: 2}
request.call(handlers, pkt, nil, 0)
request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
}
func TestRequestInfoReadlink(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Readlink")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
npkt, ok := rpkt.(*sshFxpNamePacket)
if assert.True(t, ok) {
assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0])
Expand All @@ -234,7 +234,7 @@ func TestOpendirHandleReuse(t *testing.T) {
request := testRequest("Stat")
request.handle = "1"
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpStatResponse{}, rpkt)

request.Method = "List"
Expand All @@ -244,6 +244,6 @@ func TestOpendirHandleReuse(t *testing.T) {
hpkt := rpkt.(*sshFxpHandlePacket)
assert.Equal(t, hpkt.Handle, "1")
}
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpNamePacket{}, rpkt)
}
38 changes: 29 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Server struct {
openFilesLock sync.RWMutex
handleCount int
workDir string
maxTxPacket uint32
}

func (svr *Server) nextHandle(f *os.File) string {
Expand Down Expand Up @@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
maxTxPacket: defaultMaxTxPacket,
}

for _, o := range options {
Expand Down Expand Up @@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption {
}
}

// WithMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithMaxTxPacket(size uint32) ServerOption {
return func(s *Server) error {
if size < defaultMaxTxPacket {
return errors.New("size must be greater than or equal to 32768")
}

s.maxTxPacket = size

return nil
}
}

type rxPacket struct {
pktType fxp
pktBytes []byte
Expand Down Expand Up @@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error {
f, ok := s.getHandle(p.Handle)
if ok {
err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID)
data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
Expand Down Expand Up @@ -513,16 +533,16 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {

fs, err := p.unmarshalFileStat(p.Flags)

if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = os.Truncate(path, int64(fs.Size))
}
if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
}
if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}

Expand All @@ -541,16 +561,16 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {

fs, err := p.unmarshalFileStat(p.Flags)

if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = f.Truncate(int64(fs.Size))
}
if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
}
if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = f.Chown(int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
}
Expand Down

0 comments on commit 5494656

Please sign in to comment.