From fbc42cdbc1b2d1dafcbd0985e913ebfe11e82706 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 10 Oct 2024 19:32:41 -0700 Subject: [PATCH] Polish HTML views for common Langfun objects. PiperOrigin-RevId: 684661406 --- langfun/core/logging.py | 187 ++++++++++--- langfun/core/logging_test.py | 33 +++ langfun/core/message.py | 308 +++++++++++++++++----- langfun/core/message_test.py | 115 ++++---- langfun/core/modalities/audio.py | 2 +- langfun/core/modalities/audio_test.py | 2 +- langfun/core/modalities/image.py | 2 +- langfun/core/modalities/image_test.py | 12 +- langfun/core/modalities/mime.py | 42 ++- langfun/core/modalities/mime_test.py | 39 +++ langfun/core/modalities/ms_office.py | 7 +- langfun/core/modalities/ms_office_test.py | 2 +- langfun/core/modalities/pdf_test.py | 2 +- langfun/core/modalities/video.py | 2 +- langfun/core/modalities/video_test.py | 4 +- langfun/core/structured/mapping.py | 38 +++ langfun/core/structured/mapping_test.py | 55 ++++ langfun/core/structured/schema.py | 34 +++ langfun/core/template.py | 108 +++++++- langfun/core/template_test.py | 37 +++ 20 files changed, 853 insertions(+), 178 deletions(-) diff --git a/langfun/core/logging.py b/langfun/core/logging.py index 7385807..6c6628f 100644 --- a/langfun/core/logging.py +++ b/langfun/core/logging.py @@ -13,16 +13,13 @@ # limitations under the License. """Langfun event logging.""" -from collections.abc import Iterator import contextlib import datetime -import io import typing -from typing import Any, Literal +from typing import Any, Iterator, Literal, Sequence from langfun.core import component from langfun.core import console -from langfun.core import repr_utils import pyglove as pg @@ -56,49 +53,153 @@ class LogEntry(pg.Object): def should_output(self, min_log_level: LogLevel) -> bool: return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level) - def _repr_html_(self) -> str: - s = io.StringIO() - padding_left = 50 * self.indent - s.write(f'
') - s.write(self._message_display) - if self.metadata: - s.write(repr_utils.html_repr(self.metadata)) - s.write('
') - return s.getvalue() - - @property - def _message_text_bgcolor(self) -> str: - match self.level: - case 'debug': - return '#EEEEEE' - case 'info': - return '#A3E4D7' - case 'warning': - return '#F8C471' - case 'error': - return '#F5C6CB' - case 'fatal': - return '#F19CBB' - case _: - raise ValueError(f'Unknown log level: {self.level}') - - @property - def _time_display(self) -> str: - display_text = self.time.strftime('%H:%M:%S') - alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f') - return ( - '{display_text}' + def _html_tree_view_summary( + self, + view: pg.views.HtmlTreeView, + title: str | pg.Html | None = None, + max_str_len_for_summary: int = pg.View.PresetArgValue(80), # pytype: disable=annotation-type-mismatch + **kwargs + ) -> str: + if len(self.message) > max_str_len_for_summary: + message = self.message[:max_str_len_for_summary] + '...' + else: + message = self.message + + s = pg.Html( + pg.Html.element( + 'span', + [self.time.strftime('%H:%M:%S')], + css_class=['log-time'] + ), + pg.Html.element( + 'span', + [pg.Html.escape(message)], + css_class=['log-summary'], + ), + ) + return view.summary( + self, + title=title or s, + max_str_len_for_summary=max_str_len_for_summary, + **kwargs, ) - @property - def _message_display(self) -> str: - return repr_utils.html_round_text( - self._time_display + ' ' + self.message, - background_color=self._message_text_bgcolor, + # pytype: disable=annotation-type-mismatch + def _html_tree_view_content( + self, + view: pg.views.HtmlTreeView, + root_path: pg.KeyPath, + collapse_log_metadata_level: int = pg.View.PresetArgValue(0), + max_str_len_for_summary: int = pg.View.PresetArgValue(80), + **kwargs + ) -> pg.Html: + # pytype: enable=annotation-type-mismatch + def render_message_text(): + if len(self.message) < max_str_len_for_summary: + return None + return pg.Html.element( + 'span', + [pg.Html.escape(self.message)], + css_class=['log-text'], + ) + + def render_metadata(): + if not self.metadata: + return None + return pg.Html.element( + 'div', + [ + view.render( + self.metadata, + name='metadata', + root_path=root_path + 'metadata', + parent=self, + collapse_level=( + root_path.depth + collapse_log_metadata_level + 1 + ) + ) + ], + css_class=['log-metadata'], + ) + + return pg.Html.element( + 'div', + [ + render_message_text(), + render_metadata(), + ], + css_class=['complex_value'], ) + def _html_style(self) -> list[str]: + return super()._html_style() + [ + """ + .log-time { + color: #222; + font-size: 12px; + padding-right: 10px; + } + .log-summary { + font-weight: normal; + font-style: italic; + padding: 4px; + } + .log-debug > summary > .summary_title::before { + content: '🛠️ ' + } + .log-info > summary > .summary_title::before { + content: '💡 ' + } + .log-warning > summary > .summary_title::before { + content: '❗ ' + } + .log-error > summary > .summary_title::before { + content: '❌ ' + } + .log-fatal > summary > .summary_title::before { + content: '💀 ' + } + .log-text { + display: block; + color: black; + font-style: italic; + padding: 20px; + border-radius: 5px; + background: rgba(255, 255, 255, 0.5); + white-space: pre-wrap; + } + details.log-entry { + margin: 0px 0px 10px; + border: 0px; + } + div.log-metadata { + margin: 10px 0px 0px 0px; + } + .log-metadata > details { + background-color: rgba(255, 255, 255, 0.5); + border: 1px solid transparent; + } + .log-debug { + background-color: #EEEEEE + } + .log-warning { + background-color: #F8C471 + } + .log-info { + background-color: #A3E4D7 + } + .log-error { + background-color: #F5C6CB + } + .log-fatal { + background-color: #F19CBB + } + """ + ] + + def _html_element_class(self) -> Sequence[str] | None: + return super()._html_element_class() + [f'log-{self.level}'] + def log(level: LogLevel, message: str, diff --git a/langfun/core/logging_test.py b/langfun/core/logging_test.py index 6330d00..17bdeea 100644 --- a/langfun/core/logging_test.py +++ b/langfun/core/logging_test.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for langfun.core.logging.""" +import datetime +import inspect import unittest from langfun.core import logging @@ -52,6 +54,37 @@ def assert_color(entry, color): assert_color(logging.error('hi', indent=2, x=1, y=2), '#F5C6CB') assert_color(logging.fatal('hi', indent=2, x=1, y=2), '#F19CBB') + def assert_html_content(self, html, expected): + expected = inspect.cleandoc(expected).strip() + actual = html.content.strip() + if actual != expected: + print(actual) + self.assertEqual(actual, expected) + + def test_html(self): + time = datetime.datetime(2024, 10, 10, 12, 30, 45) + self.assert_html_content( + logging.LogEntry( + level='info', message='5 + 2 > 3', + time=time, metadata={} + ).to_html(enable_summary_tooltip=False), + """ +
12:30:455 + 2 > 3
+ """ + ) + self.assert_html_content( + logging.LogEntry( + level='error', message='This is a longer message: 5 + 2 > 3', + time=time, metadata=dict(x=1, y=2) + ).to_html( + max_str_len_for_summary=10, + enable_summary_tooltip=False, + collapse_log_metadata_level=1 + ), + """ +
12:30:45This is a ...
This is a longer message: 5 + 2 > 3
+ """ + ) if __name__ == '__main__': unittest.main() diff --git a/langfun/core/message.py b/langfun/core/message.py index f2b0fd6..fd7b083 100644 --- a/langfun/core/message.py +++ b/langfun/core/message.py @@ -14,13 +14,11 @@ """Messages that are exchanged between users and agents.""" import contextlib -import html import io -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional, Sequence, Union from langfun.core import modality from langfun.core import natural_language -from langfun.core import repr_utils import pyglove as pg @@ -406,6 +404,11 @@ def tag(self, tag: str) -> None: with pg.notify_on_change(False): self.tags.append(tag) + def has_tag(self, tag: str | tuple[str, ...]) -> bool: + if isinstance(tag, str): + return tag in self.tags + return any(t in self.tags for t in tag) + # # Message source chain. # @@ -503,79 +506,244 @@ def __getattr__(self, key: str) -> Any: v = self.metadata[key] return v.value if isinstance(v, pg.Ref) else v - def _repr_html_(self): - return self.to_html().content - - def to_html( + # pytype: disable=annotation-type-mismatch + def _html_tree_view_content( self, - include_message_type: bool = True - ) -> repr_utils.Html: - """Returns the HTML representation of the message.""" - s = io.StringIO() - s.write('
') - # Title bar. - if include_message_type: - s.write( - repr_utils.html_round_text( - self.__class__.__name__, - text_color='white', - background_color=self._text_color(), + *, + view: pg.views.HtmlTreeView, + root_path: pg.KeyPath, + source_tag: str | Sequence[str] | None = pg.View.PresetArgValue( + ('lm-input', 'lm-output') + ), + include_message_metadata: bool = pg.View.PresetArgValue(True), + collapse_modalities_in_text: bool = pg.View.PresetArgValue(True), + collapse_llm_usage: bool = pg.View.PresetArgValue(False), + collapse_message_result_level: int = pg.View.PresetArgValue(1), + collapse_message_metadata_level: int = pg.View.PresetArgValue(0), + collapse_source_message_level: int = pg.View.PresetArgValue(1), + **kwargs, + ) -> pg.Html: + # pytype: enable=annotation-type-mismatch + """Returns the HTML representation of the message. + + Args: + view: The HTML tree view. + root_path: The root path of the message. + source_tag: tags to filter source messages. If None, the entire + source chain will be included. + include_message_metadata: Whether to include the metadata of the message. + collapse_modalities_in_text: Whether to collapse the modalities in the + message text. + collapse_llm_usage: Whether to collapse the usage in the message. + collapse_message_result_level: The level to collapse the result in the + message. + collapse_message_metadata_level: The level to collapse the metadata in the + message. + collapse_source_message_level: The level to collapse the source in the + message. + **kwargs: Other keyword arguments. + + Returns: + The HTML representation of the message content. + """ + def render_tags(): + return pg.Html.element( + 'div', + [pg.Html.element('span', [tag]) for tag in self.tags], + css_class=['message-tags'], + ) + + def render_message_text(): + maybe_reformatted = self.get('formatted_text') + referred_chunks = {} + s = pg.Html('
') + for chunk in self.chunk(maybe_reformatted): + if isinstance(chunk, str): + s.write(s.escape(chunk)) + else: + assert isinstance(chunk, modality.Modality), chunk + child_path = root_path + 'metadata' + chunk.referred_name + s.write( + pg.Html.element( + 'div', + [ + view.render( + chunk, + name=chunk.referred_name, + root_path=child_path, + collapse_level=child_path.depth + ( + 0 if collapse_modalities_in_text else 1 + ) + ) + ], + css_class=['modality-in-text'], + ) ) + referred_chunks[chunk.referred_name] = chunk + s.write('
') + return s + + def render_result(): + if 'result' not in self.metadata: + return None + child_path = root_path + 'metadata' + 'result' + return pg.Html.element( + 'div', + [ + view.render( + self.result, + name='result', + root_path=child_path, + collapse_level=( + child_path.depth + collapse_message_result_level + ) + ) + ], + css_class=['message-result'], + ) + + def render_usage(): + if 'usage' not in self.metadata: + return None + child_path = root_path + 'metadata' + 'usage' + return pg.Html.element( + 'div', + [ + view.render( + self.usage, + name='llm usage', + root_path=child_path, + collapse_level=child_path.depth + ( + 0 if collapse_llm_usage else 1 + ) + ) + ], + css_class=['message-usage'], + ) + + def render_source_message(): + source = self.source + while (source is not None + and source_tag is not None + and not source.has_tag(source_tag)): + source = source.source + if source is not None: + return view.render( + self.source, + name='source', + root_path=root_path + 'source', + include_metadata=include_message_metadata, + collapse_level=( + root_path.depth + 1 + collapse_source_message_level + ), + collapse_source_level=max(0, collapse_source_message_level - 1), + collapse_modalities=collapse_modalities_in_text, + collapse_usage=collapse_llm_usage, + collapse_metadata_level=collapse_message_metadata_level, + collapse_result_level=collapse_message_result_level, + ) + return None + + def render_metadata(): + if not include_message_metadata: + return None + child_path = root_path + 'metadata' + return pg.Html.element( + 'div', + [ + view.render( + self.metadata, + css_class=['message-metadata'], + name='metadata', + root_path=child_path, + collapse_level=( + child_path.depth + collapse_message_metadata_level + ) + ) + ], + css_class=['message-metadata'], ) - s.write('
') - # Body. - s.write( - f'' + return pg.Html.element( + 'div', + [ + render_tags(), + render_message_text(), + render_result(), + render_usage(), + render_metadata(), + render_source_message(), + ], + css_class=['complex_value'], ) - # NOTE(daiyip): LLM may reformat the text from the input, therefore - # we proritize the formatted text if it's available. - maybe_reformatted = self.get('formatted_text') - referred_chunks = {} - for chunk in self.chunk(maybe_reformatted): - if isinstance(chunk, str): - s.write(html.escape(chunk)) - else: - assert isinstance(chunk, modality.Modality), chunk - s.write(' ') - s.write(repr_utils.html_round_text( - chunk.referred_name, - text_color='black', - background_color='#f7dc6f' - )) - s.write(' ') - referred_chunks[chunk.referred_name] = chunk - s.write('') - - def item_color(k, v): - if isinstance(v, modality.Modality): - return ('black', '#f7dc6f', None, None) # Light yellow - elif k == 'result': - return ('white', 'purple', 'purple', None) # Blue. - elif k in ('usage',): - return ('white', '#e74c3c', None, None) # Red. - else: - return ('white', '#17202a', None, None) # Dark gray - - # TODO(daiyip): Revisit the logic in deciding what metadata keys to - # expose to the user. - if referred_chunks: - s.write(repr_utils.html_repr(referred_chunks, item_color)) - - if 'lm-response' in self.tags: - s.write(repr_utils.html_repr(self.metadata, item_color)) - s.write('
') - return repr_utils.Html(s.getvalue()) - - def _text_color(self) -> str: - match self.__class__.__name__: - case 'UserMessage': - return 'green' - case 'AIMessage': - return 'blue' - case _: - return 'black' + def _html_style(self) -> list[str]: + return super()._html_style() + [ + """ + /* Langfun Message styles.*/ + [class^="message-"] > details { + margin: 0px 0px 5px 0px; + border: 1px solid #EEE; + } + details.lf-message > summary > .summary_title::after { + content: ' 💬'; + } + details.pyglove.ai-message { + border: 1px solid blue; + color: blue; + } + details.pyglove.user-message { + border: 1px solid green; + color: green; + } + .message-tags { + margin: 5px 0px 5px 0px; + font-size: .8em; + } + .message-tags > span { + border-radius: 5px; + background-color: #CCC; + padding: 3px; + margin: 0px 2px 0px 2px; + color: white; + } + .message-text { + padding: 20px; + margin: 10px 5px 10px 5px; + font-style: italic; + font-size: 1.1em; + white-space: pre-wrap; + border: 1px solid #EEE; + border-radius: 5px; + background-color: #EEE; + } + .modality-in-text { + display: inline-block; + } + .modality-in-text > details { + display: inline-block; + font-size: 0.8em; + border: 0; + background-color: #A6F1A6; + margin: 0px 5px 0px 5px; + } + .message-result { + color: purple; + } + .message-usage { + color: orange; + } + .message-usage .object_key.str { + border: 1px solid orange; + background-color: orange; + color: white; + } + """ + ] + + def _html_element_class(self) -> list[str]: + return super()._html_element_class() + ['lf-message'] + # # Messages of different roles. diff --git a/langfun/core/message_test.py b/langfun/core/message_test.py index 4f7d349..ada8f48 100644 --- a/langfun/core/message_test.py +++ b/langfun/core/message_test.py @@ -15,6 +15,7 @@ import inspect import unittest +from langfun.core import language_model from langfun.core import message from langfun.core import modality import pyglove as pg @@ -26,9 +27,6 @@ class CustomModality(modality.Modality): def to_bytes(self): return self.content.encode() - def _repr_html_(self): - return f'
CustomModality: {self.content}
' - class MessageTest(unittest.TestCase): @@ -39,11 +37,19 @@ class A(pg.Object): d = pg.Dict(x=A()) - m = message.UserMessage('hi', metadata=dict(x=1), x=pg.Ref(d.x), y=2) + m = message.UserMessage( + 'hi', + metadata=dict(x=1), x=pg.Ref(d.x), + y=2, + tags=['lm-input'] + ) self.assertEqual(m.metadata, {'x': pg.Ref(d.x), 'y': 2}) self.assertEqual(m.sender, 'User') self.assertIs(m.x, d.x) self.assertEqual(m.y, 2) + self.assertTrue(m.has_tag('lm-input')) + self.assertTrue(m.has_tag(('lm-input', ''))) + self.assertFalse(m.has_tag('lm-response')) with self.assertRaises(AttributeError): _ = m.z @@ -332,50 +338,69 @@ def test_chunking(self): ) ) - def test_html(self): - m = message.UserMessage( - 'hi, this is a <<[[img1]]>> and <<[[x.img2]]>>', - img1=CustomModality('foo'), - x=dict(img2=CustomModality('bar')), + def assert_html_content(self, html, expected): + expected = inspect.cleandoc(expected).strip() + actual = html.content.strip() + if actual != expected: + print(actual) + self.assertEqual(actual, expected) + + def test_html_user_message(self): + self.assert_html_content( + message.UserMessage( + 'what is a
' + ).to_html(enable_summary_tooltip=False), + """ +
UserMessage(...)
what is a <div>
+ """ ) - self.assertEqual( - m._repr_html_(), - ( - '
UserMessage
hi, this is a img1 and x.img2 
img1
CustomModality: foo
x.img2' - '
CustomModality: bar
' - '
' - ) + self.assert_html_content( + message.UserMessage( + 'what is this <<[[image]]>>', + tags=['lm-input'], + image=CustomModality('bird') + ).to_html(enable_summary_tooltip=False, include_message_metadata=False), + """ +
UserMessage(...)
lm-input
what is this
image
CustomModality(...)
contentmetadata.image.content
'bird'
+ """ ) - self.assertIn( - 'background-color: blue', - message.AIMessage('hi').to_html().content, + + def test_html_ai_message(self): + image = CustomModality('foo') + user_message = message.UserMessage( + 'What is in this image? <<[[image]]>> this is a test', + metadata=dict(image=image), + source=message.UserMessage('User input'), + tags=['lm-input'] ) - self.assertIn( - 'background-color: black', - message.SystemMessage('hi').to_html().content, + ai_message = message.AIMessage( + 'My name is Gemini', + metadata=dict( + result=pg.Dict(x=1, y=2, z=pg.Dict(a=[12, 323])), + usage=language_model.LMSamplingUsage(10, 2, 12) + ), + tags=['lm-response', 'lm-output'], + source=user_message, + ) + self.assert_html_content( + ai_message.to_html(enable_summary_tooltip=False), + """ +
AIMessage(...)
lm-responselm-output
My name is Gemini
result
Dict(...)
xmetadata.result.x
1
ymetadata.result.y
2
zmetadata.result.z
Dict(...)
ametadata.result.z.a
List(...)
0metadata.result.z.a[0]
12
1metadata.result.z.a[1]
323
llm usage
LMSamplingUsage(...)
prompt_tokensmetadata.usage.prompt_tokens
10
completion_tokensmetadata.usage.completion_tokens
2
total_tokensmetadata.usage.total_tokens
12
num_requestsmetadata.usage.num_requests
1
source
UserMessage(...)
lm-input
What is in this image?
image
CustomModality(...)
contentsource.metadata.image.content
'foo'
this is a test
+ """ + ) + self.assert_html_content( + ai_message.to_html( + enable_summary_tooltip=False, + collapse_modalities_in_text=False, + collapse_llm_usage=True, + collapse_message_result_level=0, + collapse_message_metadata_level=0, + collapse_source_message_level=0, + source_tag=None, + ), + """ +
AIMessage(...)
lm-responselm-output
My name is Gemini
result
Dict(...)
xmetadata.result.x
1
ymetadata.result.y
2
zmetadata.result.z
Dict(...)
ametadata.result.z.a
List(...)
0metadata.result.z.a[0]
12
1metadata.result.z.a[1]
323
llm usage
LMSamplingUsage(...)
prompt_tokensmetadata.usage.prompt_tokens
10
completion_tokensmetadata.usage.completion_tokens
2
total_tokensmetadata.usage.total_tokens
12
num_requestsmetadata.usage.num_requests
1
source
UserMessage(...)
lm-input
What is in this image?
image
CustomModality(...)
contentsource.metadata.image.content
'foo'
this is a test
source
UserMessage(...)
User input
+ """ ) diff --git a/langfun/core/modalities/audio.py b/langfun/core/modalities/audio.py index 2576a53..9b06002 100644 --- a/langfun/core/modalities/audio.py +++ b/langfun/core/modalities/audio.py @@ -26,5 +26,5 @@ class Audio(mime.Mime): def audio_format(self) -> str: return self.mime_type.removeprefix(self.MIME_PREFIX + '/') - def _html(self, uri: str) -> str: + def _mime_control_for(self, uri: str) -> str: return f'' diff --git a/langfun/core/modalities/audio_test.py b/langfun/core/modalities/audio_test.py index a61ead3..3ce08db 100644 --- a/langfun/core/modalities/audio_test.py +++ b/langfun/core/modalities/audio_test.py @@ -53,7 +53,7 @@ def test_audio_file(self): self.assertEqual(audio.audio_format, 'x-wav') self.assertEqual(audio.mime_type, 'audio/x-wav') self.assertEqual( - audio._repr_html_(), + audio._raw_html(), '', ) self.assertEqual(audio.to_bytes(), content_bytes) diff --git a/langfun/core/modalities/image.py b/langfun/core/modalities/image.py index 42d93b3..68239e3 100644 --- a/langfun/core/modalities/image.py +++ b/langfun/core/modalities/image.py @@ -41,7 +41,7 @@ class Image(mime.Mime): def image_format(self) -> str: return self.mime_type.removeprefix(self.MIME_PREFIX + '/') - def _html(self, uri: str) -> str: + def _mime_control_for(self, uri: str) -> str: return f'' @functools.cached_property diff --git a/langfun/core/modalities/image_test.py b/langfun/core/modalities/image_test.py index 4536d4d..48b9de3 100644 --- a/langfun/core/modalities/image_test.py +++ b/langfun/core/modalities/image_test.py @@ -45,7 +45,7 @@ class ImageTest(unittest.TestCase): def test_from_bytes(self): image = image_lib.Image.from_bytes(image_content) self.assertEqual(image.image_format, 'png') - self.assertIn('data:image/png;base64,', image._repr_html_()) + self.assertIn('data:image/png;base64,', image._raw_html()) self.assertEqual(image.to_bytes(), image_content) with self.assertRaisesRegex( lf.ModalityError, '.* cannot be converted to text' @@ -67,7 +67,10 @@ def test_from_uri(self): with mock.patch('requests.get') as mock_requests_get: mock_requests_get.side_effect = mock_request self.assertEqual(image.image_format, 'png') - self.assertEqual(image._repr_html_(), '') + self.assertEqual( + image._raw_html(), + '' + ) self.assertEqual(image.to_bytes(), image_content) def test_from_uri_base_cls(self): @@ -76,7 +79,10 @@ def test_from_uri_base_cls(self): image = mime_lib.Mime.from_uri('http://mock/web/a.png') self.assertIsInstance(image, image_lib.Image) self.assertEqual(image.image_format, 'png') - self.assertEqual(image._repr_html_(), '') + self.assertEqual( + image._raw_html(), + '' + ) self.assertEqual(image.to_bytes(), image_content) def test_image_size(self): diff --git a/langfun/core/modalities/mime.py b/langfun/core/modalities/mime.py index 3a85633..55a5d43 100644 --- a/langfun/core/modalities/mime.py +++ b/langfun/core/modalities/mime.py @@ -178,14 +178,50 @@ def download(cls, uri: str) -> bytes | str: assert content is not None return content - def _repr_html_(self) -> str: + def _html_tree_view_content( + self, + **kwargs) -> str: + return self._raw_html() + + def _html_tree_view_render( + self, + view: pg.views.HtmlTreeView, + raw_mime_content: bool = pg.View.PresetArgValue(False), # pytype: disable=annotation-type-mismatch + display_modality_when_hover: bool = pg.View.PresetArgValue(False), # pytype: disable=annotation-type-mismatch + **kwargs + ): + if raw_mime_content: + return pg.Html(self._raw_html()) + else: + if display_modality_when_hover: + kwargs.update( + display_modality_when_hover=True, + enable_summary_tooltip=True, + ) + return super()._html_tree_view_render(view=view, **kwargs) + + def _html_tree_view_tooltip( + self, + *, + view: pg.views.HtmlTreeView, + content: pg.Html | str | None = None, + display_modality_when_hover: bool = pg.View.PresetArgValue(False), # pytype: disable=annotation-type-mismatch + **kwargs + ): + if content is None and display_modality_when_hover: + content = self._raw_html() + return super()._html_tree_view_tooltip( + view=view, content=content, **kwargs + ) + + def _raw_html(self) -> str: if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')): uri = self.uri else: uri = self.content_uri - return self._html(uri) + return self._mime_control_for(uri) - def _html(self, uri) -> str: + def _mime_control_for(self, uri) -> str: return f'' diff --git a/langfun/core/modalities/mime_test.py b/langfun/core/modalities/mime_test.py index e0f39b7..369678d 100644 --- a/langfun/core/modalities/mime_test.py +++ b/langfun/core/modalities/mime_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """MIME tests.""" +import inspect import unittest from unittest import mock @@ -77,6 +78,44 @@ def test_from_uri(self): self.assertEqual(content.to_bytes(), 'bar') self.assertEqual(content.mime_type, 'text/plain') + def assert_html_content(self, html, expected): + expected = inspect.cleandoc(expected).strip() + actual = html.content.strip() + if actual != expected: + print(actual) + self.assertEqual(actual, expected) + + def test_html(self): + self.assert_html_content( + mime.Custom('text/plain', b'foo').to_html( + enable_summary_tooltip=False, + enable_key_tooltip=False, + ), + """ +
Custom(...)
+ """ + ) + self.assert_html_content( + mime.Custom('text/plain', b'foo').to_html( + enable_summary_tooltip=False, + enable_key_tooltip=False, + raw_mime_content=True, + ), + """ + + """ + ) + self.assert_html_content( + mime.Custom('text/plain', b'foo').to_html( + enable_summary_tooltip=False, + enable_key_tooltip=False, + display_modality_when_hover=True, + ), + """ +
Custom(...)
+ """ + ) + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/modalities/ms_office.py b/langfun/core/modalities/ms_office.py index 8bff65e..92bfa25 100644 --- a/langfun/core/modalities/ms_office.py +++ b/langfun/core/modalities/ms_office.py @@ -29,7 +29,7 @@ class Xlsx(mime.Mime): 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' ) - def to_html(self) -> str: + def _raw_html(self) -> str: try: import pandas as pd # pylint: disable=g-import-not-at-top import openpyxl # pylint: disable=g-import-not-at-top, unused-import @@ -40,9 +40,6 @@ def to_html(self) -> str: 'Please install "langfun[mime-xlsx]" to enable XLSX support.' ) from e - def _repr_html_(self) -> str: - return self.to_html() - def _is_compatible(self, mime_types: Iterable[str]) -> bool: return bool(set(mime_types).intersection([ 'text/html', @@ -52,7 +49,7 @@ def _is_compatible(self, mime_types: Iterable[str]) -> bool: def _make_compatible(self, mime_types: Iterable[str]) -> mime.Mime: """Returns the MimeType of the converted file.""" del mime_types - return mime.Mime(uri=self.uri, content=self.to_html()) + return mime.Mime(uri=self.uri, content=self._raw_html()) class Docx(mime.Mime): diff --git a/langfun/core/modalities/ms_office_test.py b/langfun/core/modalities/ms_office_test.py index 53c050d..349a18d 100644 --- a/langfun/core/modalities/ms_office_test.py +++ b/langfun/core/modalities/ms_office_test.py @@ -347,7 +347,7 @@ def test_from_uri(self): ), ) self.assertEqual(content.to_bytes(), xlsx_bytes) - self.assertEqual(content.to_html(), expected_xlsx_html) + self.assertEqual(content._raw_html(), expected_xlsx_html) class PptxTest(unittest.TestCase): diff --git a/langfun/core/modalities/pdf_test.py b/langfun/core/modalities/pdf_test.py index b65a9e3..1d32d24 100644 --- a/langfun/core/modalities/pdf_test.py +++ b/langfun/core/modalities/pdf_test.py @@ -49,7 +49,7 @@ def test_repr_html(self): pdf = pdf_lib.PDF.from_bytes(pdf_bytes) self.assertIn( ' ' diff --git a/langfun/core/modalities/video_test.py b/langfun/core/modalities/video_test.py index 815bf07..057f7fb 100644 --- a/langfun/core/modalities/video_test.py +++ b/langfun/core/modalities/video_test.py @@ -38,7 +38,7 @@ def test_video_content(self): video = video_lib.Video.from_bytes(mp4_bytes) self.assertEqual(video.mime_type, 'video/mp4') self.assertEqual(video.video_format, 'mp4') - self.assertIn('data:video/mp4;base64,', video._repr_html_()) + self.assertIn('data:video/mp4;base64,', video._raw_html()) self.assertEqual(video.to_bytes(), mp4_bytes) def test_bad_video(self): @@ -56,7 +56,7 @@ def test_video_file(self): self.assertEqual(video.video_format, 'mp4') self.assertEqual(video.mime_type, 'video/mp4') self.assertEqual( - video._repr_html_(), + video._raw_html(), '', ) self.assertEqual(video.to_bytes(), mp4_bytes) diff --git a/langfun/core/structured/mapping.py b/langfun/core/structured/mapping.py index daf93ad..176b021 100644 --- a/langfun/core/structured/mapping.py +++ b/langfun/core/structured/mapping.py @@ -183,6 +183,44 @@ def natural_language_format(self) -> str: result.write(lf.colored(str(self.metadata), color='cyan')) return result.getvalue().strip() + def _html_tree_view_content( + self, + *, + parent: Any, + view: pg.views.HtmlTreeView, + root_path: pg.KeyPath, + **kwargs, + ): + def render_value(value, **kwargs): + if isinstance(value, lf.Template): + # Make a shallow copy to make sure modalities are rooted by + # the input. + value = value.clone().render() + return view.render(value, **kwargs) + + exclude_keys = [] + if not self.context: + exclude_keys.append('context') + if not self.schema: + exclude_keys.append('schema') + if not self.metadata: + exclude_keys.append('metadata') + + kwargs.pop('special_keys', None) + kwargs.pop('exclude_keys', None) + return view.complex_value( + self.sym_init_args, + parent=self, + root_path=root_path, + render_value_fn=render_value, + special_keys=['input', 'output', 'context', 'schema', 'metadata'], + exclude_keys=exclude_keys, + **kwargs + ) + + def _html_tree_view_collapse_level(self) -> int: + return 2 + class Mapping(lf.LangFunc): """Base class for mapping. diff --git a/langfun/core/structured/mapping_test.py b/langfun/core/structured/mapping_test.py index 3a16c5a..3d23d96 100644 --- a/langfun/core/structured/mapping_test.py +++ b/langfun/core/structured/mapping_test.py @@ -14,6 +14,7 @@ """Tests for structured mapping example.""" import inspect +from typing import Any import unittest import langfun.core as lf @@ -164,6 +165,60 @@ def test_serialization(self): pg.eq(pg.from_json_str(example.to_json_str()), example) ) + def assert_html_content(self, html, expected): + expected = inspect.cleandoc(expected).strip() + actual = html.content.strip() + if actual != expected: + print(actual) + self.assertEqual(actual, expected) + + def test_html(self): + + class Answer(pg.Object): + answer: int + + class Addition(lf.Template): + """Template Addition. + + {{x}} + {{y}} = ? + """ + x: Any + y: Any + + example = mapping.MappingExample( + input=Addition(x=1, y=2), + schema=Answer, + context='compute 1 + 1', + output=Answer(answer=3), + metadata={'foo': 'bar'}, + ) + self.assert_html_content( + example.to_html( + enable_summary_tooltip=False, include_message_metadata=False + ), + """ +
MappingExample(...)
input
UserMessage(...)
rendered
1 + 2 = ?
output
Answer(...)
answeroutput.answer
3
context
'compute 1 + 1'
'compute 1 + 1'
schema
Schema(...)
Answer + + ```python + class Answer: + answer: int + ```
metadata
Dict(...)
foometadata.foo
'bar'
+ """ + ) + + example = mapping.MappingExample( + input=Addition(x=1, y=2), + output=Answer(answer=3), + ) + self.assert_html_content( + example.to_html( + enable_summary_tooltip=False, include_message_metadata=False + ), + """ +
MappingExample(...)
input
UserMessage(...)
rendered
1 + 2 = ?
output
Answer(...)
answeroutput.answer
3
+ """ + ) + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index f855cbb..92bccef 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -189,6 +189,40 @@ def from_value(cls, value) -> 'Schema': return value return cls(parse_value_spec(value)) + def _html_tree_view_content( + self, + *, + view: pg.views.HtmlTreeView, + root_path: pg.KeyPath, + **kwargs, + ): + return pg.Html.element( + 'div', + [self.schema_str(protocol='python')], + css_class=['lf-schema-definition'] + ).add_style( + """ + .lf-schema-definition { + color: blue; + margin: 5px; + white-space: pre-wrap; + } + """ + ) + + def _html_tree_view_tooltip( + self, + *, + view: pg.views.HtmlTreeView, + content: pg.Html | str | None = None, + **kwargs, + ): + return view.tooltip( + self, + content=content or pg.Html.escape(self.schema_str(protocol='python')), + **kwargs + ) + def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]: """Returns a list of top level value specs from a symbolic value.""" diff --git a/langfun/core/template.py b/langfun/core/template.py index c3dd426..b377bcc 100644 --- a/langfun/core/template.py +++ b/langfun/core/template.py @@ -17,7 +17,7 @@ import dataclasses import functools import inspect -from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type, Union +from typing import Annotated, Any, Callable, Iterator, Sequence, Set, Tuple, Type, Union import jinja2 from jinja2 import meta as jinja2_meta @@ -526,6 +526,112 @@ def from_value( return lfun return cls(template_str='{{input}}', input=value, **kwargs) + def _html_tree_view_content( + self, + *, + view: pg.views.HtmlTreeView, + root_path: pg.KeyPath, + **kwargs, + ): + def render_template_str(): + return pg.Html.element( + 'div', + [ + pg.Html.element('span', [self.template_str]) + ], + css_class=['template-str'], + ) + + def render_fields(): + def render_value_fn(value, *, root_path, **kwargs): + if isinstance(value, component.ContextualAttribute): + inferred = self.sym_inferred(root_path.key, pg.MISSING_VALUE) + if inferred != pg.MISSING_VALUE: + return pg.Html.element( + 'div', + [ + view.render(inferred, root_path=root_path, **kwargs) + ], + css_class=['inferred-value'], + ) + else: + return pg.Html.element( + 'span', + ['(external)'], + css_class=['contextual-variable'], + ) + return view.render( + value, root_path=root_path, **kwargs + ) + return pg.Html.element( + 'fieldset', + [ + pg.Html.element('legend', ['Template Variables']), + view.complex_value( + self.sym_init_args, + name='fields', + root_path=root_path, + render_value_fn=render_value_fn, + exclude_keys=['template_str', 'clean'], + parent=self, + collapse_level=root_path.depth + 1, + ), + ], + css_class=['template-fields'], + ) + + return pg.Html.element( + 'div', + [ + render_template_str(), + render_fields(), + ], + css_class=['complex_value'], + ) + + def _html_style(self) -> list[str]: + return super()._html_style() + [ + """ + /* Langfun Template styles. */ + .template-str { + padding: 10px; + margin: 10px 5px 10px 5px; + font-style: italic; + font-size: 1.1em; + white-space: pre-wrap; + border: 1px solid #EEE; + border-radius: 5px; + background-color: #EEE; + color: #cc2986; + } + .template-fields { + margin: 0px 0px 5px 0px; + border: 1px solid #EEE; + padding: 5px; + } + .template-fields > legend { + font-size: 0.8em; + margin: 5px 0px 5px 0px; + } + .inferred-value::after { + content: ' (inferred)'; + color: gray; + font-style: italic; + } + .contextual-variable { + margin: 0px 0px 0px 5px; + font-style: italic; + color: gray; + } + """ + ] + + # Additional CSS class to add to the root
element. + def _html_element_class(self) -> Sequence[str] | None: + return [ + pg.object_utils.camel_to_snake(self.__class__.__name__, '-'), + 'lf-template' + ] # Register converter from str to LangFunc, therefore we can always # pass strs to attributes that accept LangFunc. diff --git a/langfun/core/template_test.py b/langfun/core/template_test.py index 30082a2..d490f48 100644 --- a/langfun/core/template_test.py +++ b/langfun/core/template_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Template test.""" import inspect +from typing import Any import unittest from langfun.core import component @@ -552,5 +553,41 @@ def on_event(self, event: TemplateRenderEvent): self.assertEqual(render_stacks, [[l]]) +class HtmlTest(unittest.TestCase): + + def assert_html_content(self, html, expected): + expected = inspect.cleandoc(expected).strip() + actual = html.content.strip() + if actual != expected: + print(actual) + self.assertEqual(actual, expected) + + def test_html(self): + + class Foo(Template): + """Template Foo. + + {{x}} + {{y}} = ? + """ + x: Any + y: Any + + class Bar(Template): + """Template Bar. + + {{y}} + {{z}} + """ + y: Any + + self.assert_html_content( + Foo(x=Bar('{{y}} + {{z}}'), y=1).to_html( + enable_summary_tooltip=False, + ), + """ +
Foo(...)
{{x}} + {{y}} = ?
Template Variables
xx
Bar(...)
{{y}} + {{z}}
Template Variables
yx.y
1
zx.z
(external)
yy
1
+ """ + ) + + if __name__ == '__main__': unittest.main()