Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to modify path-like environment variable in a context #4681

Open
wants to merge 8 commits into
base: 5.0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions easybuild/tools/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
import copy
import os
import contextlib

from easybuild.base import fancylogger
from easybuild.tools.build_log import EasyBuildError, dry_run_msg
Expand Down Expand Up @@ -221,3 +222,67 @@ def sanitize_env():
# unset all $PYTHON* environment variables
keys_to_unset = [key for key in os.environ if key.startswith('PYTHON')]
unset_env_vars(keys_to_unset, verbose=False)


@contextlib.contextmanager
def wrap_env(override=None, prepend=None, append=None, sep=os.pathsep, strict=False):
"""This function is a context manager that temporarily modifies environment variables.
It will override or prepend/append the values of the given dictionaries to the current environment and restore the
original environment when the context is exited.

For path-like variables, a custom separator can be specified for each variable by passing a dictionary of strings
with the same keys as prepend and append.
If a key is not present in the sep dictionary, os.pathsep will be used unless strict is True, then an error
will be raised.
Crivella marked this conversation as resolved.
Show resolved Hide resolved

Args:
override: A dictionary of environment variables to override.
prepend: A dictionary of environment variables to prepend to.
append: A dictionary of environment variables to append to.
sep: A string or a dictionary of strings to use as separator for each variable.
strict: If True, raise an error if a key is not present in the sep dictionary.
"""
if prepend is None:
prepend = {}
if append is None:
append = {}
if override is None:
override = {}

path_keys = set(prepend.keys()) | set(append.keys())
over_keys = set(override.keys())

duplicates = path_keys & over_keys
if duplicates:
raise EasyBuildError(
"The keys in override must not overlap with the keys in prepend or append: '%s'",
" ".join(duplicates)
)

orig = {}
for key in over_keys:
orig[key] = os.environ.get(key)
setvar(key, override[key])

for key in path_keys:
if isinstance(sep, dict):
if key not in sep:
if strict:
raise EasyBuildError(
"sep must be a dictionary of strings with keys for all keys in prepend and append"
)
_sep = os.pathsep
else:
_sep = sep.get(key)
elif isinstance(sep, str):
_sep = sep
else:
raise EasyBuildError("sep must be a string or a dictionary of strings")
val = orig[key] = os.environ.get(key)
path = _sep.join(filter(None, [prepend.get(key, None), val, append.get(key, None)]))
setvar(key, path)

try:
yield
finally:
restore_env_vars(orig)
86 changes: 86 additions & 0 deletions test/framework/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from unittest import TextTestRunner

import easybuild.tools.environment as env
from easybuild.tools.build_log import EasyBuildError


class EnvironmentTest(EnhancedTestCase):
Expand Down Expand Up @@ -161,6 +162,91 @@ def test_sanitize_env(self):

self.assertEqual(os.getenv('LD_PRELOAD'), None)

def test_wrap_env(self):
"""Test wrap_env function."""

def reset_env():
os.environ['TEST_VAR_1'] = '/bar:/foo'
os.environ['TEST_VAR_2'] = '/bar'
os.environ['TEST_VAR_3'] = '/foo'

def check_env():
self.assertEqual(os.getenv('TEST_VAR_1'), '/bar:/foo')
self.assertEqual(os.getenv('TEST_VAR_2'), '/bar')
self.assertEqual(os.getenv('TEST_VAR_3'), '/foo')

def null_and_check(vars):
for var in vars:
os.environ[var] = ''
for var in vars:
self.assertEqual(os.getenv(var), '')

prep = {
'TEST_VAR_1': '/usr/bin:/usr/sbin',
'TEST_VAR_2': '/usr/bin',
}
appd = {
'TEST_VAR_1': '/usr/local/bin',
'TEST_VAR_3': '/usr/local/sbin',
}
over = {
'TEST_VAR_3': 'overridden',
}
seps = {
'TEST_VAR_3': ';'
}

# Test prepend and append
reset_env()
check_env()
with env.wrap_env(prepend=prep, append=appd, sep=seps):
self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin:/bar:/foo:/usr/local/bin')
self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin:/bar')
self.assertEqual(os.getenv('TEST_VAR_3'), '/foo;/usr/local/sbin')
# Test modifying the environment inside the context
null_and_check(['TEST_VAR_1', 'TEST_VAR_2', 'TEST_VAR_3'])
check_env()

# Test sep with strict=True
def foo():
with env.wrap_env(prepend=prep, append=appd, sep={}, strict=True):
pass
self.assertErrorRegex(EasyBuildError, "sep must be a .*", foo)

# Test invalid value for sep
def foo():
with env.wrap_env(prepend=prep, append=appd, sep=None):
pass
self.assertErrorRegex(EasyBuildError, "sep must be a .*", foo)

# Test override
check_env()
with env.wrap_env(override=prep):
self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin')
self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin')
self.assertEqual(os.getenv('TEST_VAR_3'), '/foo')
# Test modifying the environment inside the context
null_and_check(['TEST_VAR_1', 'TEST_VAR_2'])
check_env()

# Test override with prepend
with env.wrap_env(override=over, prepend=prep):
self.assertEqual(os.getenv('TEST_VAR_1'), '/usr/bin:/usr/sbin:/bar:/foo')
self.assertEqual(os.getenv('TEST_VAR_2'), '/usr/bin:/bar')
self.assertEqual(os.getenv('TEST_VAR_3'), 'overridden')
null_and_check(['TEST_VAR_1', 'TEST_VAR_2', 'TEST_VAR_3'])
check_env()

# Test override duplicate key with prepend
def foo():
with env.wrap_env(override=prep, prepend=prep, sep=None):
pass
self.assertErrorRegex(
EasyBuildError,
"The keys in override must not overlap with the keys in prepend or append.*",
foo
)


def suite():
""" returns all the testcases in this module """
Expand Down
Loading