Skip to content

Commit

Permalink
Use IPython.display.Markdown() instead of rich in notebook setting
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Nov 22, 2024
1 parent 5a6f1a3 commit d40e11a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 110 deletions.
218 changes: 113 additions & 105 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import os
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -432,6 +431,7 @@ def extract_data(
*args: Content | str,
data_model: type[BaseModel],
echo: Literal["text", "all", "none"] = "none",
stream: bool = False,
) -> dict[str, Any]:
"""
Extract structured data from the given input.
Expand All @@ -444,23 +444,23 @@ def extract_data(
A Pydantic model describing the structure of the data to extract.
echo
Whether to echo text content, all content (i.e., tool calls), or no content.
stream
Whether to stream the response (i.e., have the response appear in chunks).
Returns
-------
dict[str, Any]
The extracted data.
"""

with JupyterFriendlyLive() as live:
response = ChatResponse(
self._submit_turns(
user_turn(*args),
data_model=data_model,
echo=echo,
live=live,
stream=echo != "none",
)
response = ChatResponse(
self._submit_turns(
user_turn(*args),
data_model=data_model,
echo=echo,
stream=stream,
)
)

for _ in response:
pass
Expand All @@ -486,6 +486,7 @@ async def extract_data_async(
*args: Content | str,
data_model: type[BaseModel],
echo: Literal["text", "all", "none"] = "none",
stream: bool = False,
) -> dict[str, Any]:
"""
Extract structured data from the given input asynchronously.
Expand All @@ -498,23 +499,23 @@ async def extract_data_async(
A Pydantic model describing the structure of the data to extract.
echo
Whether to echo text content, all content (i.e., tool calls), or no content
stream
Whether to stream the response (i.e., have the response appear in chunks).
Defaults to `True` if `echo` is not "none".
Returns
-------
dict[str, Any]
The extracted data.
"""

with JupyterFriendlyLive() as live:
response = ChatResponseAsync(
self._submit_turns_async(
user_turn(*args),
data_model=data_model,
echo=echo,
live=live,
stream=echo != "none",
)
response = ChatResponseAsync(
self._submit_turns_async(
user_turn(*args),
data_model=data_model,
echo=echo,
stream=stream,
)
)

async for _ in response:
pass
Expand Down Expand Up @@ -626,18 +627,15 @@ def _chat_impl(
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
user_turn_result: Turn | None = user_turn

with JupyterFriendlyLive() as live:
while user_turn_result is not None:
for chunk in self._submit_turns(
user_turn_result,
echo=echo,
live=live,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = self._invoke_tools()
while user_turn_result is not None:
for chunk in self._submit_turns(
user_turn_result,
echo=echo,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = self._invoke_tools()

async def _chat_impl_async(
self,
Expand All @@ -647,32 +645,28 @@ async def _chat_impl_async(
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
user_turn_result: Turn | None = user_turn

with JupyterFriendlyLive() as live:
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
user_turn_result,
echo=echo,
live=live,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = await self._invoke_tools_async()
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
user_turn_result,
echo=echo,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = await self._invoke_tools_async()

def _submit_turns(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
live: "rich.live.Live",
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self.tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

emit = emitter(echo, live)
emit = emitter(echo)

if echo == "all":
emit_user_contents(user_turn, emit)
Expand Down Expand Up @@ -726,12 +720,11 @@ async def _submit_turns_async(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
live: "rich.live.Live",
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
emit = emitter(echo, live)
emit = emitter(echo)

if echo == "all":
emit_user_contents(user_turn, emit)
Expand Down Expand Up @@ -966,72 +959,18 @@ def consumed(self) -> bool:
return self._generator.ag_frame is None


# ----------------------------------------------------------------------------
# Jupyter-friendly rich live context managers
# ----------------------------------------------------------------------------


@contextmanager
def JupyterFriendlyLive():
"""
A special `rich.live.Live` context manager with special handling for Jupyter.
"""
import rich.live

with JupyterFriendlyConsole() as console:
with rich.live.Live(console=console, auto_refresh=False) as live:
yield live


@contextmanager
def JupyterFriendlyConsole():
import rich.console
import rich.jupyter

# Force jupyter mode if running in Quarto (so that side-effects are captured)
is_quarto = os.getenv("QUARTO_PYTHON", None) is not None
console = rich.console.Console(
force_jupyter=True if is_quarto else None,
)

# Prevent rich from inserting line breaks in a Jupyter context
# (and, instead, rely on the browser to wrap text)
console.soft_wrap = console.is_jupyter

html_format = rich.jupyter.JUPYTER_HTML_FORMAT

# Remove the `white-space:pre;` CSS style since the LLM's response is
# (usually) already pre-formatted and essentially assumes a browser context
rich.jupyter.JUPYTER_HTML_FORMAT = html_format.replace(
"white-space:pre;", "text-wrap-mode:wrap;word-break:break-word;"
)
yield console

rich.jupyter.JUPYTER_HTML_FORMAT = html_format


# ----------------------------------------------------------------------------
# Helpers for emitting content
# ----------------------------------------------------------------------------


def emitter(
echo: Literal["text", "all", "none"],
live: "rich.live.Live",
) -> Callable[[Content | str], None]:
def emitter(echo: Literal["text", "all", "none"]) -> Callable[[Content | str], None]:
if echo == "none":
return lambda _: None

from rich.markdown import Markdown

def emit(x: Content | str):
x = str(x)
current = live.get_renderable()
if isinstance(current, Markdown):
x = current.markup + x
live.update(Markdown(x), refresh=True)
stream = StreamingMarkdown()

return emit
return lambda x: stream.update(str(x))


def emit_user_contents(
Expand Down Expand Up @@ -1063,3 +1002,72 @@ def emit_other_contents(
# a non-whitespace character before it. The   is a hack to work
# around that, but it's also decent for readability.
emit(f" {str(content)}\n\n")


class StreamingMarkdown:
"""
Stream markdown content.
This uses rich for non-notebook contexts, and IPython.display.Markdown for
notebook (+Quarto) contexts.
"""

content: str = ""
ipy_display_id: Optional[str] = None
live: "Optional[rich.live.Live]" = None

def __init__(self):
self.content = ""

from rich.console import Console
from rich.live import Live

# rich seems to be pretty good at detecting a (Jupyter) notebook
# context, so utilize that, but use IPython.display.Markdown instead if
# we're in a notebook (or Quarto) since that's a much more responsive
# way to display markdown
console = Console()
if console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None:
self.ipy_display_id = self._init_display()
else:
live = Live(auto_refresh=False, vertical_overflow="visible")
live.start()
self.live = live

def update(self, content: str):
self.content += content

if self.ipy_display_id is not None:
from IPython.display import Markdown, update_display

update_display(
Markdown(self.content),
display_id=self.ipy_display_id,
)
elif self.live is not None:
from rich.markdown import Markdown

self.live.update(Markdown(self.content), refresh=True)

def _init_display(self) -> str:
try:
from IPython.display import HTML, Markdown, display
except ImportError:
raise ImportError(
"The IPython package is required for displaying content in a Jupyter notebook. "
"Install it with `pip install ipython`."
)

display(HTML("<div class='chatlas-markdown'></div>"))
handle = display(Markdown(""), display_id=True)
if handle is None:
raise ValueError("Failed to create display handle")
return handle.display_id

def __del__(self):
if self.live is not None:
self.live.stop()
self.live = None
if self.ipy_display_id is not None:
# I don't think there's any more cleanup to do here?
self.ipy_display_id = None
11 changes: 6 additions & 5 deletions docs/styles.scss
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ pre:has(> code) {
padding: .4em;
}

/* Add a border around the rich console output */
.jp-OutputArea-output {
border: 1px solid var(--bs-gray-300);
padding: 0.5em;
border-radius: 0.25rem;
/* Add a border around the Markdown() output */
.chatlas-markdown + .cell-output-markdown {
border: 1px solid var(--bs-gray-300);
padding: .75rem 1rem;
border-radius: .25rem;
margin-bottom: 1rem;
}

0 comments on commit d40e11a

Please sign in to comment.