diff --git a/easybuild/tools/environment.py b/easybuild/tools/environment.py index 07f9ef6e50..e8976edea0 100644 --- a/easybuild/tools/environment.py +++ b/easybuild/tools/environment.py @@ -32,6 +32,7 @@ """ import copy import os +import contextlib from easybuild.base import fancylogger from easybuild.tools.build_log import EasyBuildError, dry_run_msg @@ -221,3 +222,67 @@ def sanitize_env(): # unset all $PYTHON* environment variables keys_to_unset = [key for key in os.environ if key.startswith('PYTHON')] unset_env_vars(keys_to_unset, verbose=False) + + +@contextlib.contextmanager +def wrap_env(override=None, prepend=None, append=None, sep=os.pathsep, strict=False): + """This function is a context manager that temporarily modifies environment variables. + It will override or prepend/append the values of the given dictionaries to the current environment and restore the + original environment when the context is exited. + + For path-like variables, a custom separator can be specified for each variable by passing a dictionary of strings + with the same keys as prepend and append. + If a key is not present in the sep dictionary, os.pathsep will be used unless strict is True, then an error + will be raised. + + Args: + override: A dictionary of environment variables to override. + prepend: A dictionary of environment variables to prepend to. + append: A dictionary of environment variables to append to. + sep: A string or a dictionary of strings to use as separator for each variable. + strict: If True, raise an error if a key is not present in the sep dictionary. + """ + if prepend is None: + prepend = {} + if append is None: + append = {} + if override is None: + override = {} + + path_keys = set(prepend.keys()) | set(append.keys()) + over_keys = set(override.keys()) + + duplicates = path_keys & over_keys + if duplicates: + raise EasyBuildError( + "The keys in override must not overlap with the keys in prepend or append: '%s'", + " ".join(duplicates) + ) + + orig = {} + for key in over_keys: + orig[key] = os.environ.get(key) + setvar(key, override[key]) + + for key in path_keys: + if isinstance(sep, dict): + if key not in sep: + if strict: + raise EasyBuildError( + "sep must be a dictionary of strings with keys for all keys in prepend and append" + ) + _sep = os.pathsep + else: + _sep = sep.get(key) + elif isinstance(sep, str): + _sep = sep + else: + raise EasyBuildError("sep must be a string or a dictionary of strings") + val = orig[key] = os.environ.get(key) + path = _sep.join(filter(None, [prepend.get(key, None), val, append.get(key, None)])) + setvar(key, path) + + try: + yield + finally: + restore_env_vars(orig) diff --git a/test/framework/environment.py b/test/framework/environment.py index f184c6b9c7..764a707f94 100644 --- a/test/framework/environment.py +++ b/test/framework/environment.py @@ -33,6 +33,7 @@ from unittest import TextTestRunner import easybuild.tools.environment as env +from easybuild.tools.build_log import EasyBuildError class EnvironmentTest(EnhancedTestCase): @@ -161,6 +162,91 @@ def test_sanitize_env(self): self.assertEqual(os.getenv('LD_PRELOAD'), None) + def test_wrap_env(self): + """Test wrap_env function.""" + + def reset_env(): + os.environ['TEST_VAR_1'] = '/bar:/foo' + os.environ['TEST_VAR_2'] = '/bar' + os.environ['TEST_VAR_3'] = '/foo' + + def check_env(): + self.assertEqual(os.getenv('TEST_VAR_1'), '/bar:/foo') + self.assertEqual(os.getenv('TEST_VAR_2'), '/bar') + self.assertEqual(os.getenv('TEST_VAR_3'), '/foo') + + def null_and_check(vars): + for var in vars: + os.environ[var] = '' + for var in vars: + self.assertEqual(os.getenv(var), '') + + prep = { + 'TEST_VAR_1': '/usr/bin:/usr/sbin', + 'TEST_VAR_2': '/usr/bin', + } + appd = { + 'TEST_VAR_1': '/usr/local/bin', + 'TEST_VAR_3': '/usr/local/sbin', + } + over = { + 'TEST_VAR_3': 'overridden', + } + seps = { + 'TEST_VAR_3': ';' + } + + # Test prepend and append + reset_env() + check_env() + with env.wrap_env(prepend=prep, append=appd, sep=seps): + self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin:/bar:/foo:/usr/local/bin') + self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin:/bar') + self.assertEqual(os.getenv('TEST_VAR_3'), '/foo;/usr/local/sbin') + # Test modifying the environment inside the context + null_and_check(['TEST_VAR_1', 'TEST_VAR_2', 'TEST_VAR_3']) + check_env() + + # Test sep with strict=True + def foo(): + with env.wrap_env(prepend=prep, append=appd, sep={}, strict=True): + pass + self.assertErrorRegex(EasyBuildError, "sep must be a .*", foo) + + # Test invalid value for sep + def foo(): + with env.wrap_env(prepend=prep, append=appd, sep=None): + pass + self.assertErrorRegex(EasyBuildError, "sep must be a .*", foo) + + # Test override + check_env() + with env.wrap_env(override=prep): + self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin') + self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin') + self.assertEqual(os.getenv('TEST_VAR_3'), '/foo') + # Test modifying the environment inside the context + null_and_check(['TEST_VAR_1', 'TEST_VAR_2']) + check_env() + + # Test override with prepend + with env.wrap_env(override=over, prepend=prep): + self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin:/bar:/foo') + self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin:/bar') + self.assertEqual(os.getenv('TEST_VAR_3'), 'overridden') + null_and_check(['TEST_VAR_1', 'TEST_VAR_2', 'TEST_VAR_3']) + check_env() + + # Test override duplicate key with prepend + def foo(): + with env.wrap_env(override=prep, prepend=prep, sep=None): + pass + self.assertErrorRegex( + EasyBuildError, + "The keys in override must not overlap with the keys in prepend or append.*", + foo + ) + def suite(): """ returns all the testcases in this module """