Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve buffering issues in TeeStream and capture_output #3449

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
124 changes: 111 additions & 13 deletions pyomo/common/tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@
logger = logging.getLogger(__name__)


class _SignalFlush(object):
def __init__(self, ostream, handle):
super().__setattr__('_ostream', ostream)
super().__setattr__('_handle', handle)

def flush(self):
self._ostream.flush()
self._handle.flush = True

def __getattr__(self, attr):
return getattr(self._ostream, attr)

def __setattr__(self, attr, val):
return setattr(self._ostream, attr, val)


class _AutoFlush(_SignalFlush):
def write(self, data):
self._ostream.write(data)
self.flush()

def writelines(self, data):
self._ostream.writelines(data)
self.flush()


class redirect_fd(object):
"""Redirect a file descriptor to a new file or file descriptor.

Expand Down Expand Up @@ -152,10 +178,33 @@ def __exit__(self, t, v, traceback):


class capture_output(object):
"""
Drop-in substitute for PyUtilib's capture_output.
Takes in a StringIO, file-like object, or filename and temporarily
redirects output to a string buffer.
"""Context manager to capture output sent to sys.stdout and sys.stderr

This is a drop-in substitute for PyUtilib's capture_output to
temporarily redirect output to the provided stream or file.

Parameters
----------
output : io.TextIOBase, TeeStream, str, or None

Output stream where all captured stdout/stderr data is sent. If
a ``str`` is provided, it is used as a file name and opened
(potentially overwriting any existing file). If ``None``, a
:class:`io.StringIO` object is created and used.

capture_fd : bool

If True, we will also redirect the low-level file descriptors
associated with stdout (1) and stderr (2) to the ``output``.
This is useful for capturing output emitted directly to the
process stdout / stderr by external compiled modules.

Returns
-------
io.TextIOBase

This is the output stream object where all data is sent.

"""

def __init__(self, output=None, capture_fd=False):
Expand All @@ -169,19 +218,22 @@ def __init__(self, output=None, capture_fd=False):
self.fd_redirect = None

def __enter__(self):
self.old = (sys.stdout, sys.stderr)
if isinstance(self.output, str):
self.output_stream = open(self.output, 'w')
else:
self.output_stream = self.output
self.old = (sys.stdout, sys.stderr)
self.tee = TeeStream(self.output_stream)
if isinstance(self.output, TeeStream):
self.tee = self.output
else:
self.tee = TeeStream(self.output_stream)
self.tee.__enter__()
sys.stdout = self.tee.STDOUT
sys.stderr = self.tee.STDERR
if self.capture_fd:
self.fd_redirect = (
redirect_fd(1, sys.stdout.fileno()),
redirect_fd(2, sys.stderr.fileno()),
redirect_fd(1, self.tee.STDOUT.fileno(), synchronize=False),
redirect_fd(2, self.tee.STDERR.fileno(), synchronize=False),
)
self.fd_redirect[0].__enter__()
self.fd_redirect[1].__enter__()
Expand Down Expand Up @@ -220,6 +272,7 @@ class _StreamHandle(object):
def __init__(self, mode, buffering, encoding, newline):
self.buffering = buffering
self.newlines = newline
self.flush = False
self.read_pipe, self.write_pipe = os.pipe()
if not buffering and 'b' not in mode:
# While we support "unbuffered" behavior in text mode,
Expand All @@ -233,6 +286,13 @@ def __init__(self, mode, buffering, encoding, newline):
newline=newline,
closefd=False,
)
if not self.buffering and buffering:
# We want this stream to be unbuffered, but Python doesn't
# allow it for text strreams. Mock up an unbuffered stream
# using AutoFlush
self.write_file = _AutoFlush(self.write_file, self)
else:
self.write_file = _SignalFlush(self.write_file, self)
self.decoder_buffer = b''
try:
self.encoding = encoding or self.write_file.encoding
Expand Down Expand Up @@ -310,7 +370,7 @@ def decodeIncomingBuffer(self):
def writeOutputBuffer(self, ostreams):
if not self.encoding:
ostring, self.output_buffer = self.output_buffer, b''
elif self.buffering == 1:
elif self.buffering > 0:
EOL = self.output_buffer.rfind(self.newlines or '\n') + 1
ostring = self.output_buffer[:EOL]
self.output_buffer = self.output_buffer[EOL:]
Expand Down Expand Up @@ -340,25 +400,41 @@ def writeOutputBuffer(self, ostreams):


class TeeStream(object):
def __init__(self, *ostreams, encoding=None):
self.ostreams = ostreams
def __init__(self, *ostreams, encoding=None, buffering=-1):
self.user_ostreams = ostreams
self.ostreams = []
self.encoding = encoding
self.buffering = buffering
self._stdout = None
self._stderr = None
self._handles = []
self._active_handles = []
self._threads = []
for s in ostreams:
try:
fileno = s.fileno()
except:
self.ostreams.append(s)
continue
s = os.fdopen(os.dup(fileno), mode=getattr(s, 'mode', None), closefd=True)
self.ostreams.append(s)

@property
def STDOUT(self):
if self._stdout is None:
self._stdout = self.open(buffering=1)
b = self.buffering
if b == -1:
b = 1
self._stdout = self.open(buffering=b)
return self._stdout

@property
def STDERR(self):
if self._stderr is None:
self._stderr = self.open(buffering=0)
b = self.buffering
if b == -1:
b = 0
self._stderr = self.open(buffering=b)
return self._stderr

def open(self, mode='w', buffering=-1, encoding=None, newline=None):
Expand Down Expand Up @@ -422,6 +498,9 @@ def close(self, in_exception=False):
self._active_handles.clear()
self._stdout = None
self._stderr = None
for orig, local in zip(self.user_ostreams, self.ostreams):
if orig is not local:
local.close()

def __enter__(self):
return self
Expand Down Expand Up @@ -452,7 +531,11 @@ def _start(self, handle):
pass

def _streamReader(self, handle):
flush = False
while True:
if handle.flush:
flush = True
handle.flush = False
new_data = os.read(handle.read_pipe, io.DEFAULT_BUFFER_SIZE)
if not new_data:
break
Expand All @@ -463,6 +546,11 @@ def _streamReader(self, handle):
handle.decodeIncomingBuffer()
# Now, output whatever we have decoded to the output streams
handle.writeOutputBuffer(self.ostreams)
if flush:
flush = False
if self.buffering:
for s in self.ostreams:
s.flush()
#
# print("STREAM READER: DONE")

Expand All @@ -489,9 +577,13 @@ def _mergedReader(self):
_fast_poll_ct = _poll_rampup
else:
new_data = None
flush = False
if _mswindows:
for handle in list(handles):
try:
if handle.flush:
flush = True
handle.flush = False
pipe = get_osfhandle(handle.read_pipe)
numAvail = PeekNamedPipe(pipe, 0)[1]
if numAvail:
Expand Down Expand Up @@ -520,6 +612,9 @@ def _mergedReader(self):
continue

handle = ready_handles[0]
if handle.flush:
flush = True
handle.flush = False
new_data = os.read(handle.read_pipe, io.DEFAULT_BUFFER_SIZE)
if not new_data:
handles.remove(handle)
Expand All @@ -532,5 +627,8 @@ def _mergedReader(self):

# Now, output whatever we have decoded to the output streams
handle.writeOutputBuffer(self.ostreams)
if flush and self.buffering:
for s in self.ostreams:
s.flush()
#
# print("MERGED READER: DONE")
Loading
Loading