diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index 493e3ea02b5c..d5b2cf500616 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -800,7 +800,7 @@ def restore_from_path(self, snap_dir: Path, **kwargs): return self.restore_from_snapshot(Snapshot.load_from(snap_dir), **kwargs) @lru_cache - def ssh_iface(self, iface_idx=0): + def ssh_iface(self, iface_idx=0, timeout=None): """Return a cached SSH connection on a given interface id.""" guest_ip = list(self.iface.values())[iface_idx]["iface"].guest_ip self.ssh_key = Path(self.ssh_key) @@ -812,9 +812,9 @@ def ssh_iface(self, iface_idx=0): ) @property - def ssh(self): + def ssh(self, timeout=None): """Return a cached SSH connection on the 1st interface""" - return self.ssh_iface(0) + return self.ssh_iface(0, timeout) class MicroVMFactory: diff --git a/tests/framework/utils.py b/tests/framework/utils.py index a6d796135131..2908081ed70a 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -448,7 +448,7 @@ def get_free_mem_ssh(ssh_connection): raise Exception("Available memory not found in `/proc/meminfo") -def run_cmd_sync(cmd, ignore_return_code=False, no_shell=False, cwd=None): +def run_cmd_sync(cmd, ignore_return_code=False, no_shell=False, cwd=None, timeout=None): """ Execute a given command. @@ -469,7 +469,7 @@ def run_cmd_sync(cmd, ignore_return_code=False, no_shell=False, cwd=None): ) # Capture stdout/stderr - stdout, stderr = proc.communicate() + stdout, stderr = proc.communicate(timeout=timeout) output_message = f"\n[{proc.pid}] Command:\n{cmd}" # Append stdout/stderr to the output message @@ -493,7 +493,7 @@ def run_cmd_sync(cmd, ignore_return_code=False, no_shell=False, cwd=None): return CommandReturn(proc.returncode, stdout.decode(), stderr.decode()) -def run_cmd(cmd, ignore_return_code=False, no_shell=False, cwd=None): +def run_cmd(cmd, ignore_return_code=False, no_shell=False, cwd=None, timeout=None): """ Run a command using the sync function that logs the output. @@ -503,7 +503,11 @@ def run_cmd(cmd, ignore_return_code=False, no_shell=False, cwd=None): :returns: tuple of (return code, stdout, stderr) """ return run_cmd_sync( - cmd=cmd, ignore_return_code=ignore_return_code, no_shell=no_shell, cwd=cwd + cmd=cmd, + ignore_return_code=ignore_return_code, + no_shell=no_shell, + cwd=cwd, + timeout=timeout, ) diff --git a/tests/host_tools/network.py b/tests/host_tools/network.py index 3ef2332e748c..b1a965af4c90 100644 --- a/tests/host_tools/network.py +++ b/tests/host_tools/network.py @@ -89,7 +89,7 @@ def _init_connection(self): if ecode != 0: raise ConnectionError - def run(self, cmd_string): + def run(self, cmd_string, timeout=None): """Execute the command passed as a string in the ssh context.""" return self._exec( [ @@ -97,10 +97,11 @@ def run(self, cmd_string): *self.options, f"{self.user}@{self.host}", cmd_string, - ] + ], + timeout ) - def _exec(self, cmd): + def _exec(self, cmd, timeout=None): """Private function that handles the ssh client invocation.""" # TODO: If a microvm runs in a particular network namespace, we have to @@ -111,7 +112,7 @@ def _exec(self, cmd): if self.netns_file_path is not None: ctx = Namespace(self.netns_file_path, "net") with ctx: - return utils.run_cmd(cmd, ignore_return_code=True) + return utils.run_cmd(cmd, ignore_return_code=True, timeout=timeout) def mac_from_ip(ip_address): diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index 517cededaee5..2f65a7c92cba 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -3,6 +3,7 @@ """Tests for guest-side operations on /balloon resources.""" import logging +from subprocess import TimeoutExpired import time import pytest @@ -69,14 +70,19 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): # so that we avoid having the SSH connections hanging due to the OOM # killer kicking in. cmd = f"/usr/local/bin/fillmem {amount_mib} &" - exit_code, stdout, stderr = ssh_connection.run(cmd) - # add something to the logs for troubleshooting - if exit_code != 0: - logger.error("while running: %s", cmd) - logger.error("stdout: %s", stdout) - logger.error("stderr: %s", stderr) + try: + exit_code, stdout, stderr = ssh_connection.run(cmd, timeout=1.0) + # add something to the logs for troubleshooting + if exit_code != 0: + logger.error("while running: %s", cmd) + logger.error("stdout: %s", stdout) + logger.error("stderr: %s", stderr) + + cmd = "cat /tmp/fillmem_output.txt" + except TimeoutExpired: + # It's ok if this expires + pass - cmd = "cat /tmp/fillmem_output.txt" time.sleep(5)