Skip to content

Commit

Permalink
LMUsageSummary enhancements.
Browse files Browse the repository at this point in the history
1) Adjust LMUsage estimate aggregation. Previously the aggregated estimate is None when any model does not support estimating, now the aggregated estimate will be the sum of cost for models that support estimating.
2) We also update the tooltip for LMUsageSummary badge to show breakdown of usage for different models.

PiperOrigin-RevId: 696999410
  • Loading branch information
daiyip authored and langfun authors committed Nov 15, 2024
1 parent bf12a48 commit bb8256a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
29 changes: 21 additions & 8 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ def average_estimated_cost(self) -> float | None:
def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
if other is None:
return self
if self.estimated_cost is None:
estimated_cost = other.estimated_cost
elif other.estimated_cost is None:
estimated_cost = self.estimated_cost
else:
estimated_cost = self.estimated_cost + other.estimated_cost
return LMSamplingUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens,
num_requests=self.num_requests + other.num_requests,
estimated_cost=(
self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary
if (self.estimated_cost is not None
and other.estimated_cost is not None)
else None
)
estimated_cost=estimated_cost,
)

def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
Expand Down Expand Up @@ -956,7 +957,9 @@ def _update_view(self):
if self._usage_badge is not None:
self._usage_badge.update(
self._badge_text(),
tooltip=pg.format(self.total, verbose=False),
tooltip=pg.format(
self, verbose=False, custom_format=self._tooltip_format
),
styles=dict(color=self._badge_color()),
)

Expand All @@ -978,6 +981,14 @@ def _badge_color(self) -> str | None:
green = int(255 * (1 - normalized_value))
return f'rgb({red}, {green}, 0)'

def _tooltip_format(self, v, root_indent):
del root_indent
if isinstance(v, int):
return f'{v:,}'
if isinstance(v, float):
return f'{v:,.3f}'
return None

def _html_tree_view(
self,
*,
Expand All @@ -993,7 +1004,9 @@ def _html_tree_view(
if usage_badge is None:
usage_badge = pg.views.html.controls.Badge(
self._badge_text(),
tooltip=pg.format(self.total, verbose=False),
tooltip=pg.format(
self, custom_format=self._tooltip_format, verbose=False
),
css_classes=['usage-summary'],
styles=dict(color=self._badge_color()),
interactive=True,
Expand Down
7 changes: 7 additions & 0 deletions langfun/core/language_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,13 @@ def test_add(self):
self.assertEqual(usage1 + usage2, usage1 + usage2)
self.assertIs(usage1 + None, usage1)
self.assertIs(None + usage1, usage1)
usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
self.assertEqual(
usage1 + usage3, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0)
)
self.assertEqual(
usage3 + usage1, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0)
)

def test_usage_not_available(self):
usage_not_available = lm_lib.UsageNotAvailable()
Expand Down

0 comments on commit bb8256a

Please sign in to comment.