diff --git a/ccmlib/common.py b/ccmlib/common.py index fc770879..0bf62072 100644 --- a/ccmlib/common.py +++ b/ccmlib/common.py @@ -3,6 +3,7 @@ # import fnmatch +import fcntl import os import platform import re @@ -14,7 +15,7 @@ import time import tempfile import logging -from typing import Callable +from typing import Callable, Optional, TextIO, Union import yaml from boto3.session import Session @@ -65,6 +66,145 @@ class UnavailableSocketError(CCMError): pass +class LockFile: + """ + A class to create filesystem-based lockfiles that are unlocked upon + process termination. + + This class uses locking mechanisms provided by kernel. Currently, + that's `fcntl.flock`, due to its more intuitive semantics, but if + compatibility becomes a problem, it should be possible to switch to + `fcntl.lockf` or use platform-specific functions. + + Lockfile stores PID of owning process and a "status" - which is an + arbitrary string that owning process can set. Format of the file is: + PID;status + + Lockfile is not removed when unlocked, nor is its content cleared - + mere presence of a lockfile does not mean that lock is actually taken. + + Lockfiles preserve status between subsequent `acquire`s. If a process executes + this code: + ``` + lf = LockFile('lockfile') + lf.acquire() + lf.write_status('abc') + lf.release() + ``` + then the following code executed later by the same or different process + will pass without assertion error: + ``` + lf = LockFile('lockfile') + lf.acquire() + assert lf.read_status() == 'abc' + lf.release() + ``` + + This lock is not reentrant. If you already own it, and try to lock it again, + assertion error will be raised. In other words, the following code is incorrect: + ``` + lf = LockFile('lockfile') + lf.acquire() + lf.acquire() # Assertion error will be thrown + ``` + + This class is not thread safe. + + LockFile class also supports context management protocol, but because + currently it is only used in ccm to prevent concurrent downloads, + logging messages in __enter__ are specific to this use case. + If this class is ever needed somewhere else, this can be changed, + either by changing messages or the API. + + Attributes + ---------- + _filename: str | bytes | os.PathLike + Path to a lockfile - used for logging + _file: TextIO + File handle which will be used to take a lock. + _locked: bool + True if lock is currently acquired by this object. + """ + _filename: Union[str, bytes, os.PathLike] + _file: TextIO + _locked: bool + + def __init__(self, filename: Union[str, bytes, os.PathLike]): + # We use append because: + # - if a file doesn't exist, we need to create it + # - we don't want to truncate existing file + # - we want RW access to file + # 'a+' is the only mode that satisfies all of this. + self._filename = filename + self._file = open(filename, 'a+') + self._locked = False + + def acquire(self, blocking=True) -> (bool, Optional[int]): + """Tries to take a lock. + If `blocking` parameter is `True` (default), it will wait indefinitely. + If it's false, + If it fails, it returns PID of the process + that currently owns this lock. + """ + assert not self._locked + + flags = fcntl.LOCK_EX if blocking else fcntl.LOCK_EX | fcntl.LOCK_NB + try: + fcntl.flock(self._file, flags) + except OSError: + (blocking_pid, _) = self.read_contents() + return False, blocking_pid + else: + self._locked = True + old_status = self.read_status() + self.write_status(old_status or '') + return True, None + + def release(self): + assert self._locked + fcntl.flock(self._file, fcntl.LOCK_UN) + self._locked = False + + def write_status(self, new_status: str): + assert self._locked + self._file.seek(0, 0) + self._file.truncate() + self._file.write(f'{os.getpid()};{new_status}') + self._file.flush() + + def read_contents(self) -> (Optional[int], Optional[str]): + """Reads the lockfile and returns pair + (pid of owning process, last status) + """ + self._file.seek(0, 0) + file_data = self._file.read() + try: + (blocking_pid, old_status) = file_data.split(';', 1) + blocking_pid = int(blocking_pid) + return blocking_pid, old_status + except: + return None, None + + def read_status(self) -> Optional[str]: + return self.read_contents()[1] + + def __enter__(self) -> 'LockFile': + success, blocking_pid_opt = self.acquire(blocking=False) + if success: + return self + print(f"Another download running into '{os.path.dirname(self._filename)}', " + f"by process '{blocking_pid_opt}'. Waiting for parallel downloading to finish. " + f"If process '{blocking_pid_opt}' got stuck, kill it in order to continue the operation here.") + + if not wait_for(func=lambda: self.acquire(blocking=False)[0], timeout=3600): + raise TimeoutError(f"Relocatables download still runs in parallel from another test after 60 min. " + f"Placeholder file still locked: {self._filename}") + return self + + def __exit__(self, *args): + self.release() + + def get_default_path(): if CCM_CONFIG_DIR in os.environ and os.environ[CCM_CONFIG_DIR]: default_path = os.environ[CCM_CONFIG_DIR] @@ -521,44 +661,37 @@ def wait_for(func: Callable, timeout: int, first: float = 0.0, step: float = 1.0 return False -def wait_for_parallel_download_finish(placeholder_file): - if not wait_for(func=lambda: not os.path.exists(placeholder_file), timeout=3600): - raise TimeoutError(f"Relocatables download still runs in parallel from another test after 60 min. " - f"Placeholder file exists: {placeholder_file}") - - def validate_install_dir(install_dir): if install_dir is None: raise ArgumentError('Undefined installation directory') # If relocatables download is running in parallel from another test, the install_dir exists with placehoslder file # in the folder. Once it will be downloaded and installed, this file will be removed. - wait_for_parallel_download_finish(placeholder_file=os.path.join(install_dir, DOWNLOAD_IN_PROGRESS_FILE)) - - # Windows requires absolute pathing on installation dir - abort if specified cygwin style - if is_win(): - if ':' not in install_dir: - raise ArgumentError('%s does not appear to be a cassandra or dse installation directory. Please use absolute pathing (e.g. C:/cassandra.' % install_dir) + with LockFile(os.path.join(install_dir, DOWNLOAD_IN_PROGRESS_FILE)): + # Windows requires absolute pathing on installation dir - abort if specified cygwin style + if is_win(): + if ':' not in install_dir: + raise ArgumentError(f'{install_dir} does not appear to be a cassandra or dse installation directory. Please use absolute pathing (e.g. C:/cassandra.') - bin_dir = os.path.join(install_dir, BIN_DIR) - if isScylla(install_dir): - install_dir, mode = scylla_extract_install_dir_and_mode(install_dir) - bin_dir = install_dir - conf_dir = os.path.join(install_dir, SCYLLA_CONF_DIR) - elif isDse(install_dir): - conf_dir = os.path.join(install_dir, DSE_CASSANDRA_CONF_DIR) - elif isOpscenter(install_dir): - conf_dir = os.path.join(install_dir, OPSCENTER_CONF_DIR) - else: - conf_dir = os.path.join(install_dir, CASSANDRA_CONF_DIR) - cnd = os.path.exists(bin_dir) - cnd = cnd and os.path.exists(conf_dir) - if isScylla(install_dir): - cnd = os.path.exists(os.path.join(conf_dir, SCYLLA_CONF)) - elif not isOpscenter(install_dir): - cnd = cnd and os.path.exists(os.path.join(conf_dir, CASSANDRA_CONF)) - if not cnd: - raise ArgumentError('%s does not appear to be a cassandra or dse installation directory' % install_dir) + bin_dir = os.path.join(install_dir, BIN_DIR) + if isScylla(install_dir): + install_dir, mode = scylla_extract_install_dir_and_mode(install_dir) + bin_dir = install_dir + conf_dir = os.path.join(install_dir, SCYLLA_CONF_DIR) + elif isDse(install_dir): + conf_dir = os.path.join(install_dir, DSE_CASSANDRA_CONF_DIR) + elif isOpscenter(install_dir): + conf_dir = os.path.join(install_dir, OPSCENTER_CONF_DIR) + else: + conf_dir = os.path.join(install_dir, CASSANDRA_CONF_DIR) + cnd = os.path.exists(bin_dir) + cnd = cnd and os.path.exists(conf_dir) + if isScylla(install_dir): + cnd = os.path.exists(os.path.join(conf_dir, SCYLLA_CONF)) + elif not isOpscenter(install_dir): + cnd = cnd and os.path.exists(os.path.join(conf_dir, CASSANDRA_CONF)) + if not cnd: + raise ArgumentError(f'{install_dir} does not appear to be a cassandra or dse installation directory') def check_socket_available(itf): diff --git a/ccmlib/scylla_repository.py b/ccmlib/scylla_repository.py index d136dacb..7300952b 100644 --- a/ccmlib/scylla_repository.py +++ b/ccmlib/scylla_repository.py @@ -26,7 +26,7 @@ from ccmlib.common import ( ArgumentError, CCMError, get_default_path, rmdirs, validate_install_dir, get_scylla_version, aws_bucket_ls, - DOWNLOAD_IN_PROGRESS_FILE, wait_for_parallel_download_finish, print_if_standalone) + DOWNLOAD_IN_PROGRESS_FILE, print_if_standalone, LockFile) from ccmlib.utils.download import download_file, download_version_from_s3 from ccmlib.utils.version import parse_version @@ -279,19 +279,15 @@ def setup(version, verbose=True): # Give a chance not to start few downloads in the exactly same second time.sleep(random.randint(0, 5)) - # If another parallel downloading has been started already, wait while it will be completed - if download_in_progress_file.exists(): - print(f"Another download running into '{version_dir}'. Waiting for parallel downloading finished") - wait_for_parallel_download_finish(placeholder_file=download_in_progress_file.absolute()) - else: - try: - os.makedirs(version_dir) - except FileExistsError as exc: - # If parallel process created the folder first, let to the parallel download to finish - print(f"Another download running into '{version_dir}'. Waiting for parallel downloading finished") - wait_for_parallel_download_finish(placeholder_file=download_in_progress_file.absolute()) - else: - download_in_progress_file.touch() + os.makedirs(version_dir, exist_ok=True) + with LockFile(download_in_progress_file) as f: + if f.read_status() != 'done': + # First ensure that we are working on a clean directory + # This prevents lockfile deletion by download_packages, as it doesn't have to clean the directory. + for p in Path(version_dir).iterdir(): + if p.name != DOWNLOAD_IN_PROGRESS_FILE: + shutil.rmtree(p) + try: package_version, packages = download_packages(version_dir=version_dir, packages=packages, s3_url=s3_url, scylla_product=scylla_product, version=version, verbose=verbose) @@ -308,7 +304,6 @@ def setup(version, verbose=True): else: raise - download_in_progress_file.touch() args = dict(install_dir=os.path.join(version_dir, CORE_PACKAGE_DIR_NAME), target_dir=version_dir, package_version=package_version) @@ -319,7 +314,7 @@ def setup(version, verbose=True): else: run_scylla_install_script(**args) print(f"Completed to install Scylla in the folder '{version_dir}'") - download_in_progress_file.unlink() + f.write_status('done') scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS', '') scylla_manager_package = os.environ.get('SCYLLA_MANAGER_PACKAGE') @@ -349,9 +344,8 @@ def download_packages(version_dir, packages, s3_url, scylla_product, version, ve if packages.scylla_unified_package: package_version = download_version(version=version, verbose=verbose, url=packages.scylla_unified_package, target_dir=tmp_download, unified=True) - shutil.rmtree(version_dir) target_dir = Path(version_dir) / CORE_PACKAGE_DIR_NAME - target_dir.parent.mkdir(parents=True, exist_ok=True) + target_dir.parent.mkdir(parents=False, exist_ok=True) shutil.move(tmp_download, target_dir) else: package_version = download_version(version=version, verbose=verbose, url=packages.scylla_package, @@ -363,7 +357,6 @@ def download_packages(version_dir, packages, s3_url, scylla_product, version, ve download_version(version=version, verbose=verbose, url=packages.scylla_jmx_package, target_dir=os.path.join(tmp_download, 'scylla-jmx')) - shutil.rmtree(version_dir) shutil.move(tmp_download, version_dir) return package_version, packages diff --git a/tests/test_common.py b/tests/test_common.py index 0af428fa..0a1c45d7 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,4 +1,6 @@ -from ccmlib.common import scylla_extract_mode +from ccmlib.common import scylla_extract_mode, LockFile +import tempfile +import os def test_scylla_extract_mode(): @@ -23,3 +25,44 @@ def test_scylla_extract_mode(): '30ce52b2e.tar.gz') == 'release' assert scylla_extract_mode('url=https://s3.amazonaws.com/downloads.scylladb.com/downloads/scylla-enterprise/' 'relocatable/scylladb-2022.1/scylla-enterprise-debug-aarch64-package-2022.1.rc0.0.20220331.f3ee71fba.tar.gz') == 'debug' + + +# Those tests assume that LockFile uses fcntl.flock +# If it switches to anything else, the tests need to be adjusted. + +def test_lockfile_basic(): + f, path = tempfile.mkstemp(prefix='ccm-test-lockfile') + lf = LockFile(path) + assert lf.acquire(blocking=True) == (True, None) + + assert lf.read_contents() == (os.getpid(), '') + lf.write_status('abc') + assert lf.read_contents() == (os.getpid(), 'abc') + + lf.release() + + +def test_lockfile_locks(): + f, path = tempfile.mkstemp(prefix='ccm-test-lockfile') + lf1 = LockFile(path) + lf2 = LockFile(path) + with lf1: + assert lf2.acquire(blocking=False) == (False, os.getpid()) + assert lf2.acquire(blocking=False) == (True, None) + assert lf1.acquire(blocking=False) == (False, os.getpid()) + lf2.release() + + +def test_lockfile_retain_status_by_default(): + f, path = tempfile.mkstemp(prefix='ccm-test-lockfile') + lf = LockFile(path) + + assert lf.acquire(blocking=False)[0] is True + lf.write_status('some_status_1') + assert lf.read_status() == 'some_status_1' + lf.release() + + # Status should be retained from previous lock. + assert lf.acquire(blocking=False)[0] is True + assert lf.read_status() == 'some_status_1' + lf.release()