Skip to content

Commit

Permalink
Support cost estimation for LLM calls.
Browse files Browse the repository at this point in the history
Usage:

```
import pyglove as pg
import langfun as lf

class Answer(pg.Object):
  result: str

with lf.track_usages() as usage:
  lf.query('compute 1 + 1', Answer, lm=lf.llms.VertexAIGeminiPro1_5_002())
  lf.query('compute 1 + 1', Answer, lm=lf.llms.Gpt4o())

print(usage)
```
Sample output:
```
UsageSummary(
  cached = AggregatedUsage(
    total = LMSamplingUsage(
      prompt_tokens = 0,
      completion_tokens = 0,
      total_tokens = 0,
      num_requests = 0,
      estimated_cost = 0.0
    ),
    breakdown = {}
  ),
  uncached = AggregatedUsage(
    total = LMSamplingUsage(
      prompt_tokens = 239,
      completion_tokens = 28,
      total_tokens = 267,
      num_requests = 2,
      estimated_cost = 0.00064
    ),
    breakdown = {
      VertexAI(gemini-1.5-pro-002) = LMSamplingUsage(
        prompt_tokens = 126,
        completion_tokens = 16,
        total_tokens = 142,
        num_requests = 1,
        estimated_cost = 0.0002375
      ),
      OpenAI(gpt-4o) = LMSamplingUsage(
        prompt_tokens = 113,
        completion_tokens = 12,
        total_tokens = 125,
        num_requests = 1,
        estimated_cost = 0.0004025
      )
    }
  )
)
```
We also add a few Groq models.

PiperOrigin-RevId: 684927276
  • Loading branch information
daiyip authored and langfun authors committed Oct 11, 2024
1 parent 8e29ae3 commit 0cc6486
Show file tree
Hide file tree
Showing 18 changed files with 1,020 additions and 149 deletions.
1 change: 1 addition & 0 deletions langfun/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from langfun.core.language_model import LMSamplingOptions
from langfun.core.language_model import LMSamplingUsage
from langfun.core.language_model import UsageNotAvailable
from langfun.core.language_model import UsageSummary
from langfun.core.language_model import LMSamplingResult
from langfun.core.language_model import LMScoringResult
from langfun.core.language_model import LMCache
Expand Down
1 change: 1 addition & 0 deletions langfun/core/eval/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def test_dryrun(self):
cache_seed=0,
score=1.0,
logprobs=None,
is_cached=False,
usage=lf.LMSamplingUsage(387, 24, 411),
tags=['lm-response', 'lm-output', 'transformed'],
),
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_call(self):
self.assertEqual(
r,
message.AIMessage(
'Hello!!!', score=0.0, logprobs=None,
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
usage=language_model.UsageNotAvailable()
)
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_call(self):
self.assertEqual(
r,
message.AIMessage(
'Hello!!!', score=0.0, logprobs=None,
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
usage=language_model.UsageNotAvailable()
)
)
Expand Down
164 changes: 140 additions & 24 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import enum
import threading
import time
from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
from langfun.core import component
from langfun.core import concurrent
from langfun.core import console
Expand Down Expand Up @@ -86,25 +86,75 @@ class LMSamplingUsage(pg.Object):
completion_tokens: int
total_tokens: int
num_requests: int = 1
estimated_cost: Annotated[
float | None,
(
'Estimated cost in US dollars. If None, cost estimating is not '
'suppported on the model being queried.'
)
] = None

def __bool__(self) -> bool:
return self.num_requests > 0

@property
def average_prompt_tokens(self) -> int:
"""Returns the average prompt tokens per request."""
return self.prompt_tokens // self.num_requests

@property
def average_completion_tokens(self) -> int:
"""Returns the average completion tokens per request."""
return self.completion_tokens // self.num_requests

@property
def average_total_tokens(self) -> int:
"""Returns the average total tokens per request."""
return self.total_tokens // self.num_requests

def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
@property
def average_estimated_cost(self) -> float | None:
"""Returns the average estimated cost per request."""
if self.estimated_cost is None:
return None
return self.estimated_cost / self.num_requests

def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
if other is None:
return self
return LMSamplingUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens,
num_requests=self.num_requests + other.num_requests,
estimated_cost=(
self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary
if (self.estimated_cost is not None
and other.estimated_cost is not None)
else None
)
)

def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
return self + other


class UsageNotAvailable(LMSamplingUsage):
"""Usage information not available."""
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
num_requests: pg.typing.Int(1).freeze() # pytype: disable=invalid-annotation
estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation

def __bool__(self) -> bool:
return False
def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
if other is None:
return self
return UsageNotAvailable(
num_requests=self.num_requests + other.num_requests
)

def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
return self + other


class LMSamplingResult(pg.Object):
Expand All @@ -123,6 +173,11 @@ class LMSamplingResult(pg.Object):
'Usage information. Currently only OpenAI models are supported.',
] = UsageNotAvailable()

is_cached: Annotated[
bool,
'Whether the result is from cache or not.'
] = False


class LMSamplingOptions(component.Component):
"""Language model sampling options."""
Expand Down Expand Up @@ -425,27 +480,31 @@ def sample(
response = sample.response
response.metadata.score = sample.score
response.metadata.logprobs = sample.logprobs
response.metadata.is_cached = result.is_cached

# NOTE(daiyip): Current usage is computed at per-result level,
# which is accurate when n=1. For n > 1, we average the usage across
# multiple samples.
usage = result.usage
if len(result.samples) == 1 or not usage:
if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
response.metadata.usage = usage
else:
n = len(result.samples)
response.metadata.usage = LMSamplingUsage(
prompt_tokens=usage.prompt_tokens // n,
completion_tokens=usage.completion_tokens // n,
total_tokens=usage.total_tokens // n,
estimated_cost=(
usage.estimated_cost / n if usage.estimated_cost else None
)
)

# Track usage.
trackers = component.context_value('__usage_trackers__', [])
if trackers:
model_id = self.model_id
for tracker in trackers:
tracker.track(model_id, usage)
tracker.track(model_id, usage, result.is_cached)

# Track the prompt for corresponding response.
response.source = prompt
Expand Down Expand Up @@ -474,7 +533,9 @@ def _sample_with_cache_lookup(
request_to_result_index[len(requests)] = i
requests.append(prompt)
else:
results[i] = r.clone()
result = r.clone()
assert result.is_cached, result
results[i] = result

# Sample non-cache-hit prompts.
if requests:
Expand All @@ -491,8 +552,12 @@ def _sample_with_cache_lookup(
sample.response.set('cache_seed', cache_seed)

if cache_seed is not None:
self.cache.put(self, prompt, result.clone(), seed=cache_seed)

self.cache.put(
self,
prompt,
result.clone(override=dict(is_cached=True)),
seed=cache_seed
)
return results # pytype: disable=bad-return-type

@abc.abstractmethod
Expand Down Expand Up @@ -800,30 +865,81 @@ def rate_to_max_concurrency(
return DEFAULT_MAX_CONCURRENCY # Default of 1


class UsageSummary(pg.Object):
"""Usage sumary."""

class AggregatedUsage(pg.Object):
"""Aggregated usage."""

total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
breakdown: dict[str, LMSamplingUsage] = {}

def __bool__(self) -> bool:
"""Returns True if the usage is non-empty."""
return bool(self.breakdown)

def add(
self,
model_id: str,
usage: LMSamplingUsage,
) -> None:
"""Adds an entry to the breakdown."""
aggregated = self.breakdown.get(model_id, None)
with pg.notify_on_change(False):
self.breakdown[model_id] = usage + aggregated
self.rebind(total=self.total + usage, skip_notification=True)

@property
def total(self) -> LMSamplingUsage:
return self.cached.total + self.uncached.total

def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
"""Updates the usage summary."""
if is_cached:
usage.rebind(estimated_cost=0.0, skip_notification=True)
self.cached.add(model_id, usage)
else:
self.uncached.add(model_id, usage)


pg.members(
dict(
cached=(
pg.typing.Object(
UsageSummary.AggregatedUsage,
default=UsageSummary.AggregatedUsage()
),
'Aggregated usages for cached LLM calls.'
),
uncached=(
pg.typing.Object(
UsageSummary.AggregatedUsage,
default=UsageSummary.AggregatedUsage()
),
'Aggregated usages for uncached LLM calls.'
),
)
)(UsageSummary)


class _UsageTracker:
"""Usage tracker."""

def __init__(self, model_ids: set[str] | None):
self.model_ids = model_ids
self.usage_summary = UsageSummary()
self._lock = threading.Lock()
self.usages = {
m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids
} if model_ids else {}

def track(self, model_id: str, usage: LMSamplingUsage):
if self.model_ids is not None and model_id not in self.model_ids:
return
with self._lock:
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
self.usages[model_id] += usage
else:
self.usages[model_id] = usage

def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
if self.model_ids is None or model_id in self.model_ids:
with self._lock:
self.usage_summary.update(model_id, usage, is_cached)


@contextlib.contextmanager
def track_usages(
*lm: Union[str, LanguageModel]
) -> Iterator[dict[str, LMSamplingUsage]]:
) -> Iterator[UsageSummary]:
"""Context manager to track the usages of all language models in scope.
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
Expand Down Expand Up @@ -854,6 +970,6 @@ def track_usages(
tracker = _UsageTracker(set(model_ids) if model_ids else None)
with component.context(__usage_trackers__=trackers + [tracker]):
try:
yield tracker.usages
yield tracker.usage_summary
finally:
pass
Loading

0 comments on commit 0cc6486

Please sign in to comment.