diff --git a/devlib/target.py b/devlib/target.py index 9f2a5c128..19f8a7281 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -668,10 +668,22 @@ async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False transfering multiple sources. """ - once = functools.lru_cache(maxsize=None) + def once(f): + cache = dict() + + @functools.wraps(f) + async def wrapper(path): + try: + return cache[path] + except KeyError: + x = await f(path) + cache[path] = x + return x + + return wrapper _target_cache = {} - def target_paths_kind(paths, as_root=False): + async def target_paths_kind(paths, as_root=False): def process(x): x = x.strip() if x == 'notexist': @@ -691,7 +703,7 @@ def process(x): ) for path in _paths ) - res = self.execute(cmd, as_root=as_root) + res = await self.execute.asyn(cmd, as_root=as_root) _target_cache.update(zip(_paths, map(process, res.split()))) return [ @@ -700,7 +712,7 @@ def process(x): ] _host_cache = {} - def host_paths_kind(paths, as_root=False): + async def host_paths_kind(paths, as_root=False): def path_kind(path): if os.path.isdir(path): return 'dir' @@ -727,47 +739,55 @@ def path_kind(path): src_excep = HostError src_path_kind = host_paths_kind - _dst_mkdir = once(self.makedirs) + _dst_mkdir = once(self.makedirs.asyn) dst_path_join = self.path.join dst_paths_kind = target_paths_kind - dst_remove_file = once(functools.partial(self.remove, as_root=as_root)) + + @once + async def dst_remove_file(path): + return await self.remove.asyn(path, as_root=as_root) elif action == 'pull': src_excep = TargetStableError src_path_kind = target_paths_kind - _dst_mkdir = once(functools.partial(os.makedirs, exist_ok=True)) + @once + async def _dst_mkdir(path): + return os.makedirs(path, exist_ok=True) dst_path_join = os.path.join dst_paths_kind = host_paths_kind - dst_remove_file = once(os.remove) + + @once + async def dst_remove_file(path): + return os.remove(path) else: raise ValueError('Unknown action "{}"'.format(action)) # Handle the case where path is None - def dst_mkdir(path): + async def dst_mkdir(path): if path: - _dst_mkdir(path) + await _dst_mkdir(path) - def rewrite_dst(src, dst): + async def rewrite_dst(src, dst): new_dst = dst_path_join(dst, os.path.basename(src)) - src_kind, = src_path_kind([src], as_root) + src_kind, = await src_path_kind([src], as_root) # Batch both checks to avoid a costly extra execute() - dst_kind, new_dst_kind = dst_paths_kind([dst, new_dst], as_root) + dst_kind, new_dst_kind = await dst_paths_kind([dst, new_dst], as_root) if src_kind == 'file': if dst_kind == 'dir': if new_dst_kind == 'dir': raise IsADirectoryError(new_dst) if new_dst_kind == 'file': - dst_remove_file(new_dst) + await dst_remove_file(new_dst) return new_dst else: return new_dst elif dst_kind == 'file': - dst_remove_file(dst) + await dst_remove_file(dst) return dst else: - dst_mkdir(os.path.dirname(dst)) + await dst_mkdir(os.path.dirname(dst)) return dst elif src_kind == 'dir': if dst_kind == 'dir': @@ -781,7 +801,7 @@ def rewrite_dst(src, dst): elif dst_kind == 'file': raise FileExistsError(dst_kind) else: - dst_mkdir(os.path.dirname(dst)) + await dst_mkdir(os.path.dirname(dst)) return dst else: raise FileNotFoundError(src) @@ -790,18 +810,19 @@ def rewrite_dst(src, dst): if not sources: raise src_excep('No file matching source pattern: {}'.format(pattern)) - if dst_paths_kind([dest]) != ['dir']: + if (await dst_paths_kind([dest])) != ['dir']: raise NotADirectoryError('A folder dest is required for multiple matches but destination is a file: {}'.format(dest)) + async def f(src): + return await rewrite_dst(src, dest) + mapping = await self.async_manager.map_concurrently(f, sources) + # TODO: since rewrite_dst() will currently return a different path for # each source, it will not bring anything. In order to be useful, # connections need to be able to understand that if the destination is # an empty folder, the source is supposed to be transfered into it with # the same basename. - return groupby_value({ - src: rewrite_dst(src, dest) - for src in sources - }) + return groupby_value(mapping) @asyn.asyncf @call_conn @@ -824,10 +845,11 @@ def do_push(sources, dest): if as_root: for sources, dest in mapping.items(): - for source in sources: + async def f(source): async with self._xfer_cache_path(source) as device_tempfile: do_push([source], device_tempfile) await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) + await self.async_manager.map_concurrently(f, sources) else: for sources, dest in mapping.items(): do_push(sources, dest) @@ -902,11 +924,13 @@ def do_pull(sources, dest): if via_temp: for sources, dest in mapping.items(): - for source in sources: + async def f(source): async with self._xfer_cache_path(source) as device_tempfile: - await self.execute.asyn("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=as_root) - await self.execute.asyn("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=as_root) + cp_cmd = f"{quote(self.busybox)} cp -rL -- {quote(source)} {quote(device_tempfile)}" + chmod_cmd = f"{quote(self.busybox)} chmod 0644 -- {quote(device_tempfile)}" + await self.execute.asyn(f"{cp_cmd} && {chmod_cmd}", as_root=as_root) do_pull([device_tempfile], dest) + await self.async_manager.map_concurrently(f, sources) else: for sources, dest in mapping.items(): do_pull(sources, dest) diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 9fe4c613a..81aa405c4 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -499,7 +499,18 @@ def _push_path(self, sftp, src, dst, callback=None): push(sftp, src, dst, callback) def _pull_file(self, sftp, src, dst, callback): - sftp.get(src, dst, callback=callback) + try: + sftp.get(src, dst, callback=callback) + except Exception as e: + # A file may have been created by Paramiko, but we want to clean + # that up, particularly if we tried to pull a folder and failed, + # otherwise this will make subsequent attempts at pulling the + # folder fail since the destination will exist. + try: + os.remove(dst) + except Exception: + pass + raise e def _pull_folder(self, sftp, src, dst, callback): os.makedirs(dst)