From 83e35021fda1693ecb7b4e0fb747a70c1c22263f Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 21 Dec 2022 09:06:08 +0100 Subject: [PATCH] Implement sub-shells --- ipykernel/athread.py | 54 +++ ipykernel/control.py | 27 +- ipykernel/debugger.py | 48 +-- ipykernel/inprocess/ipkernel.py | 2 +- ipykernel/iostream.py | 100 +++-- ipykernel/ipkernel.py | 182 ++++----- ipykernel/kernelapp.py | 84 +++- ipykernel/kernelbase.py | 562 ++++++++++++--------------- ipykernel/shell.py | 6 + ipykernel/tests/conftest.py | 8 +- ipykernel/tests/test_async.py | 2 +- ipykernel/tests/test_debugger.py | 1 + ipykernel/tests/test_embed_kernel.py | 3 +- ipykernel/tests/test_io.py | 31 +- ipykernel/tests/test_subshell.py | 66 ++++ pyproject.toml | 8 +- 16 files changed, 637 insertions(+), 547 deletions(-) create mode 100644 ipykernel/athread.py create mode 100644 ipykernel/shell.py create mode 100644 ipykernel/tests/test_subshell.py diff --git a/ipykernel/athread.py b/ipykernel/athread.py new file mode 100644 index 000000000..9aa5a86a2 --- /dev/null +++ b/ipykernel/athread.py @@ -0,0 +1,54 @@ +import asyncio +import threading + +import janus + + +class AThread(threading.Thread): + """A thread that can run async tasks.""" + + def __init__(self, name, awaitables=None): + super().__init__(name=name, daemon=True) + self._aws = list(awaitables) if awaitables is not None else [] + self._lock = threading.Lock() + self.__initialized = False + self._stopped = False + + def run(self): + asyncio.run(self._main()) + + async def _main(self): + with self._lock: + if self._stopped: + return + self._queue = janus.Queue() + self.__initialized = True + self._tasks = [asyncio.create_task(aw) for aw in self._aws] + + while True: + try: + aw = await self._queue.async_q.get() + except BaseException: + break + if aw is None: + break + self._tasks.append(asyncio.create_task(aw)) + + for task in self._tasks: + task.cancel() + + def create_task(self, awaitable): + """Create a task in the thread (thread-safe).""" + with self._lock: + if self.__initialized: + self._queue.sync_q.put(awaitable) + else: + self._aws.append(awaitable) + + def stop(self): + """Stop the thread (thread-safe).""" + with self._lock: + if self.__initialized: + self._queue.sync_q.put(None) + else: + self._stopped = True diff --git a/ipykernel/control.py b/ipykernel/control.py index d78a4ebe1..c6b5891ef 100644 --- a/ipykernel/control.py +++ b/ipykernel/control.py @@ -1,30 +1,11 @@ -"""A thread for a control channel.""" -from threading import Thread +from .athread import AThread -from tornado.ioloop import IOLoop - -class ControlThread(Thread): +class ControlThread(AThread): """A thread for a control channel.""" - def __init__(self, **kwargs): + def __init__(self): """Initialize the thread.""" - Thread.__init__(self, name="Control", **kwargs) - self.io_loop = IOLoop(make_current=False) + super().__init__(name="Control") self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True - - def run(self): - """Run the thread.""" - self.name = "Control" - try: - self.io_loop.start() - finally: - self.io_loop.close() - - def stop(self): - """Stop the thread. - - This method is threadsafe. - """ - self.io_loop.add_callback(self.io_loop.stop) diff --git a/ipykernel/debugger.py b/ipykernel/debugger.py index 43ae68300..a29b67ab5 100644 --- a/ipykernel/debugger.py +++ b/ipykernel/debugger.py @@ -1,4 +1,5 @@ """Debugger implementation for the IPython kernel.""" +import asyncio import os import re import sys @@ -7,8 +8,6 @@ import zmq from IPython.core.getipython import get_ipython from IPython.core.inputtransformer2 import leading_empty_lines -from tornado.locks import Event -from tornado.queues import Queue from zmq.utils import jsonapi try: @@ -116,7 +115,7 @@ def __init__(self, event_callback, log): self.tcp_buffer = "" self._reset_tcp_pos() self.event_callback = event_callback - self.message_queue: Queue[t.Any] = Queue() + self.message_queue: asyncio.Queue[t.Any] = asyncio.Queue() self.log = log def _reset_tcp_pos(self): @@ -192,17 +191,17 @@ async def get_message(self): class DebugpyClient: """A client for debugpy.""" - def __init__(self, log, debugpy_stream, event_callback): + def __init__(self, log, debugpy_socket, event_callback): """Initialize the client.""" self.log = log - self.debugpy_stream = debugpy_stream + self.debugpy_socket = debugpy_socket self.event_callback = event_callback self.message_queue = DebugpyMessageQueue(self._forward_event, self.log) self.debugpy_host = "127.0.0.1" self.debugpy_port = -1 self.routing_id = None self.wait_for_attach = True - self.init_event = Event() + self.init_event = asyncio.Event() self.init_event_seq = -1 def _get_endpoint(self): @@ -215,9 +214,9 @@ def _forward_event(self, msg): self.init_event_seq = msg["seq"] self.event_callback(msg) - def _send_request(self, msg): + async def _send_request(self, msg): if self.routing_id is None: - self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID) + self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID) content = jsonapi.dumps( msg, default=json_default, @@ -232,7 +231,7 @@ def _send_request(self, msg): self.log.debug("DEBUGPYCLIENT:") self.log.debug(self.routing_id) self.log.debug(buf) - self.debugpy_stream.send_multipart((self.routing_id, buf)) + await self.debugpy_socket.send_multipart((self.routing_id, buf)) async def _wait_for_response(self): # Since events are never pushed to the message_queue @@ -250,7 +249,7 @@ async def _handle_init_sequence(self): "seq": int(self.init_event_seq) + 1, "command": "configurationDone", } - self._send_request(configurationDone) + await self._send_request(configurationDone) # 3] Waits for configurationDone response await self._wait_for_response() @@ -262,7 +261,7 @@ async def _handle_init_sequence(self): def get_host_port(self): """Get the host debugpy port.""" if self.debugpy_port == -1: - socket = self.debugpy_stream.socket + socket = self.debugpy_socket socket.bind_to_random_port("tcp://" + self.debugpy_host) self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8") socket.unbind(self.endpoint) @@ -272,14 +271,14 @@ def get_host_port(self): def connect_tcp_socket(self): """Connect to the tcp socket.""" - self.debugpy_stream.socket.connect(self._get_endpoint()) - self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID) + self.debugpy_socket.connect(self._get_endpoint()) + self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID) def disconnect_tcp_socket(self): """Disconnect from the tcp socket.""" - self.debugpy_stream.socket.disconnect(self._get_endpoint()) + self.debugpy_socket.disconnect(self._get_endpoint()) self.routing_id = None - self.init_event = Event() + self.init_event = asyncio.Event() self.init_event_seq = -1 self.wait_for_attach = True @@ -289,7 +288,7 @@ def receive_dap_frame(self, frame): async def send_dap_request(self, msg): """Send a dap request.""" - self._send_request(msg) + await self._send_request(msg) if self.wait_for_attach and msg["command"] == "attach": rep = await self._handle_init_sequence() self.wait_for_attach = False @@ -319,17 +318,17 @@ class Debugger: static_debug_msg_types = ["debugInfo", "inspectVariables", "richInspectVariables", "modules"] def __init__( - self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True + self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True ): """Initialize the debugger.""" self.log = log - self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event) + self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event) self.shell_socket = shell_socket self.session = session self.is_started = False self.event_callback = event_callback self.just_my_code = just_my_code - self.stopped_queue: Queue[t.Any] = Queue() + self.stopped_queue: asyncio.Queue[t.Any] = asyncio.Queue() self.started_debug_handlers = {} for msg_type in Debugger.started_debug_msg_types: @@ -406,7 +405,7 @@ async def handle_stopped_event(self): def tcp_client(self): return self.debugpy_client - def start(self): + async def start(self): """Start the debugger.""" if not self.debugpy_initialized: tmp_dir = get_tmp_directory() @@ -424,7 +423,12 @@ def start(self): (self.shell_socket.getsockopt(ROUTING_ID)), ) - ident, msg = self.session.recv(self.shell_socket, mode=0) + msg = await self.shell_socket.recv_multipart() + idents, msg = self.session.feed_identities(msg, copy=True) + try: + msg = self.session.deserialize(msg, content=True, copy=True) + except BaseException: + self.log.error("Invalid Message", exc_info=True) self.debugpy_initialized = msg["content"]["status"] == "ok" # Don't remove leading empty lines when debugging so the breakpoints are correctly positioned @@ -685,7 +689,7 @@ async def process_request(self, message): if self.is_started: self.log.info("The debugger has already started") else: - self.is_started = self.start() + self.is_started = await self.start() if self.is_started: self.log.info("The debugger has started") else: diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index df34303b4..3707867e6 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -51,7 +51,7 @@ class InProcessKernel(IPythonKernel): _underlying_iopub_socket = Instance(DummySocket, ()) iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment] - shell_stream = Instance(DummySocket, ()) + # shell_stream = Instance(DummySocket, ()) @default("iopub_thread") def _default_iopub_thread(self): diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 22b50a2a4..bfce08fc3 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -3,6 +3,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +import asyncio import atexit import io import os @@ -18,8 +19,8 @@ import zmq from jupyter_client.session import extract_header -from tornado.ioloop import IOLoop -from zmq.eventloop.zmqstream import ZMQStream + +from .athread import AThread # ----------------------------------------------------------------------------- # Globals @@ -57,35 +58,30 @@ def __init__(self, socket, pipe=False): self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe - self.io_loop = IOLoop(make_current=False) + self.io_loop = asyncio.new_event_loop() if pipe: self._setup_pipe_in() self._local = threading.local() self._events: Deque[Callable[..., Any]] = deque() self._event_pipes: WeakSet[Any] = WeakSet() self._setup_event_pipe() - self.thread = threading.Thread(target=self._thread_main, name="IOPub") - self.thread.daemon = True + # self.thread = threading.Thread(target=self._thread_main, name="IOPub") + aws = [self._handle_event()] + if self._pipe_flag: + aws.append(self._handle_pipe_msg()) + self.thread = AThread("IOPub", awaitables=aws) self.thread.pydev_do_not_trace = True # type:ignore[attr-defined] self.thread.is_pydev_daemon_thread = True # type:ignore[attr-defined] - self.thread.name = "IOPub" - - def _thread_main(self): - """The inner loop that's actually run in a thread""" - self.io_loop.start() - self.io_loop.close(all_fds=True) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" ctx = self.socket.context - pipe_in = ctx.socket(zmq.PULL) - pipe_in.linger = 0 + self._pipe_in0 = ctx.socket(zmq.PULL) + self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") iface = self._event_interface = "inproc://%s" % _uuid - pipe_in.bind(iface) - self._event_puller = ZMQStream(pipe_in, self.io_loop) - self._event_puller.on_recv(self._handle_event) + self._pipe_in0.bind(iface) @property def _event_pipe(self): @@ -94,7 +90,7 @@ def _event_pipe(self): event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe - ctx = self.socket.context + ctx = zmq.Context(self.socket.context) event_pipe = ctx.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) @@ -104,7 +100,7 @@ def _event_pipe(self): self._event_pipes.add(event_pipe) return event_pipe - def _handle_event(self, msg): + async def _handle_event(self): """Handle an event on the event pipe Content of the message is ignored. @@ -112,12 +108,14 @@ def _handle_event(self, msg): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ - # freeze event count so new writes don't extend the queue - # while we are processing - n_events = len(self._events) - for _ in range(n_events): - event_f = self._events.popleft() - event_f() + while True: + await self._pipe_in0.recv() + # freeze event count so new writes don't extend the queue + # while we are processing + n_events = len(self._events) + for _ in range(n_events): + event_f = self._events.popleft() + event_f() def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" @@ -126,30 +124,30 @@ def _setup_pipe_in(self): # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - pipe_in = ctx.socket(zmq.PULL) - pipe_in.linger = 0 + self._pipe_in1 = ctx.socket(zmq.PULL) + self._pipe_in1.linger = 0 try: - self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") + self._pipe_port = self._pipe_in1.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn( "Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable." ) self._pipe_flag = False - pipe_in.close() + self._pipe_in1.close() return - self._pipe_in = ZMQStream(pipe_in, self.io_loop) - self._pipe_in.on_recv(self._handle_pipe_msg) - def _handle_pipe_msg(self, msg): + async def _handle_pipe_msg(self): """handle a pipe message from a subprocess""" - if not self._pipe_flag or not self._is_master_process(): - return - if msg[0] != self._pipe_uuid: - print("Bad pipe message: %s", msg, file=sys.__stderr__) - return - self.send_multipart(msg[1:]) + while True: + msg = await self._pipe_in1.recv_multipart() + if not self._pipe_flag or not self._is_master_process(): + return + if msg[0] != self._pipe_uuid: + print("Bad pipe message: %s", msg, file=sys.__stderr__) + return + self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork @@ -171,7 +169,7 @@ def _check_mp_mode(self): def start(self): """Start the IOPub thread""" - self.thread.name = "IOPub" + # self.thread.name = "IOPub" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. @@ -181,7 +179,8 @@ def stop(self): """Stop the IOPub thread""" if not self.thread.is_alive(): return - self.io_loop.add_callback(self.io_loop.stop) + # self.io_loop.call_soon_threadsafe(self.io_loop.stop) + self.thread.stop() self.thread.join() # close *all* event pipes, created in any thread # event pipes can only be used from other threads while self.thread.is_alive() @@ -193,6 +192,9 @@ def close(self): """Close the IOPub thread.""" if self.closed: return + self._pipe_in0.close() + if self._pipe_flag: + self._pipe_in1.close() self.socket.close() self.socket = None @@ -208,6 +210,10 @@ def schedule(self, f): if self.thread.is_alive(): self._events.append(f) # wake event thread (message content is ignored) + # try: + # asyncio.get_running_loop() + # except BaseException: + # asyncio.set_event_loop(asyncio.new_event_loop()) self._event_pipe.send(b"") else: f() @@ -378,6 +384,8 @@ def __init__( ) # This is necessary for compatibility with Python built-in streams self.session = session + self._has_thread = False + self.watch_fd_thread = None if not isinstance(pub_thread, IOPubThread): # Backward-compat: given socket, not thread. Wrap in a thread. warnings.warn( @@ -388,6 +396,7 @@ def __init__( ) pub_thread = IOPubThread(pub_thread) pub_thread.start() + self._has_thread = True self.pub_thread = pub_thread self.name = name self.topic = b"stream." + name.encode() @@ -449,10 +458,14 @@ def close(self): """Close the stream.""" if sys.platform.startswith("linux") or sys.platform.startswith("darwin"): self._should_watch = False - self.watch_fd_thread.join() + if self.watch_fd_thread is not None: + self.watch_fd_thread.join() if self._exc: etype, value, tb = self._exc traceback.print_exception(etype, value, tb) + if self._has_thread: + self.pub_thread.stop() + self.pub_thread.close() self.pub_thread = None @property @@ -469,10 +482,11 @@ def _schedule_flush(self): self._flush_pending = True # add_timeout has to be handed to the io thread via event pipe - def _schedule_in_thread(): - self._io_loop.call_later(self.flush_interval, self._flush) + # def _schedule_in_thread(): + # self._io_loop.call_later(self.flush_interval, self._flush) - self.pub_thread.schedule(_schedule_in_thread) + # self.pub_thread.schedule(_schedule_in_thread) + self.pub_thread.schedule(self._flush) def flush(self): """trigger actual zmq send diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index a4b975a4b..fca4e30f4 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -7,14 +7,12 @@ import sys import threading import typing as t -from contextlib import contextmanager -from functools import partial import comm +import zmq.asyncio from IPython.core import release from IPython.utils.tokenutil import line_at_cursor, token_at_cursor from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat -from zmq.eventloop.zmqstream import ZMQStream from .comm.comm import BaseComm from .comm.manager import CommManager @@ -22,7 +20,7 @@ from .debugger import Debugger, _is_debugpy_available from .eventloops import _use_appnope from .kernelbase import Kernel as KernelBase -from .kernelbase import _accepts_cell_id +from .kernelbase import _accepts_arg from .zmqshell import ZMQInteractiveShell try: @@ -39,6 +37,11 @@ _use_experimental_60_completion = False +def DEBUG(msg): + with open("debug.log", "a") as f: + f.write(f"{msg}\n") + + _EXPERIMENTAL_KEY_NAME = "_jupyter_types_experimental" @@ -52,6 +55,10 @@ def _create_comm(*args, **kwargs): _comm_manager: t.Optional[CommManager] = None +def _sigint_handler(*args): + raise KeyboardInterrupt + + def _get_comm_manager(*args, **kwargs): """Create a new CommManager.""" global _comm_manager @@ -77,7 +84,9 @@ class IPythonKernel(KernelBase): help="Set this flag to False to deactivate the use of experimental IPython completion APIs.", ).tag(config=True) - debugpy_stream = Instance(ZMQStream, allow_none=True) if _is_debugpy_available else None + debugpy_socket = ( + Instance(zmq.asyncio.Socket, allow_none=True) if _is_debugpy_available else None + ) user_module = Any() @@ -109,7 +118,7 @@ def __init__(self, **kwargs): if _is_debugpy_available: self.debugger = Debugger( self.log, - self.debugpy_stream, + self.debugpy_socket, self._publish_debug_event, self.debug_shell_socket, self.session, @@ -191,12 +200,18 @@ def __init__(self, **kwargs): "file_extension": ".py", } - def dispatch_debugpy(self, msg): + async def process_debugpy(self): + asyncio.create_task(self.dispatch_debugpy()) + asyncio.create_task(self.poll_stopped_queue()) + + async def dispatch_debugpy(self): if _is_debugpy_available: - # The first frame is the socket id, we can drop it - frame = msg[1].bytes.decode("utf-8") - self.log.debug("Debugpy received: %s", frame) - self.debugger.tcp_client.receive_dap_frame(frame) + while True: + msg = await self.debugpy_socket.recv_multipart() + # The first frame is the socket id, we can drop it + frame = msg[1].decode("utf-8") + self.log.debug("Debugpy received: %s", frame) + self.debugger.tcp_client.receive_dap_frame(frame) @property def banner(self): @@ -210,15 +225,12 @@ async def poll_stopped_queue(self): def start(self): """Start the kernel.""" self.shell.exit_now = False - if self.debugpy_stream is None: - self.log.warning("debugpy_stream undefined, debugging will not be enabled") - else: - self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False) + if self.debugpy_socket is None: + self.log.warning("debugpy_socket undefined, debugging will not be enabled") super().start() - if self.debugpy_stream: - asyncio.run_coroutine_threadsafe( - self.poll_stopped_queue(), self.control_thread.io_loop.asyncio_loop - ) + if self.debugpy_socket: + # asyncio.run_coroutine_threadsafe(self.process_debugpy(), self.control_thread.loop) + self.control_thread.create_task(self.process_debugpy()) def set_parent(self, ident, parent, channel="shell"): """Overridden from parent to tell the display hook and output streams @@ -286,50 +298,6 @@ def execution_count(self, value): # execution counter. pass - @contextmanager - def _cancel_on_sigint(self, future): - """ContextManager for capturing SIGINT and cancelling a future - - SIGINT raises in the event loop when running async code, - but we want it to halt a coroutine. - - Ideally, it would raise KeyboardInterrupt, - but this turns it into a CancelledError. - At least it gets a decent traceback to the user. - """ - sigint_future: asyncio.Future[int] = asyncio.Future() - - # whichever future finishes first, - # cancel the other one - def cancel_unless_done(f, _ignored): - if f.cancelled() or f.done(): - return - f.cancel() - - # when sigint finishes, - # abort the coroutine with CancelledError - sigint_future.add_done_callback(partial(cancel_unless_done, future)) - # when the main future finishes, - # stop watching for SIGINT events - future.add_done_callback(partial(cancel_unless_done, sigint_future)) - - def handle_sigint(*args): - def set_sigint_result(): - if sigint_future.cancelled() or sigint_future.done(): - return - sigint_future.set_result(1) - - # use add_callback for thread safety - self.io_loop.add_callback(set_sigint_result) - - # set the custom sigint hander during this context - save_sigint = signal.signal(signal.SIGINT, handle_sigint) - try: - yield - finally: - # restore the previous sigint handler - signal.signal(signal.SIGINT, save_sigint) - async def do_execute( self, code, @@ -339,6 +307,7 @@ async def do_execute( allow_stdin=False, *, cell_id=None, + shell_id=None, ): """Handle code execution.""" shell = self.shell # we'll need this a lot here @@ -349,7 +318,7 @@ async def do_execute( if hasattr(shell, "run_cell_async") and hasattr(shell, "should_run_async"): run_cell = shell.run_cell_async should_run_async = shell.should_run_async - with_cell_id = _accepts_cell_id(run_cell) + with_cell_id = _accepts_arg(run_cell, "cell_id") else: should_run_async = lambda cell: False # noqa # older IPython, @@ -358,7 +327,8 @@ async def do_execute( async def run_cell(*args, **kwargs): return shell.run_cell(*args, **kwargs) - with_cell_id = _accepts_cell_id(shell.run_cell) + with_cell_id = _accepts_arg(shell.run_cell, "cell_id") + try: # default case: runner is asyncio and asyncio is already running @@ -371,6 +341,13 @@ async def run_cell(*args, **kwargs): transformed_cell = code preprocessing_exc_tuple = sys.exc_info() + kwargs = dict( + store_history=store_history, + silent=silent, + ) + if with_cell_id: + kwargs.update(cell_id=cell_id) + if ( _asyncio_runner and shell.loop_runner is _asyncio_runner @@ -381,47 +358,54 @@ async def run_cell(*args, **kwargs): preprocessing_exc_tuple=preprocessing_exc_tuple, ) ): - if with_cell_id: - coro = run_cell( - code, - store_history=store_history, - silent=silent, - transformed_cell=transformed_cell, - preprocessing_exc_tuple=preprocessing_exc_tuple, - cell_id=cell_id, - ) - else: - coro = run_cell( - code, - store_history=store_history, - silent=silent, - transformed_cell=transformed_cell, - preprocessing_exc_tuple=preprocessing_exc_tuple, - ) + kwargs.update( + transformed_cell=transformed_cell, + preprocessing_exc_tuple=preprocessing_exc_tuple, + ) + coro = run_cell(code, **kwargs) - coro_future = asyncio.ensure_future(coro) + async def run(): + res = await coro + if self.shell_msg_thread and shell_id is not None: + self.shells[shell_id]["interrupt"].sync_q.put(False) + else: + self.shells[shell_id]["interrupt"].put_nowait(False) + return res + + task = asyncio.create_task(run()) + + sigint_prev_handler = signal.signal(signal.SIGINT, _sigint_handler) + + if self.shell_msg_thread and shell_id is not None: + interrupt = await self.shells[shell_id]["interrupt"].async_q.get() + else: + interrupt = await self.shells[shell_id]["interrupt"].get() - with self._cancel_on_sigint(coro_future): + if interrupt: + task.cancel() res = None - try: - res = await coro_future - finally: - shell.events.trigger("post_execute") - if not silent: - shell.events.trigger("post_run_cell", res) + else: + res = await task + + signal.signal(signal.SIGINT, sigint_prev_handler) + + shell.events.trigger("post_execute") + if not silent: + shell.events.trigger("post_run_cell", res) + else: # runner isn't already running, # make synchronous call, # letting shell dispatch to loop runners - if with_cell_id: - res = shell.run_cell( - code, - store_history=store_history, - silent=silent, - cell_id=cell_id, - ) - else: - res = shell.run_cell(code, store_history=store_history, silent=silent) + + if shell_id is None: + sigint_prev_handler = signal.signal(signal.SIGINT, _sigint_handler) + try: + res = shell.run_cell(code, **kwargs) + finally: + if shell_id is None: + signal.signal(signal.SIGINT, sigint_prev_handler) + finally: self._restore_input() diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index a98439cfc..488169c2a 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -3,6 +3,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +import asyncio import atexit import errno import logging @@ -14,7 +15,9 @@ from io import FileIO, TextIOWrapper from logging import StreamHandler +import janus import zmq +import zmq.asyncio from IPython.core.application import ( # type:ignore[attr-defined] BaseIPythonApplication, base_aliases, @@ -27,7 +30,6 @@ from jupyter_client.connect import ConnectionFileMixin from jupyter_client.session import Session, session_aliases, session_flags from jupyter_core.paths import jupyter_runtime_dir -from tornado import ioloop from traitlets.traitlets import ( Any, Bool, @@ -41,7 +43,6 @@ ) from traitlets.utils import filefind from traitlets.utils.importstring import import_item -from zmq.eventloop.zmqstream import ZMQStream from .control import ControlThread from .heartbeat import Heartbeat @@ -50,8 +51,15 @@ from .iostream import IOPubThread from .ipkernel import IPythonKernel from .parentpoller import ParentPollerUnix, ParentPollerWindows +from .shell import ShellThread from .zmqshell import ZMQInteractiveShell + +def DEBUG(msg): + with open("debug.log", "a") as f: + f.write(f"{msg}\n") + + # ----------------------------------------------------------------------------- # Flags and Aliases # ----------------------------------------------------------------------------- @@ -132,6 +140,7 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix heartbeat = Instance(Heartbeat, allow_none=True) context = Any() + acontext = Any() shell_socket = Any() control_socket = Any() debugpy_socket = Any() @@ -139,6 +148,7 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix stdin_socket = Any() iopub_socket = Any() iopub_thread = Any() + shell_msg_thread = Any() control_thread = Any() _ports = Dict() @@ -308,10 +318,12 @@ def init_sockets(self): """Create a context, a session, and the kernel sockets.""" self.log.info("Starting the kernel at pid: %i", os.getpid()) assert self.context is None, "init_sockets cannot be called twice!" + assert self.acontext is None, "init_sockets cannot be called twice!" self.context = context = zmq.Context() + self.acontext = acontext = zmq.asyncio.Context() atexit.register(self.close) - self.shell_socket = context.socket(zmq.ROUTER) + self.shell_socket = acontext.socket(zmq.ROUTER) self.shell_socket.linger = 1000 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port) self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port) @@ -327,8 +339,8 @@ def init_sockets(self): # see ipython/ipykernel#270 and zeromq/libzmq#2892 self.shell_socket.router_handover = self.stdin_socket.router_handover = 1 - self.init_control(context) - self.init_iopub(context) + self.init_control(acontext) + self.init_iopub(acontext) def init_control(self, context): """Initialize the control channel.""" @@ -351,11 +363,12 @@ def init_control(self, context): # see ipython/ipykernel#270 and zeromq/libzmq#2892 self.control_socket.router_handover = 1 - self.control_thread = ControlThread(daemon=True) + self.shell_msg_thread = ShellThread("messages") + self.control_thread = ControlThread() - def init_iopub(self, context): + def init_iopub(self, acontext): """Initialize the iopub channel.""" - self.iopub_socket = context.socket(zmq.PUB) + self.iopub_socket = acontext.socket(zmq.PUB) self.iopub_socket.linger = 1000 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port) self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port) @@ -387,6 +400,10 @@ def close(self): self.log.debug("Closing iopub channel") self.iopub_thread.stop() self.iopub_thread.close() + if self.shell_msg_thread and self.shell_msg_thread.is_alive(): + self.log.debug("Closing shell message thread") + self.shell_msg_thread.stop() + self.shell_msg_thread.join() if self.control_thread and self.control_thread.is_alive(): self.log.debug("Closing control thread") self.control_thread.stop() @@ -404,6 +421,7 @@ def close(self): socket.close() self.log.debug("Terminating zmq context") self.context.term() + self.acontext.term() self.log.debug("Terminated zmq context") def log_connection_info(self): @@ -531,19 +549,18 @@ def init_signal(self): def init_kernel(self): """Create the Kernel object itself""" - shell_stream = ZMQStream(self.shell_socket) - control_stream = ZMQStream(self.control_socket, self.control_thread.io_loop) - debugpy_stream = ZMQStream(self.debugpy_socket, self.control_thread.io_loop) + self.shell_msg_thread.start() self.control_thread.start() kernel_factory = self.kernel_class.instance kernel = kernel_factory( parent=self, session=self.session, - control_stream=control_stream, - debugpy_stream=debugpy_stream, + control_socket=self.control_socket, + debugpy_socket=self.debugpy_socket, debug_shell_socket=self.debug_shell_socket, - shell_stream=shell_stream, + shell_socket=self.shell_socket, + shell_msg_thread=self.shell_msg_thread, control_thread=self.control_thread, iopub_thread=self.iopub_thread, iopub_socket=self.iopub_socket, @@ -665,6 +682,7 @@ def init_pdb(self): @catch_config_error def initialize(self, argv=None): """Initialize the application.""" + self._stopped = False self._init_asyncio_patch() super().initialize(argv) if self.subapp is not None: @@ -704,26 +722,50 @@ def initialize(self, argv=None): def start(self): """Start the application.""" + self.io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.io_loop) + self.io_loop.run_until_complete(self._async_start()) + while True: + try: + self.io_loop.run_until_complete(self._wait_for_stop()) + except KeyboardInterrupt: + # kernel interrupt, notify whatever is being executed + for shell_id, v in self.kernel.shells.items(): + if self.kernel.shell_msg_thread and shell_id is not None: + v["interrupt"].sync_q.put(True) + else: + v["interrupt"].put_nowait(True) + else: + # kernel shutdown + return + + def stop(self): + """Stop the kernel, thread-safe.""" + try: + self._stop_queue.sync_q.put(None) + except AttributeError: + self._stopped = True + + async def _async_start(self): + self._stop_queue = janus.Queue() + self.kernel._stop_queue = self._stop_queue if self.subapp is not None: return self.subapp.start() if self.poller is not None: self.poller.start() self.kernel.start() - self.io_loop = ioloop.IOLoop.current() if self.trio_loop: from ipykernel.trio_runner import TrioRunner tr = TrioRunner() - tr.initialize(self.kernel, self.io_loop) + tr.initialize(self.kernel, self.io_loop) # FIXME try: tr.run() except KeyboardInterrupt: pass - else: - try: - self.io_loop.start() - except KeyboardInterrupt: - pass + + async def _wait_for_stop(self): + await self._stop_queue.async_q.get() launch_new_instance = IPKernelApp.launch_instance diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 768056cfd..8b6c73bdf 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -4,20 +4,18 @@ # Distributed under the terms of the Modified BSD License. import asyncio -import concurrent.futures import inspect import itertools import logging import os -import socket import sys +import threading import time import typing as t import uuid import warnings from datetime import datetime -from functools import partial -from signal import SIGINT, SIGTERM, Signals, default_int_handler, signal +from signal import SIGINT, SIGTERM, Signals if sys.platform != "win32": from signal import SIGKILL @@ -32,12 +30,11 @@ # jupyter_client < 5, use local now() now = datetime.now +import janus import psutil import zmq from IPython.core.error import StdinNotImplementedError from jupyter_client.session import Session -from tornado import ioloop -from tornado.queues import Queue, QueueEmpty from traitlets.config.configurable import SingletonConfigurable from traitlets.traitlets import ( Any, @@ -52,16 +49,21 @@ default, observe, ) -from zmq.eventloop.zmqstream import ZMQStream from ipykernel.jsonutil import json_clean from ._version import kernel_protocol_version +from .shell import ShellThread -def _accepts_cell_id(meth): +def DEBUG(msg): + with open("debug.log", "a") as f: + f.write(f"{msg}\n") + + +def _accepts_arg(meth, arg: str): parameters = inspect.signature(meth).parameters - cid_param = parameters.get("cell_id") + cid_param = parameters.get(arg) return (cid_param and cid_param.kind == cid_param.KEYWORD_ONLY) or any( p.kind == p.VAR_KEYWORD for p in parameters.values() ) @@ -82,58 +84,23 @@ class Kernel(SingletonConfigurable): @observe("eventloop") def _update_eventloop(self, change): """schedule call to eventloop from IOLoop""" - loop = ioloop.IOLoop.current() + loop = asyncio.get_running_loop() if change.new is not None: - loop.add_callback(self.enter_eventloop) + loop.call_soon(self.enter_eventloop) session = Instance(Session, allow_none=True) profile_dir = Instance("IPython.core.profiledir.ProfileDir", allow_none=True) - shell_stream = Instance(ZMQStream, allow_none=True) - - shell_streams = List( - help="""Deprecated shell_streams alias. Use shell_stream - - .. versionchanged:: 6.0 - shell_streams is deprecated. Use shell_stream. - """ - ) + shell_socket = Instance(zmq.asyncio.Socket, allow_none=True) implementation: str implementation_version: str banner: str - @default("shell_streams") - def _shell_streams_default(self): # pragma: no cover - warnings.warn( - "Kernel.shell_streams is deprecated in ipykernel 6.0. Use Kernel.shell_stream", - DeprecationWarning, - stacklevel=2, - ) - if self.shell_stream is not None: - return [self.shell_stream] - else: - return [] - - @observe("shell_streams") - def _shell_streams_changed(self, change): # pragma: no cover - warnings.warn( - "Kernel.shell_streams is deprecated in ipykernel 6.0. Use Kernel.shell_stream", - DeprecationWarning, - stacklevel=2, - ) - if len(change.new) > 1: - warnings.warn( - "Kernel only supports one shell stream. Additional streams will be ignored.", - RuntimeWarning, - stacklevel=2, - ) - if change.new: - self.shell_stream = change.new[0] - - control_stream = Instance(ZMQStream, allow_none=True) + control_socket = Instance(zmq.asyncio.Socket, allow_none=True) debug_shell_socket = Any() + shell_msg_thread = Any() control_thread = Any() iopub_socket = Any() iopub_thread = Any() @@ -257,11 +224,14 @@ def _parent_header(self): "abort_request", "debug_request", "usage_request", + "create_subshell_request", ] def __init__(self, **kwargs): """Initialize the kernel.""" super().__init__(**kwargs) + self.shell_socket_lock = threading.Lock() + self.shells = {None: {}} # Build dict of handlers for message types self.shell_handlers = {} for msg_type in self.msg_types: @@ -271,79 +241,7 @@ def __init__(self, **kwargs): for msg_type in self.control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) - self.control_queue: Queue[t.Any] = Queue() - - def dispatch_control(self, msg): - self.control_queue.put_nowait(msg) - - async def poll_control_queue(self): - while True: - msg = await self.control_queue.get() - # handle tracers from _flush_control_queue - if isinstance(msg, (concurrent.futures.Future, asyncio.Future)): - msg.set_result(None) - continue - await self.process_control(msg) - - async def _flush_control_queue(self): - """Flush the control queue, wait for processing of any pending messages""" - tracer_future: t.Union[concurrent.futures.Future[object], asyncio.Future[object]] - if self.control_thread: - control_loop = self.control_thread.io_loop - # concurrent.futures.Futures are threadsafe - # and can be used to await across threads - tracer_future = concurrent.futures.Future() - awaitable_future = asyncio.wrap_future(tracer_future) - else: - control_loop = self.io_loop - tracer_future = awaitable_future = asyncio.Future() - - def _flush(): - # control_stream.flush puts messages on the queue - self.control_stream.flush() - # put Future on the queue after all of those, - # so we can wait for all queued messages to be processed - self.control_queue.put(tracer_future) - - control_loop.add_callback(_flush) - return awaitable_future - - async def process_control(self, msg): - """dispatch control requests""" - idents, msg = self.session.feed_identities(msg, copy=False) - try: - msg = self.session.deserialize(msg, content=True, copy=False) - except Exception: - self.log.error("Invalid Control Message", exc_info=True) - return - - self.log.debug("Control received: %s", msg) - - # Set the parent message for side effects. - self.set_parent(idents, msg, channel="control") - self._publish_status("busy", "control") - - header = msg["header"] - msg_type = header["msg_type"] - - handler = self.control_handlers.get(msg_type, None) - if handler is None: - self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) - else: - try: - result = handler(self.control_stream, idents, msg) - if inspect.isawaitable(result): - await result - except Exception: - self.log.error("Exception in control handler:", exc_info=True) - - sys.stdout.flush() - sys.stderr.flush() - self._publish_status("idle", "control") - # flush to ensure reply is sent - self.control_stream.flush(zmq.POLLOUT) - - def should_handle(self, stream, msg, idents): + async def should_handle(self, socket, msg, idents): """Check whether a shell-channel message should be handled Allows subclasses to prevent handling of certain messages (e.g. aborted requests). @@ -352,86 +250,20 @@ def should_handle(self, stream, msg, idents): if msg_id in self.aborted: # is it safe to assume a msg_id will not be resubmitted? self.aborted.remove(msg_id) - self._send_abort_reply(stream, msg, idents) + await self._send_abort_reply(socket, msg, idents) return False return True - async def dispatch_shell(self, msg): - """dispatch shell requests""" - - # flush control queue before handling shell requests - await self._flush_control_queue() - - idents, msg = self.session.feed_identities(msg, copy=False) - try: - msg = self.session.deserialize(msg, content=True, copy=False) - except Exception: - self.log.error("Invalid Message", exc_info=True) - return - - # Set the parent message for side effects. - self.set_parent(idents, msg, channel="shell") - self._publish_status("busy", "shell") - - msg_type = msg["header"]["msg_type"] - - # Only abort execute requests - if self._aborting and msg_type == "execute_request": - self._send_abort_reply(self.shell_stream, msg, idents) - self._publish_status("idle", "shell") - # flush to ensure reply is sent before - # handling the next request - self.shell_stream.flush(zmq.POLLOUT) - return - - # Print some info about this message and leave a '--->' marker, so it's - # easier to trace visually the message chain when debugging. Each - # handler prints its message at the end. - self.log.debug("\n*** MESSAGE TYPE:%s***", msg_type) - self.log.debug(" Content: %s\n --->\n ", msg["content"]) - - if not self.should_handle(self.shell_stream, msg, idents): - return - - handler = self.shell_handlers.get(msg_type, None) - if handler is None: - self.log.warning("Unknown message type: %r", msg_type) - else: - self.log.debug("%s: %s", msg_type, msg) - try: - self.pre_handler_hook() - except Exception: - self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) - try: - result = handler(self.shell_stream, idents, msg) - if inspect.isawaitable(result): - await result - except Exception: - self.log.error("Exception in message handler:", exc_info=True) - except KeyboardInterrupt: - # Ctrl-c shouldn't crash the kernel here. - self.log.error("KeyboardInterrupt caught in kernel.") - finally: - try: - self.post_handler_hook() - except Exception: - self.log.debug("Unable to signal in post_handler_hook:", exc_info=True) - - sys.stdout.flush() - sys.stderr.flush() - self._publish_status("idle", "shell") - # flush to ensure reply is sent before - # handling the next request - self.shell_stream.flush(zmq.POLLOUT) - def pre_handler_hook(self): """Hook to execute before calling message handler""" # ensure default_int_handler during handler call - self.saved_sigint_handler = signal(SIGINT, default_int_handler) + # self.saved_sigint_handler = signal(SIGINT, default_int_handler) + pass def post_handler_hook(self): """Hook to execute after calling message handler""" - signal(SIGINT, self.saved_sigint_handler) + # signal(SIGINT, self.saved_sigint_handler) + pass def enter_eventloop(self): """enter eventloop""" @@ -447,11 +279,6 @@ def advance_eventloop(): if self.eventloop is not eventloop: self.log.info("exiting eventloop %s", eventloop) return - if self.msg_queue.qsize(): - self.log.debug("Delaying eventloop due to waiting messages") - # still messages to process, make the eventloop wait - schedule_next() - return self.log.debug("Advancing eventloop %s", eventloop) try: eventloop(self) @@ -468,96 +295,174 @@ def schedule_next(): # flush the eventloop every so often, # giving us a chance to handle messages in the meantime self.log.debug("Scheduling eventloop advance") - self.io_loop.call_later(0.001, advance_eventloop) + asyncio.get_running_loop().call_later(0.001, advance_eventloop) # begin polling the eventloop schedule_next() - async def do_one_iteration(self): - """Process a single shell message + _message_counter = Any( + help="""Monotonic counter of messages + """, + ) - Any pending control messages will be flushed as well + @default("_message_counter") + def _message_counter_default(self): + return itertools.count() - .. versionchanged:: 5 - This is now a coroutine + async def get_shell_messages(self): + """Get messages from the shell socket, in a separate thread if sub-shells are supported, + in the main thread otherwise. """ - # flush messages off of shell stream into the message queue - self.shell_stream.flush() - # process at most one shell message per iteration - await self.process_one(wait=False) + while True: + await self.get_shell_message() - async def process_one(self, wait=True): - """Process one request + async def get_shell_message(self, msg=None): + """Get a message from the shell socket, in a separate thread if sub-shells are supported, + in the main thread otherwise. - Returns None if no message was handled. + Allow by-passing the socket and injecting a message for testing. """ - if wait: - t, dispatch, args = await self.msg_queue.get() + msg = msg or await self.shell_socket.recv_multipart() + idents, msg = self.session.feed_identities(msg, copy=True) + try: + msg = self.session.deserialize(msg, content=True, copy=True) + except BaseException: + self.log.error("Invalid Message", exc_info=True) + return + + shell_id = msg["header"].get("shell_id") + # the shell message queue for this sub-shell might not have been created yet, + # because done in another thread, it should be ok to not use a lock here + while shell_id not in self.shells or "messages" not in self.shells[shell_id]: + await asyncio.sleep(0.1) + if self.shell_msg_thread: + self.shells[shell_id]["messages"].sync_q.put((idents, msg)) else: - try: - t, dispatch, args = self.msg_queue.get_nowait() - except (asyncio.QueueEmpty, QueueEmpty): - return None - await dispatch(*args) + self.shells[shell_id]["messages"].put_nowait((idents, msg)) + + async def process_shell_messages(self, shell_id=None): + # create a message queue only for sub-shells (already created for main shell) + if shell_id is not None: + lib = janus if self.shell_msg_thread else asyncio + self.shells[shell_id]["interrupt"] = lib.Queue() + self.shells[shell_id]["messages"] = lib.Queue() + while True: + if self.shell_msg_thread: + idents, msg = await self.shells[shell_id]["messages"].async_q.get() + else: + idents, msg = await self.shells[shell_id]["messages"].get() - async def dispatch_queue(self): - """Coroutine to preserve order of message handling + # Set the parent message for side effects. + self.set_parent(idents, msg, channel="shell") + self._publish_status("busy", "shell") - Ensures that only one message is processing at a time, - even when the handler is async - """ + msg_type = msg["header"]["msg_type"] + + # Only abort execute requests + if self._aborting and msg_type == "execute_request": + await self._send_abort_reply(self.shell_socket, msg, idents) + self._publish_status("idle", "shell") + # flush to ensure reply is sent before + # handling the next request + return + + # Print some info about this message and leave a '--->' marker, so it's + # easier to trace visually the message chain when debugging. Each + # handler prints its message at the end. + self.log.debug("\n*** MESSAGE TYPE:%s***", msg_type) + self.log.debug(" Content: %s\n --->\n ", msg["content"]) + if not await self.should_handle(self.shell_socket, msg, idents): + return + + handler = self.shell_handlers.get(msg_type) + if handler is None: + self.log.warning("Unknown message type: %r", msg_type) + else: + self.log.debug("%s: %s", msg_type, msg) + try: + self.pre_handler_hook() + except Exception: + self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) + try: + result = handler(self.shell_socket, idents, msg) + if inspect.isawaitable(result): + await result + DEBUG(f"{shell_id} {result=}") + except Exception: + self.log.error("Exception in message handler:", exc_info=True) + except KeyboardInterrupt: + # Ctrl-c shouldn't crash the kernel here. + self.log.error("KeyboardInterrupt caught in kernel.") + finally: + try: + self.post_handler_hook() + except Exception: + self.log.debug("Unable to signal in post_handler_hook:", exc_info=True) + + sys.stdout.flush() + sys.stderr.flush() + self._publish_status("idle", "shell") + # flush to ensure reply is sent before + # handling the next request + + async def process_control_messages(self): while True: + msg = await self.control_socket.recv_multipart() + idents, msg = self.session.feed_identities(msg, copy=True) try: - await self.process_one() + msg = self.session.deserialize(msg, content=True, copy=True) except Exception: - self.log.exception("Error in message handler") + self.log.error("Invalid Control Message", exc_info=True) + return - _message_counter = Any( - help="""Monotonic counter of messages - """, - ) + self.log.debug("Control received: %s", msg) - @default("_message_counter") - def _message_counter_default(self): - return itertools.count() + # Set the parent message for side effects. + self.set_parent(idents, msg, channel="control") + self._publish_status("busy", "control") - def schedule_dispatch(self, dispatch, *args): - """schedule a message for dispatch""" - idx = next(self._message_counter) + header = msg["header"] + msg_type = header["msg_type"] - self.msg_queue.put_nowait( - ( - idx, - dispatch, - args, - ) - ) - # ensure the eventloop wakes up - self.io_loop.add_callback(lambda: None) + handler = self.control_handlers.get(msg_type, None) + if handler is None: + self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) + else: + try: + result = handler(self.control_socket, idents, msg) + if inspect.isawaitable(result): + await result + except Exception: + self.log.error("Exception in control handler:", exc_info=True) - def start(self): - """register dispatchers for streams""" - self.io_loop = ioloop.IOLoop.current() - self.msg_queue: Queue[t.Any] = Queue() - self.io_loop.add_callback(self.dispatch_queue) + sys.stdout.flush() + sys.stderr.flush() + self._publish_status("idle", "control") + # flush to ensure reply is sent - self.control_stream.on_recv(self.dispatch_control, copy=False) + def start(self): + """Process messages on shell and control channels""" + self.exec_done = asyncio.Event() + self.io_loop = asyncio.get_running_loop() if self.control_thread: - control_loop = self.control_thread.io_loop + self.control_thread.create_task(self.process_control_messages()) else: - control_loop = self.io_loop + asyncio.create_task(self.process_control_messages()) - asyncio.run_coroutine_threadsafe(self.poll_control_queue(), control_loop.asyncio_loop) + self.shells[None]["interrupt"] = asyncio.Queue() + self.shells[None]["thread"] = None + if self.shell_msg_thread: + self.shells[None]["messages"] = janus.Queue() + self.shell_msg_thread.create_task(self.get_shell_messages()) + else: + # if sub-shells are not supported, getting messages and processing them is done in the main thread + # so use an asyncio.Queue directly, no need to handle multi-threading + self.shells[None]["messages"] = asyncio.Queue() + asyncio.create_task(self.get_shell_messages()) - self.shell_stream.on_recv( - partial( - self.schedule_dispatch, - self.dispatch_shell, - ), - copy=False, - ) + asyncio.create_task(self.process_shell_messages()) # main shell/thread # publish idle status self._publish_status("starting", "shell") @@ -683,7 +588,7 @@ def finish_metadata(self, parent, metadata, reply_content): """ return metadata - async def execute_request(self, stream, ident, parent): + async def execute_request(self, socket, ident, parent): """handle an execute_request""" try: content = parent["content"] @@ -707,28 +612,23 @@ async def execute_request(self, stream, ident, parent): self.execution_count += 1 self._publish_execute_input(code, parent, self.execution_count) - cell_id = (parent.get("metadata") or {}).get("cellId") + shell_id = parent["header"].get("shell_id") + cell_id = parent.get("metadata", {}).get("cellId") - if _accepts_cell_id(self.do_execute): - reply_content = self.do_execute( - code, - silent, - store_history, - user_expressions, - allow_stdin, - cell_id=cell_id, - ) - else: - reply_content = self.do_execute( - code, - silent, - store_history, - user_expressions, - allow_stdin, - ) + kwargs = dict() + if _accepts_arg(self.do_execute, "shell_id"): + kwargs.update(shell_id=shell_id) + if _accepts_arg(self.do_execute, "cell_id"): + kwargs.update(cell_id=cell_id) + + reply_content = self.do_execute( + code, silent, store_history, user_expressions, allow_stdin, **kwargs + ) + DEBUG(f"execute_request {parent['header'].get('shell_id')}") if inspect.isawaitable(reply_content): reply_content = await reply_content + DEBUG(f"{parent['header'].get('shell_id')} {reply_content=}") # Flush output before sending the reply. sys.stdout.flush() @@ -743,19 +643,21 @@ async def execute_request(self, stream, ident, parent): reply_content = json_clean(reply_content) metadata = self.finish_metadata(parent, metadata, reply_content) - reply_msg = self.session.send( - stream, - "execute_reply", - reply_content, - parent, - metadata=metadata, - ident=ident, - ) + DEBUG(f"sending reply for shell_id: {parent['header'].get('shell_id')}") + with self.shell_socket_lock: + reply_msg = self.session.send( + socket, + "execute_reply", + reply_content, + parent, + metadata=metadata, + ident=ident, + ) self.log.debug("%s", reply_msg) if not silent and reply_msg["content"]["status"] == "error" and stop_on_error: - self._abort_queues() + await self._abort_queues() def do_execute( self, @@ -766,6 +668,7 @@ def do_execute( allow_stdin=False, *, cell_id=None, + shell_id=None, ): """Execute user code. Must be overridden by subclasses.""" raise NotImplementedError @@ -863,11 +766,23 @@ def kernel_info(self): "help_links": self.help_links, } - async def kernel_info_request(self, stream, ident, parent): + async def create_subshell_request(self, socket, ident, parent): + shell_id = str(uuid.uuid4()) + self.shells[shell_id] = {} + self.shells[shell_id]["thread"] = subshell_thread = ShellThread(shell_id) + subshell_thread.start() + subshell_thread.create_task(self.process_shell_messages(shell_id)) + content = { + "status": "ok", + "shell_id": shell_id, + } + self.session.send(socket, "create_subshell_reply", content, parent, ident) + + async def kernel_info_request(self, socket, ident, parent): """Handle a kernel info request.""" content = {"status": "ok"} content.update(self.kernel_info) - msg = self.session.send(stream, "kernel_info_reply", content, parent, ident) + msg = self.session.send(socket, "kernel_info_reply", content, parent, ident) self.log.debug("%s", msg) async def comm_info_request(self, stream, ident, parent): @@ -907,31 +822,36 @@ def _send_interupt_children(self): except OSError: pass - async def interrupt_request(self, stream, ident, parent): + async def interrupt_request(self, socket, ident, parent): """Handle an interrupt request.""" self._send_interupt_children() content = parent["content"] - self.session.send(stream, "interrupt_reply", content, parent, ident=ident) + self.session.send(socket, "interrupt_reply", content, parent, ident=ident) return - async def shutdown_request(self, stream, ident, parent): + async def shutdown_request(self, socket, ident, parent): """Handle a shutdown request.""" content = self.do_shutdown(parent["content"]["restart"]) if inspect.isawaitable(content): content = await content - self.session.send(stream, "shutdown_reply", content, parent, ident=ident) + self.session.send(socket, "shutdown_reply", content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg("shutdown_reply", content, parent) await self._at_shutdown() - self.log.debug("Stopping control ioloop") - control_io_loop = self.control_stream.io_loop - control_io_loop.add_callback(control_io_loop.stop) - - self.log.debug("Stopping shell ioloop") - shell_io_loop = self.shell_stream.io_loop - shell_io_loop.add_callback(shell_io_loop.stop) + if self.control_thread: + self.control_thread.stop() + if self.shell_msg_thread: + self.shell_msg_thread.stop() + if self.iopub_thread: + self.iopub_thread.stop() + for shell in self.shells.values(): + subshell_thread = shell["thread"] + if subshell_thread is not None: + subshell_thread.stop() + + self._stop_queue.sync_q.put(None) def do_shutdown(self, restart): """Override in subclasses to do things when the frontend shuts down the @@ -939,7 +859,7 @@ def do_shutdown(self, restart): """ return {"status": "ok", "restart": restart} - async def is_complete_request(self, stream, ident, parent): + async def is_complete_request(self, socket, ident, parent): """Handle an is_complete request.""" content = parent["content"] code = content["code"] @@ -948,21 +868,21 @@ async def is_complete_request(self, stream, ident, parent): if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) - reply_msg = self.session.send(stream, "is_complete_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "is_complete_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) def do_is_complete(self, code): """Override in subclasses to find completions.""" return {"status": "unknown"} - async def debug_request(self, stream, ident, parent): + async def debug_request(self, socket, ident, parent): """Handle a debug request.""" content = parent["content"] reply_content = self.do_debug_request(content) if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) - reply_msg = self.session.send(stream, "debug_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "debug_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) def get_process_metric_value(self, process, name, attribute=None): @@ -978,7 +898,7 @@ def get_process_metric_value(self, process, name, attribute=None): except BaseException: return None - async def usage_request(self, stream, ident, parent): + async def usage_request(self, socket, ident, parent): """Handle a usage request.""" reply_content = {"hostname": socket.gethostname(), "pid": os.getpid()} current_process = psutil.Process() @@ -1008,7 +928,7 @@ async def usage_request(self, stream, ident, parent): reply_content["host_cpu_percent"] = cpu_percent reply_content["cpu_count"] = psutil.cpu_count(logical=True) reply_content["host_virtual_memory"] = dict(psutil.virtual_memory()._asdict()) - reply_msg = self.session.send(stream, "usage_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "usage_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) async def do_debug_request(self, msg): @@ -1018,7 +938,7 @@ async def do_debug_request(self, msg): # Engine methods (DEPRECATED) # --------------------------------------------------------------------------- - async def apply_request(self, stream, ident, parent): # pragma: no cover + async def apply_request(self, socket, ident, parent): # pragma: no cover """Handle an apply request.""" self.log.warning("apply_request is deprecated in kernel_base, moving to ipyparallel.") try: @@ -1040,7 +960,7 @@ async def apply_request(self, stream, ident, parent): # pragma: no cover md = self.finish_metadata(parent, md, reply_content) self.session.send( - stream, + socket, "apply_reply", reply_content, parent=parent, @@ -1057,7 +977,7 @@ def do_apply(self, content, bufs, msg_id, reply_metadata): # Control messages (DEPRECATED) # --------------------------------------------------------------------------- - async def abort_request(self, stream, ident, parent): # pragma: no cover + async def abort_request(self, socket, ident, parent): # pragma: no cover """abort a specific msg by id""" self.log.warning( "abort_request is deprecated in kernel_base. It is only part of IPython parallel" @@ -1066,23 +986,23 @@ async def abort_request(self, stream, ident, parent): # pragma: no cover if isinstance(msg_ids, str): msg_ids = [msg_ids] if not msg_ids: - self._abort_queues() + await self._abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status="ok") reply_msg = self.session.send( - stream, "abort_reply", content=content, parent=parent, ident=ident + socket, "abort_reply", content=content, parent=parent, ident=ident ) self.log.debug("%s", reply_msg) - async def clear_request(self, stream, idents, parent): # pragma: no cover + async def clear_request(self, socket, idents, parent): # pragma: no cover """Clear our namespace.""" self.log.warning( "clear_request is deprecated in kernel_base. It is only part of IPython parallel" ) content = self.do_clear() - self.session.send(stream, "clear_reply", ident=idents, parent=parent, content=content) + self.session.send(socket, "clear_reply", ident=idents, parent=parent, content=content) def do_clear(self): """DEPRECATED since 4.0.3""" @@ -1100,7 +1020,7 @@ def _topic(self, topic): _aborting = Bool(False) - def _abort_queues(self): + async def _abort_queues(self): # while this flag is true, # execute requests will be aborted self._aborting = True @@ -1108,24 +1028,25 @@ def _abort_queues(self): # flush streams, so all currently waiting messages # are added to the queue - self.shell_stream.flush() # Callback to signal that we are done aborting # dispatch functions _must_ be async async def stop_aborting(): + await asyncio.sleep(self.stop_on_error_timeout) self.log.info("Finishing abort") self._aborting = False # put the stop-aborting event on the message queue # so that all messages already waiting in the queue are aborted # before we reset the flag - schedule_stop_aborting = partial(self.schedule_dispatch, stop_aborting) + # schedule_stop_aborting = partial(self.schedule_dispatch, stop_aborting) # if we have a delay, give messages this long to arrive on the queue # before we stop aborting requests - asyncio.get_event_loop().call_later(self.stop_on_error_timeout, schedule_stop_aborting) + # asyncio.get_event_loop().call_later(self.stop_on_error_timeout, schedule_stop_aborting) + asyncio.create_task(stop_aborting()) - def _send_abort_reply(self, stream, msg, idents): + async def _send_abort_reply(self, socket, msg, idents): """Send a reply to an aborted request""" self.log.info(f"Aborting {msg['header']['msg_id']}: {msg['header']['msg_type']}") reply_type = msg["header"]["msg_type"].rsplit("_", 1)[0] + "_reply" @@ -1135,7 +1056,7 @@ def _send_abort_reply(self, stream, msg, idents): md.update(status) self.session.send( - stream, + socket, reply_type, metadata=md, content=status, @@ -1319,4 +1240,3 @@ async def _at_shutdown(self): ident=self._topic("shutdown"), ) self.log.debug("%s", self._shutdown_message) - self.control_stream.flush(zmq.POLLOUT) diff --git a/ipykernel/shell.py b/ipykernel/shell.py new file mode 100644 index 000000000..07a57a6d3 --- /dev/null +++ b/ipykernel/shell.py @@ -0,0 +1,6 @@ +from .athread import AThread + + +class ShellThread(AThread): + def __init__(self, name: str): + super().__init__(name=f"Shell:{name}") diff --git a/ipykernel/tests/conftest.py b/ipykernel/tests/conftest.py index 45341cef6..9d09a1233 100644 --- a/ipykernel/tests/conftest.py +++ b/ipykernel/tests/conftest.py @@ -73,8 +73,8 @@ def destroy(self): @no_type_check async def test_shell_message(self, *args, **kwargs): msg_list = self._prep_msg(*args, **kwargs) - await self.dispatch_shell(msg_list) - self.shell_stream.flush() + self.shell_msg_queues.setdefault(None, asyncio.Queue()) + await self.get_shell_message(msg_list) return await self._wait_for_msg() @no_type_check @@ -91,10 +91,12 @@ def _prep_msg(self, *args, **kwargs): self._reply = None raw_msg = self.session.msg(*args, **kwargs) msg = self.session.serialize(raw_msg) - return [zmq.Message(m) for m in msg] + return msg + # return [zmq.Message(m) for m in msg] async def _wait_for_msg(self): while not self._reply: + print(f"{self._reply=}") await asyncio.sleep(0.1) _, msg = self.session.feed_identities(self._reply) return self.session.deserialize(msg) diff --git a/ipykernel/tests/test_async.py b/ipykernel/tests/test_async.py index c58a24d9d..8b80b16f6 100644 --- a/ipykernel/tests/test_async.py +++ b/ipykernel/tests/test_async.py @@ -52,7 +52,7 @@ def test_async_interrupt(asynclib, request): stream = KC.get_iopub_msg(timeout=TIMEOUT) # wait for the stream output to be sure kernel is in the async block validate_message(stream, "stream") - assert stream["content"]["text"] == "begin\n" + assert stream["content"]["text"].startswith("begin") KM.interrupt_kernel() reply = KC.get_shell_msg()["content"] diff --git a/ipykernel/tests/test_debugger.py b/ipykernel/tests/test_debugger.py index 200154cfc..dbafca893 100644 --- a/ipykernel/tests/test_debugger.py +++ b/ipykernel/tests/test_debugger.py @@ -92,6 +92,7 @@ def test_attach_debug(kernel_with_debug): reply = wait_for_debug_request( kernel_with_debug, "evaluate", {"expression": "'a' + 'b'", "context": "repl"} ) + print(reply) assert reply["success"] assert reply["body"]["result"] == "" diff --git a/ipykernel/tests/test_embed_kernel.py b/ipykernel/tests/test_embed_kernel.py index ff97edfa5..96cf6a0e5 100644 --- a/ipykernel/tests/test_embed_kernel.py +++ b/ipykernel/tests/test_embed_kernel.py @@ -195,6 +195,7 @@ def test_embed_kernel_reentrant(): # exit from embed_kernel client.execute("get_ipython().exit_now = True") msg = client.get_shell_msg(timeout=TIMEOUT) + print(f"{msg=}") time.sleep(0.2) @@ -206,7 +207,7 @@ def test_embed_kernel_func(): def trigger_stop(): time.sleep(1) app = IPKernelApp.instance() - app.io_loop.add_callback(app.io_loop.stop) + app.stop() IPKernelApp.clear_instance() thread = threading.Thread(target=trigger_stop) diff --git a/ipykernel/tests/test_io.py b/ipykernel/tests/test_io.py index f7fc8b3f4..776399099 100644 --- a/ipykernel/tests/test_io.py +++ b/ipykernel/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import zmq +import zmq.asyncio from jupyter_client.session import Session from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream @@ -13,7 +14,7 @@ def test_io_api(): """Test that wrapped stdout has the same API as a normal TextIO object""" session = Session() - ctx = zmq.Context() + ctx = zmq.asyncio.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread.start() @@ -45,35 +46,39 @@ def test_io_api(): def test_io_isatty(): session = Session() - ctx = zmq.Context() + ctx = zmq.asyncio.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread.start() stream = OutStream(session, thread, "stdout", isatty=True) assert stream.isatty() + thread.stop() + thread.close() + ctx.term() def test_io_thread(): - ctx = zmq.Context() + ctx = zmq.asyncio.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread._setup_pipe_in() - msg = [thread._pipe_uuid, b"a"] - thread._handle_pipe_msg(msg) + # msg = [thread._pipe_uuid, b"a"] + # thread._handle_pipe_msg(msg) ctx1, pipe = thread._setup_pipe_out() pipe.close() - thread._pipe_in.close() + thread._pipe_in1.close() thread._check_mp_mode = lambda: MASTER # type:ignore thread._really_send([b"hi"]) ctx1.destroy() - thread.close() + thread.stop() thread.close() thread._really_send(None) + ctx.term() def test_background_socket(): - ctx = zmq.Context() + ctx = zmq.asyncio.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) sock = BackgroundSocket(thread) @@ -84,11 +89,14 @@ def test_background_socket(): assert thread.socket.linger == 101 assert sock.io_thread == thread sock.send(b"hi") + thread.stop() + thread.close() + ctx.term() def test_outstream(): session = Session() - ctx = zmq.Context() + ctx = zmq.asyncio.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread.start() @@ -96,7 +104,9 @@ def test_outstream(): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) stream = OutStream(session, pub, "stdout") + stream.close() stream = OutStream(session, thread, "stdout", pipe=object()) + stream.close() stream = OutStream(session, thread, "stdout", isatty=True, echo=io.StringIO()) with pytest.raises(io.UnsupportedOperation): @@ -106,3 +116,6 @@ def test_outstream(): stream.write("hi") stream.writelines(["ab", "cd"]) assert stream.writable() + thread.stop() + thread.close() + # ctx.term() diff --git a/ipykernel/tests/test_subshell.py b/ipykernel/tests/test_subshell.py new file mode 100644 index 000000000..e6a0eb491 --- /dev/null +++ b/ipykernel/tests/test_subshell.py @@ -0,0 +1,66 @@ +"""Test subshell""" + +import time + +from .utils import flush_channels, get_reply, start_new_kernel + +KC = KM = None + + +def setup_function(): + """start the global kernel (if it isn't running) and return its client""" + global KM, KC + KM, KC = start_new_kernel() + flush_channels(KC) + + +def teardown_function(): + assert KC is not None + assert KM is not None + KC.stop_channels() + KM.shutdown_kernel(now=True) + + +def test_subshell(): + flush_channels(KC) + + # create sub-shells + shell_ids = [] + n_subshells = 5 + for _ in range(n_subshells): + msg = KC.session.msg("create_subshell_request") + KC.control_channel.send(msg) + reply = get_reply(KC, msg["header"]["msg_id"], channel="control") + shell_ids.append(reply["content"]["shell_id"]) + + t0 = time.time() + seconds1 = 1 # main shell execution time + # will wait some time in main shell + msg1 = KC.session.msg( + "execute_request", {"code": f"import time; time.sleep({seconds1})", "silent": False} + ) + KC.shell_channel.send(msg1) + msg_ids = [] + seconds = [] + # try running (blocking) code in parallel + # will wait more time in each sub-shell + for i, shell_id in enumerate(shell_ids): + seconds2 = (2 + i * 0.1) * seconds1 # sub-shell execution time + msg2 = KC.session.msg( + "execute_request", {"code": f"import time; time.sleep({seconds2})", "silent": False} + ) + msg2["header"]["shell_id"] = shell_id + KC.shell_channel.send(msg2) + seconds.append(seconds2) + msg_ids.append(msg2["header"]["msg_id"]) + # in any case, main shell should finish first + reply = get_reply(KC, msg1["header"]["msg_id"]) + dt1 = time.time() - t0 + # main shell execution should not take much more than seconds1 + assert seconds1 < dt1 < seconds1 * 1.1 + # in any case, sub-shells should finish after main shell, an in order + for i, msg_id in enumerate(msg_ids): + reply = get_reply(KC, msg_id) + dt2 = time.time() - t0 + # sub-shell execution should not take much more than seconds2 + assert seconds[i] < dt2 < seconds[i] * 1.1 diff --git a/pyproject.toml b/pyproject.toml index 77b885ecb..fa00881d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,13 +30,13 @@ dependencies = [ "comm>=0.1.1", "traitlets>=5.4.0", "jupyter_client>=6.1.12", - "tornado>=6.1", "matplotlib-inline>=0.1", 'appnope;platform_system=="Darwin"', - "pyzmq>=17", + "pyzmq==25.0.0b1", "psutil", "nest_asyncio", "packaging", + "janus>=1.0.0", ] [project.optional-dependencies] @@ -54,7 +54,8 @@ test = [ "ipyparallel", "pre-commit", "pytest-asyncio", - "pytest-timeout" + "pytest-timeout", + "tornado>=6.1", ] cov = [ "coverage[toml]", @@ -168,6 +169,7 @@ filterwarnings= [ "ignore:unclosed