Skip to content

Commit

Permalink
refactor getPacket to separate per-response-packet-types that even un…
Browse files Browse the repository at this point in the history
…pack the relevant field that we're actually interested in
  • Loading branch information
puellanivis committed Nov 15, 2024
1 parent 6366220 commit 98fad75
Showing 1 changed file with 71 additions and 35 deletions.
106 changes: 71 additions & 35 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,14 +383,14 @@ type respPacket[PKT any] interface {
sshfx.Packet
}

func getPacket[PKT any, P respPacket[PKT]](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) {
func getPacket[RESP respPacket[PKT], PKT any](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) {
raw, err := cl.conn.send(ctx, cancel, req)
if err != nil {
return nil, err
}
defer cl.conn.returnRaw(raw)

var resp P
var resp RESP

switch raw.PacketType {
case resp.Type():
Expand All @@ -414,6 +414,38 @@ func getPacket[PKT any, P respPacket[PKT]](ctx context.Context, cancel <-chan st
}
}

func (cl *Client) getPath(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) {
resp, err := getPacket[*sshfx.PathPseudoPacket](ctx, cancel, cl, req)
if err != nil {
return "", err
}
return resp.Path, nil
}

func (cl *Client) getHandle(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) {
resp, err := getPacket[*sshfx.HandlePacket](ctx, cancel, cl, req)
if err != nil {
return "", err
}
return resp.Handle, nil
}

func (cl *Client) getNames(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) ([]*sshfx.NameEntry, error) {
resp, err := getPacket[*sshfx.NamePacket](ctx, cancel, cl, req)
if err != nil {
return nil, err
}
return resp.Entries, nil
}

func (cl *Client) getAttrs(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (*sshfx.Attributes, error) {
resp, err := getPacket[*sshfx.AttrsPacket](ctx, cancel, cl, req)
if err != nil {
return nil, err
}
return &resp.Attrs, nil
}

func statusToError(status *sshfx.StatusPacket, okExpected bool) error {
switch status.StatusCode {
case sshfx.StatusOK:
Expand Down Expand Up @@ -689,35 +721,39 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error {
func (cl *Client) Remove(name string) error {
ctx := context.Background()

err := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{
errFile := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{
Path: name,
})
if err == nil {
if errFile == nil {
return nil
}

err1 := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{
errDir := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{
Path: name,
})
if err1 == nil {
if errDir == nil {
return nil
}

// Both failed: figure out which error to return.
if err != err1 {
attrs, err2 := getPacket[sshfx.AttrsPacket](ctx, nil, cl, &sshfx.StatPacket{
Path: name,
})
if err2 != nil {
err = err2
} else {
if perm, ok := attrs.Attrs.GetPermissions(); ok && perm.IsDir() {
err = err1
}
}

if errFile == errDir {
// If they are the same error, then just return that.
return wrapPathError("remove", name, errFile)
}

attrs, err := cl.getAttrs(ctx, nil, &sshfx.StatPacket{
Path: name,
})
if err != nil {
return wrapPathError("remove", name, err)
}

if perm, ok := attrs.GetPermissions(); ok && perm.IsDir() {
return wrapPathError("remove", name, errDir)
}

return wrapPathError("remove", name, err)
return wrapPathError("remove", name, errFile)
}

func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error {
Expand Down Expand Up @@ -788,29 +824,29 @@ func (cl *Client) Chtimes(name string, atime, mtime time.Time) error {
// This is useful for converting path names containing ".." components,
// or relative pathnames without a leading slash into absolute paths.
func (cl *Client) RealPath(name string) (string, error) {
pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.RealPathPacket{
path, err := cl.getPath(context.Background(), nil, &sshfx.RealPathPacket{
Path: name,
})
if err != nil {
return "", wrapPathError("realpath", name, err)
}

return pkt.Path, nil
return path, nil
}

// ReadLink returns the destination of the named symbolic link.
//
// The client cannot guarantee any specific way that a server handles a relative link destination.
// That is, you may receive a relative link destination, one that has been converted to an absolute path.
func (cl *Client) ReadLink(name string) (string, error) {
pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.ReadLinkPacket{
path, err := cl.getPath(context.Background(), nil, &sshfx.ReadLinkPacket{
Path: name,
})
if err != nil {
return "", wrapPathError("readlink", name, err)
}

return pkt.Path, nil
return path, nil
}

// Rename renames (moves) oldpath to newpath.
Expand Down Expand Up @@ -916,7 +952,7 @@ func (cl *Client) ReadDirContext(ctx context.Context, name string) ([]fs.DirEntr
// Stat returns a FileInfo describing the named file.
// If the file is a symbolic link, the returned FileInfo describes the link's target.
func (cl *Client) Stat(name string) (fs.FileInfo, error) {
pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.StatPacket{
attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.StatPacket{
Path: name,
})
if err != nil {
Expand All @@ -925,7 +961,7 @@ func (cl *Client) Stat(name string) (fs.FileInfo, error) {

return &sshfx.NameEntry{
Filename: name,
Attrs: pkt.Attrs,
Attrs: *attrs,
}, nil
}

Expand All @@ -935,7 +971,7 @@ func (cl *Client) Stat(name string) (fs.FileInfo, error) {
//
// The description returned may have server specific caveats and special cases that cannot be covered here.
func (cl *Client) LStat(name string) (fs.FileInfo, error) {
pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.LStatPacket{
attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.LStatPacket{
Path: name,
})
if err != nil {
Expand All @@ -944,7 +980,7 @@ func (cl *Client) LStat(name string) (fs.FileInfo, error) {

return &sshfx.NameEntry{
Filename: name,
Attrs: pkt.Attrs,
Attrs: *attrs,
}, nil
}

Expand Down Expand Up @@ -1022,7 +1058,7 @@ type Dir struct {
//
// The semantics of SSH_FX_OPENDIR is such that the associated file handle is in a read-only mode.
func (cl *Client) OpenDir(name string) (*Dir, error) {
pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenDirPacket{
handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenDirPacket{
Path: name,
})
if err != nil {
Expand All @@ -1034,7 +1070,7 @@ func (cl *Client) OpenDir(name string) (*Dir, error) {
name: name,
}

d.handle.init(pkt.Handle)
d.handle.init(handle)

return d, nil
}
Expand Down Expand Up @@ -1085,7 +1121,7 @@ func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] {
return
}

pkt, err := getPacket[sshfx.NamePacket](ctx, closed, d.cl, &sshfx.ReadDirPacket{
entries, err := d.cl.getNames(ctx, closed, &sshfx.ReadDirPacket{
Handle: handle,
})
if err != nil {
Expand All @@ -1095,10 +1131,10 @@ func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] {
return
}

for i, entry := range pkt.Entries {
for i, entry := range entries {
if !yield(entry, nil) {
// Early break, save the remaining entries we got for maybe later.
d.entries = append(d.entries, pkt.Entries[i+1:]...)
d.entries = append(d.entries, entries[i+1:]...)
return
}
}
Expand Down Expand Up @@ -1276,7 +1312,7 @@ func (cl *Client) Create(name string) (*File, error) {
// Note well: since all Write operations are down through an offset-specifying operation,
// the OpenFlagAppend flag is currently ignored.
func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, error) {
pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenPacket{
handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenPacket{
Filename: name,
PFlags: toPortableFlags(flag),
Attrs: sshfx.Attributes{
Expand All @@ -1293,7 +1329,7 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro
name: name,
}

f.handle.init(pkt.Handle)
f.handle.init(handle)

return f, nil
}
Expand Down Expand Up @@ -1391,7 +1427,7 @@ func (f *File) Stat() (fs.FileInfo, error) {
return nil, f.wrapErr("fstat", err)
}

pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), closed, f.cl, &sshfx.FStatPacket{
attrs, err := f.cl.getAttrs(context.Background(), closed, &sshfx.FStatPacket{
Handle: handle,
})
if err != nil {
Expand All @@ -1400,7 +1436,7 @@ func (f *File) Stat() (fs.FileInfo, error) {

return &sshfx.NameEntry{
Filename: f.name,
Attrs: pkt.Attrs,
Attrs: *attrs,
}, nil
}

Expand Down

0 comments on commit 98fad75

Please sign in to comment.