Skip to content

Commit

Permalink
Merge branch 'main' into autoslots-deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
blnicho authored Nov 14, 2024
2 parents 729a433 + fe3f83f commit 4bed62c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
8 changes: 3 additions & 5 deletions pyomo/common/tests/test_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import multiprocessing
import os
import time
from io import StringIO

import pyomo.common.unittest as unittest
from pyomo.common.log import LoggingIntercept
Expand Down Expand Up @@ -190,7 +189,7 @@ def test_timeout(self):
@unittest.timeout(0.01)
def test_timeout_timeout(self):
time.sleep(1)
self.assertEqual(0, 1)
self.assertEqual(0, 0)

@unittest.timeout(10)
def test_timeout_skip(self):
Expand Down Expand Up @@ -218,8 +217,7 @@ def test_bound_function(self):
if multiprocessing.get_start_method() == 'fork':
self.bound_function()
return
LOG = StringIO()
with LoggingIntercept(LOG):
with LoggingIntercept() as LOG:
with self.assertRaises((TypeError, EOFError, AttributeError)):
self.bound_function()
self.assertIn("platform that does not support 'fork'", LOG.getvalue())
Expand All @@ -234,7 +232,7 @@ def test_bound_function_require_fork(self):
self.bound_function_require_fork()
return
with self.assertRaisesRegex(
unittest.SkipTest, "timeout requires unavailable fork interface"
unittest.SkipTest, r"timeout\(\) requires unavailable fork interface"
):
self.bound_function_require_fork()

Expand Down
56 changes: 33 additions & 23 deletions pyomo/common/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,11 @@ def _assertStructuredAlmostEqual(
raise exception(msg)


def _runner(q, qualname):
def _runner(pipe, qualname):
"Utility wrapper for running functions, used by timeout()"
resultType = _RunnerResult.call
if q in _runner.data:
fcn, args, kwargs = _runner.data[q]
if pipe in _runner.data:
fcn, args, kwargs = _runner.data[pipe]
elif isinstance(qualname, str):
# Use unittest to instantiate the TestCase and run it
resultType = _RunnerResult.unittest
Expand All @@ -328,11 +328,10 @@ def fcn():
else:
qualname, fcn, args, kwargs = qualname
_runner.data[qualname] = None
OUT = StringIO()
try:
with capture_output(OUT):
with capture_output() as OUT:
result = fcn(*args, **kwargs)
q.put((resultType, result, OUT.getvalue()))
pipe.send((resultType, result, OUT.getvalue()))
except:
import traceback

Expand All @@ -341,7 +340,7 @@ def fcn():
e = etype(
"%s\nOriginal traceback:\n%s" % (e, ''.join(traceback.format_tb(tb)))
)
q.put((_RunnerResult.exception, e, OUT.getvalue()))
pipe.send((_RunnerResult.exception, e, OUT.getvalue()))
finally:
_runner.data.pop(qualname)

Expand Down Expand Up @@ -418,18 +417,24 @@ def timeout_decorator(fcn):
@functools.wraps(fcn)
def test_timer(*args, **kwargs):
qualname = '%s.%s' % (fcn.__module__, fcn.__qualname__)
# If qualname is in the data dict, then we are in the child
# process and are being asked to run the wrapped function.
if qualname in _runner.data:
return fcn(*args, **kwargs)
# Parent process: spawn a subprocess to execute the wrapped
# function and monitor for timeout
if require_fork and multiprocessing.get_start_method() != 'fork':
raise _unittest.SkipTest("timeout requires unavailable fork interface")
raise _unittest.SkipTest(
"timeout() requires unavailable fork interface"
)

q = multiprocessing.Queue()
pipe_recv, pipe_send = multiprocessing.Pipe(False)
if multiprocessing.get_start_method() == 'fork':
# Option 1: leverage fork if possible. This minimizes
# the reliance on serialization and ensures that the
# wrapped function operates in the same environment.
_runner.data[q] = (fcn, args, kwargs)
runner_args = (q, qualname)
_runner.data[pipe_send] = (fcn, args, kwargs)
runner_arg = qualname
elif (
args
and fcn.__name__.startswith('test')
Expand All @@ -439,36 +444,41 @@ def test_timer(*args, **kwargs):
# unittest in the child process with this function as
# the sole target. This ensures that things like setUp
# and tearDown are correctly called.
runner_args = (q, qualname)
runner_arg = qualname
else:
# Option 3: attempt to serialize the function and all
# arguments and send them to the (spawned) child
# process. The wrapped function cannot count on any
# environment configuration that it does not set up
# itself.
runner_args = (q, (qualname, test_timer, args, kwargs))
test_proc = multiprocessing.Process(target=_runner, args=runner_args)
runner_arg = (qualname, test_timer, args, kwargs)
test_proc = multiprocessing.Process(
target=_runner, args=(pipe_send, runner_arg)
)
# Set daemon: if the parent process is killed, the child
# process should be killed and collected.
test_proc.daemon = True
try:
test_proc.start()
except:
if type(runner_args[1]) is tuple:
if type(runner_arg) is tuple:
logging.getLogger(__name__).error(
"Exception raised spawning timeout subprocess "
"Exception raised spawning timeout() subprocess "
"on a platform that does not support 'fork'. "
"It is likely that either the wrapped function or "
"one of its arguments is not serializable"
)
raise
try:
resultType, result, stdout = q.get(True, seconds)
except queue.Empty:
test_proc.terminate()
raise timeout_raises(
"test timed out after %s seconds" % (seconds,)
) from None
if pipe_recv.poll(seconds):
resultType, result, stdout = pipe_recv.recv()
else:
test_proc.terminate()
raise timeout_raises(
"test timed out after %s seconds" % (seconds,)
) from None
finally:
_runner.data.pop(q, None)
_runner.data.pop(pipe_send, None)
sys.stdout.write(stdout)
test_proc.join()
if resultType == _RunnerResult.call:
Expand Down

0 comments on commit 4bed62c

Please sign in to comment.