diff --git a/chatlas/_chat.py b/chatlas/_chat.py index a90a1f2..35d39b5 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, @@ -432,6 +431,7 @@ def extract_data( *args: Content | str, data_model: type[BaseModel], echo: Literal["text", "all", "none"] = "none", + stream: bool = False, ) -> dict[str, Any]: """ Extract structured data from the given input. @@ -444,6 +444,8 @@ def extract_data( A Pydantic model describing the structure of the data to extract. echo Whether to echo text content, all content (i.e., tool calls), or no content. + stream + Whether to stream the response (i.e., have the response appear in chunks). Returns ------- @@ -451,16 +453,14 @@ def extract_data( The extracted data. """ - with JupyterFriendlyLive() as live: - response = ChatResponse( - self._submit_turns( - user_turn(*args), - data_model=data_model, - echo=echo, - live=live, - stream=echo != "none", - ) + response = ChatResponse( + self._submit_turns( + user_turn(*args), + data_model=data_model, + echo=echo, + stream=stream, ) + ) for _ in response: pass @@ -486,6 +486,7 @@ async def extract_data_async( *args: Content | str, data_model: type[BaseModel], echo: Literal["text", "all", "none"] = "none", + stream: bool = False, ) -> dict[str, Any]: """ Extract structured data from the given input asynchronously. @@ -498,23 +499,23 @@ async def extract_data_async( A Pydantic model describing the structure of the data to extract. echo Whether to echo text content, all content (i.e., tool calls), or no content + stream + Whether to stream the response (i.e., have the response appear in chunks). + Defaults to `True` if `echo` is not "none". Returns ------- dict[str, Any] The extracted data. """ - - with JupyterFriendlyLive() as live: - response = ChatResponseAsync( - self._submit_turns_async( - user_turn(*args), - data_model=data_model, - echo=echo, - live=live, - stream=echo != "none", - ) + response = ChatResponseAsync( + self._submit_turns_async( + user_turn(*args), + data_model=data_model, + echo=echo, + stream=stream, ) + ) async for _ in response: pass @@ -626,18 +627,15 @@ def _chat_impl( kwargs: Optional[SubmitInputArgsT] = None, ) -> Generator[str, None, None]: user_turn_result: Turn | None = user_turn - - with JupyterFriendlyLive() as live: - while user_turn_result is not None: - for chunk in self._submit_turns( - user_turn_result, - echo=echo, - live=live, - stream=stream, - kwargs=kwargs, - ): - yield chunk - user_turn_result = self._invoke_tools() + while user_turn_result is not None: + for chunk in self._submit_turns( + user_turn_result, + echo=echo, + stream=stream, + kwargs=kwargs, + ): + yield chunk + user_turn_result = self._invoke_tools() async def _chat_impl_async( self, @@ -647,24 +645,20 @@ async def _chat_impl_async( kwargs: Optional[SubmitInputArgsT] = None, ) -> AsyncGenerator[str, None]: user_turn_result: Turn | None = user_turn - - with JupyterFriendlyLive() as live: - while user_turn_result is not None: - async for chunk in self._submit_turns_async( - user_turn_result, - echo=echo, - live=live, - stream=stream, - kwargs=kwargs, - ): - yield chunk - user_turn_result = await self._invoke_tools_async() + while user_turn_result is not None: + async for chunk in self._submit_turns_async( + user_turn_result, + echo=echo, + stream=stream, + kwargs=kwargs, + ): + yield chunk + user_turn_result = await self._invoke_tools_async() def _submit_turns( self, user_turn: Turn, echo: Literal["text", "all", "none"], - live: "rich.live.Live", stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, @@ -672,7 +666,7 @@ def _submit_turns( if any(x._is_async for x in self.tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") - emit = emitter(echo, live) + emit = emitter(echo) if echo == "all": emit_user_contents(user_turn, emit) @@ -726,12 +720,11 @@ async def _submit_turns_async( self, user_turn: Turn, echo: Literal["text", "all", "none"], - live: "rich.live.Live", stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, ) -> AsyncGenerator[str, None]: - emit = emitter(echo, live) + emit = emitter(echo) if echo == "all": emit_user_contents(user_turn, emit) @@ -966,72 +959,18 @@ def consumed(self) -> bool: return self._generator.ag_frame is None -# ---------------------------------------------------------------------------- -# Jupyter-friendly rich live context managers -# ---------------------------------------------------------------------------- - - -@contextmanager -def JupyterFriendlyLive(): - """ - A special `rich.live.Live` context manager with special handling for Jupyter. - """ - import rich.live - - with JupyterFriendlyConsole() as console: - with rich.live.Live(console=console, auto_refresh=False) as live: - yield live - - -@contextmanager -def JupyterFriendlyConsole(): - import rich.console - import rich.jupyter - - # Force jupyter mode if running in Quarto (so that side-effects are captured) - is_quarto = os.getenv("QUARTO_PYTHON", None) is not None - console = rich.console.Console( - force_jupyter=True if is_quarto else None, - ) - - # Prevent rich from inserting line breaks in a Jupyter context - # (and, instead, rely on the browser to wrap text) - console.soft_wrap = console.is_jupyter - - html_format = rich.jupyter.JUPYTER_HTML_FORMAT - - # Remove the `white-space:pre;` CSS style since the LLM's response is - # (usually) already pre-formatted and essentially assumes a browser context - rich.jupyter.JUPYTER_HTML_FORMAT = html_format.replace( - "white-space:pre;", "text-wrap-mode:wrap;word-break:break-word;" - ) - yield console - - rich.jupyter.JUPYTER_HTML_FORMAT = html_format - - # ---------------------------------------------------------------------------- # Helpers for emitting content # ---------------------------------------------------------------------------- -def emitter( - echo: Literal["text", "all", "none"], - live: "rich.live.Live", -) -> Callable[[Content | str], None]: +def emitter(echo: Literal["text", "all", "none"]) -> Callable[[Content | str], None]: if echo == "none": return lambda _: None - from rich.markdown import Markdown - - def emit(x: Content | str): - x = str(x) - current = live.get_renderable() - if isinstance(current, Markdown): - x = current.markup + x - live.update(Markdown(x), refresh=True) + stream = StreamingMarkdown() - return emit + return lambda x: stream.update(str(x)) def emit_user_contents( @@ -1063,3 +1002,72 @@ def emit_other_contents( # a non-whitespace character before it. The is a hack to work # around that, but it's also decent for readability. emit(f" {str(content)}\n\n") + + +class StreamingMarkdown: + """ + Stream markdown content. + + This uses rich for non-notebook contexts, and IPython.display.Markdown for + notebook (+Quarto) contexts. + """ + + content: str = "" + ipy_display_id: Optional[str] = None + live: "Optional[rich.live.Live]" = None + + def __init__(self): + self.content = "" + + from rich.console import Console + from rich.live import Live + + # rich seems to be pretty good at detecting a (Jupyter) notebook + # context, so utilize that, but use IPython.display.Markdown instead if + # we're in a notebook (or Quarto) since that's a much more responsive + # way to display markdown + console = Console() + if console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None: + self.ipy_display_id = self._init_display() + else: + live = Live(auto_refresh=False, vertical_overflow="visible") + live.start() + self.live = live + + def update(self, content: str): + self.content += content + + if self.ipy_display_id is not None: + from IPython.display import Markdown, update_display + + update_display( + Markdown(self.content), + display_id=self.ipy_display_id, + ) + elif self.live is not None: + from rich.markdown import Markdown + + self.live.update(Markdown(self.content), refresh=True) + + def _init_display(self) -> str: + try: + from IPython.display import HTML, Markdown, display + except ImportError: + raise ImportError( + "The IPython package is required for displaying content in a Jupyter notebook. " + "Install it with `pip install ipython`." + ) + + display(HTML("
")) + handle = display(Markdown(""), display_id=True) + if handle is None: + raise ValueError("Failed to create display handle") + return handle.display_id + + def __del__(self): + if self.live is not None: + self.live.stop() + self.live = None + if self.ipy_display_id is not None: + # I don't think there's any more cleanup to do here? + self.ipy_display_id = None diff --git a/docs/styles.scss b/docs/styles.scss index ae94cec..f029f61 100644 --- a/docs/styles.scss +++ b/docs/styles.scss @@ -7,9 +7,10 @@ pre:has(> code) { padding: .4em; } -/* Add a border around the rich console output */ -.jp-OutputArea-output { - border: 1px solid var(--bs-gray-300); - padding: 0.5em; - border-radius: 0.25rem; +/* Add a border around the Markdown() output */ +.chatlas-markdown + .cell-output-markdown { + border: 1px solid var(--bs-gray-300); + padding: .75rem 1rem; + border-radius: .25rem; + margin-bottom: 1rem; } \ No newline at end of file