Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssh folder pull #707

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions devlib/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,10 +668,22 @@ async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False
transfering multiple sources.
"""

once = functools.lru_cache(maxsize=None)
def once(f):
cache = dict()

@functools.wraps(f)
async def wrapper(path):
try:
return cache[path]
except KeyError:
x = await f(path)
cache[path] = x
return x

return wrapper

_target_cache = {}
def target_paths_kind(paths, as_root=False):
async def target_paths_kind(paths, as_root=False):
def process(x):
x = x.strip()
if x == 'notexist':
Expand All @@ -691,7 +703,7 @@ def process(x):
)
for path in _paths
)
res = self.execute(cmd, as_root=as_root)
res = await self.execute.asyn(cmd, as_root=as_root)
_target_cache.update(zip(_paths, map(process, res.split())))

return [
Expand All @@ -700,7 +712,7 @@ def process(x):
]

_host_cache = {}
def host_paths_kind(paths, as_root=False):
async def host_paths_kind(paths, as_root=False):
def path_kind(path):
if os.path.isdir(path):
return 'dir'
Expand All @@ -727,47 +739,55 @@ def path_kind(path):
src_excep = HostError
src_path_kind = host_paths_kind

_dst_mkdir = once(self.makedirs)
_dst_mkdir = once(self.makedirs.asyn)
dst_path_join = self.path.join
dst_paths_kind = target_paths_kind
dst_remove_file = once(functools.partial(self.remove, as_root=as_root))

@once
async def dst_remove_file(path):
return await self.remove.asyn(path, as_root=as_root)
elif action == 'pull':
src_excep = TargetStableError
src_path_kind = target_paths_kind

_dst_mkdir = once(functools.partial(os.makedirs, exist_ok=True))
@once
async def _dst_mkdir(path):
return os.makedirs(path, exist_ok=True)
dst_path_join = os.path.join
dst_paths_kind = host_paths_kind
dst_remove_file = once(os.remove)

@once
async def dst_remove_file(path):
return os.remove(path)
else:
raise ValueError('Unknown action "{}"'.format(action))

# Handle the case where path is None
def dst_mkdir(path):
async def dst_mkdir(path):
if path:
_dst_mkdir(path)
await _dst_mkdir(path)

def rewrite_dst(src, dst):
async def rewrite_dst(src, dst):
new_dst = dst_path_join(dst, os.path.basename(src))

src_kind, = src_path_kind([src], as_root)
src_kind, = await src_path_kind([src], as_root)
# Batch both checks to avoid a costly extra execute()
dst_kind, new_dst_kind = dst_paths_kind([dst, new_dst], as_root)
dst_kind, new_dst_kind = await dst_paths_kind([dst, new_dst], as_root)

if src_kind == 'file':
if dst_kind == 'dir':
if new_dst_kind == 'dir':
raise IsADirectoryError(new_dst)
if new_dst_kind == 'file':
dst_remove_file(new_dst)
await dst_remove_file(new_dst)
return new_dst
else:
return new_dst
elif dst_kind == 'file':
dst_remove_file(dst)
await dst_remove_file(dst)
return dst
else:
dst_mkdir(os.path.dirname(dst))
await dst_mkdir(os.path.dirname(dst))
return dst
elif src_kind == 'dir':
if dst_kind == 'dir':
Expand All @@ -781,7 +801,7 @@ def rewrite_dst(src, dst):
elif dst_kind == 'file':
raise FileExistsError(dst_kind)
else:
dst_mkdir(os.path.dirname(dst))
await dst_mkdir(os.path.dirname(dst))
return dst
else:
raise FileNotFoundError(src)
Expand All @@ -790,18 +810,19 @@ def rewrite_dst(src, dst):
if not sources:
raise src_excep('No file matching source pattern: {}'.format(pattern))

if dst_paths_kind([dest]) != ['dir']:
if (await dst_paths_kind([dest])) != ['dir']:
raise NotADirectoryError('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))

async def f(src):
return await rewrite_dst(src, dest)
mapping = await self.async_manager.map_concurrently(f, sources)

# TODO: since rewrite_dst() will currently return a different path for
# each source, it will not bring anything. In order to be useful,
# connections need to be able to understand that if the destination is
# an empty folder, the source is supposed to be transfered into it with
# the same basename.
return groupby_value({
src: rewrite_dst(src, dest)
for src in sources
})
return groupby_value(mapping)

@asyn.asyncf
@call_conn
Expand All @@ -824,10 +845,11 @@ def do_push(sources, dest):

if as_root:
for sources, dest in mapping.items():
for source in sources:
async def f(source):
async with self._xfer_cache_path(source) as device_tempfile:
do_push([source], device_tempfile)
await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True)
await self.async_manager.map_concurrently(f, sources)
else:
for sources, dest in mapping.items():
do_push(sources, dest)
Expand Down Expand Up @@ -902,11 +924,13 @@ def do_pull(sources, dest):

if via_temp:
for sources, dest in mapping.items():
for source in sources:
async def f(source):
async with self._xfer_cache_path(source) as device_tempfile:
await self.execute.asyn("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=as_root)
await self.execute.asyn("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=as_root)
cp_cmd = f"{quote(self.busybox)} cp -rL -- {quote(source)} {quote(device_tempfile)}"
chmod_cmd = f"{quote(self.busybox)} chmod 0644 -- {quote(device_tempfile)}"
await self.execute.asyn(f"{cp_cmd} && {chmod_cmd}", as_root=as_root)
do_pull([device_tempfile], dest)
await self.async_manager.map_concurrently(f, sources)
else:
for sources, dest in mapping.items():
do_pull(sources, dest)
Expand Down
13 changes: 12 additions & 1 deletion devlib/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,18 @@ def _push_path(self, sftp, src, dst, callback=None):
push(sftp, src, dst, callback)

def _pull_file(self, sftp, src, dst, callback):
sftp.get(src, dst, callback=callback)
try:
sftp.get(src, dst, callback=callback)
except Exception as e:
# A file may have been created by Paramiko, but we want to clean
# that up, particularly if we tried to pull a folder and failed,
# otherwise this will make subsequent attempts at pulling the
# folder fail since the destination will exist.
try:
os.remove(dst)
except Exception:
pass
raise e

def _pull_folder(self, sftp, src, dst, callback):
os.makedirs(dst)
Expand Down