Skip to content

Commit

Permalink
Improve lf.QueryInvocaton.
Browse files Browse the repository at this point in the history
1) Keep raw LLM response with `lm_response` field.
2) No longer store `output`, instead restoring it from `lm_response`.
2) Add custom HTML rendering.

PiperOrigin-RevId: 704442848
  • Loading branch information
daiyip authored and langfun authors committed Dec 9, 2024
1 parent f5c1e61 commit 2d6c09b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 7 deletions.
2 changes: 1 addition & 1 deletion langfun/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def _html_tree_view_css_styles(cls) -> list[str]:
.modality-in-text {
display: inline-block;
}
.modality-in-text > details {
.modality-in-text > details.pyglove {
display: inline-block;
font-size: 0.8em;
border: 0;
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test_html_style(self):
.modality-in-text {
display: inline-block;
}
.modality-in-text > details {
.modality-in-text > details.pyglove {
display: inline-block;
font-size: 0.8em;
border: 0;
Expand Down
110 changes: 105 additions & 5 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def _result(message: lf.Message):
schema_lib.Schema.from_value(schema)
if schema not in (None, str) else None
),
output=pg.Ref(_result(output_message)),
lm=pg.Ref(lm),
examples=pg.Ref(examples) if examples else [],
lm_response=lf.AIMessage(output_message.text),
usage_summary=usage_summary,
)
for i, (tracker, include_child_scopes) in enumerate(trackers):
Expand Down Expand Up @@ -357,7 +357,7 @@ def _reward(self, input, expected_output, metadata): # pylint: disable=redefine
return _reward


class QueryInvocation(pg.Object):
class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
"""A class to represent the invocation of `lf.query`."""

input: Annotated[
Expand All @@ -368,9 +368,9 @@ class QueryInvocation(pg.Object):
schema_lib.schema_spec(noneable=True),
'Schema of `lf.query`.'
]
output: Annotated[
Any,
'Mapping output of `lf.query`.'
lm_response: Annotated[
lf.Message,
'Raw LM response.'
]
lm: Annotated[
lf.LanguageModel,
Expand All @@ -385,6 +385,106 @@ class QueryInvocation(pg.Object):
'Usage summary for `lf.query`.'
]

@functools.cached_property
def lm_request(self) -> lf.Message:
return query_prompt(self.input, self.schema)

@functools.cached_property
def output(self) -> Any:
return query_output(self.lm_response, self.schema)

def _on_bound(self):
super()._on_bound()
self.__dict__.pop('lm_request', None)
self.__dict__.pop('output', None)

def _html_tree_view_summary(
self,
*,
view: pg.views.HtmlTreeView,
**kwargs: Any
) -> pg.Html | None:
return view.summary(
value=self,
title=pg.Html.element(
'div',
[
pg.views.html.controls.Label(
'lf.query',
css_classes=['query-invocation-type-name']
),
pg.views.html.controls.Badge(
f'lm={self.lm.model_id}',
pg.format(
self.lm,
verbose=False,
python_format=True,
hide_default_values=True
),
css_classes=['query-invocation-lm']
),
self.usage_summary.to_html(extra_flags=dict(as_badge=True))
],
css_classes=['query-invocation-title']
),
enable_summary_tooltip=False
)

def _html_tree_view_content(
self,
*,
view: pg.views.HtmlTreeView,
**kwargs: Any
) -> pg.Html:
return pg.views.html.controls.TabControl([
pg.views.html.controls.Tab(
'input',
pg.view(self.input, collapse_level=None),
),
pg.views.html.controls.Tab(
'schema',
pg.view(self.schema),
),
pg.views.html.controls.Tab(
'output',
pg.view(self.output, collapse_level=None),
),
pg.views.html.controls.Tab(
'lm_request',
pg.view(
self.lm_request,
extra_flags=dict(include_message_metadata=False),
),
),
pg.views.html.controls.Tab(
'lm_response',
pg.view(
self.lm_response,
extra_flags=dict(include_message_metadata=False)
),
),
], tab_position='top').to_html()

@classmethod
def _html_tree_view_css_styles(cls) -> list[str]:
return super()._html_tree_view_css_styles() + [
"""
.query-invocation-title {
display: inline-block;
font-weight: normal;
}
.query-invocation-type-name {
font-style: italic;
color: #888;
}
.query-invocation-lm.badge {
margin-left: 5px;
margin-right: 5px;
background-color: #fff0d6;
}
"""
]


@contextlib.contextmanager
def track_queries(
Expand Down
12 changes: 12 additions & 0 deletions langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,18 @@ def test_query(self):
)


class QueryInvocationTest(unittest.TestCase):

def test_to_html(self):
lm = fake.StaticSequence([
'Activity(description="hi")',
])
with prompting.track_queries() as queries:
prompting.query('foo', Activity, lm=lm)

self.assertIn('schema', queries[0].to_html_str())


class TrackQueriesTest(unittest.TestCase):

def test_include_child_scopes(self):
Expand Down

0 comments on commit 2d6c09b

Please sign in to comment.