diff --git a/langfun/core/__init__.py b/langfun/core/__init__.py index 34fd4c8..7fd97af 100644 --- a/langfun/core/__init__.py +++ b/langfun/core/__init__.py @@ -103,6 +103,7 @@ from langfun.core.language_model import LMSamplingOptions from langfun.core.language_model import LMSamplingUsage from langfun.core.language_model import UsageNotAvailable +from langfun.core.language_model import UsageSummary from langfun.core.language_model import LMSamplingResult from langfun.core.language_model import LMScoringResult from langfun.core.language_model import LMCache diff --git a/langfun/core/eval/base_test.py b/langfun/core/eval/base_test.py index 0902fd5..3efcc82 100644 --- a/langfun/core/eval/base_test.py +++ b/langfun/core/eval/base_test.py @@ -194,6 +194,7 @@ def test_dryrun(self): cache_seed=0, score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(387, 24, 411), tags=['lm-response', 'lm-output', 'transformed'], ), diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 3d5285e..19e1ea5 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -89,7 +89,7 @@ def test_call(self): self.assertEqual( r, message.AIMessage( - 'Hello!!!', score=0.0, logprobs=None, + 'Hello!!!', score=0.0, logprobs=None, is_cached=False, usage=language_model.UsageNotAvailable() ) ) @@ -120,7 +120,7 @@ def test_call(self): self.assertEqual( r, message.AIMessage( - 'Hello!!!', score=0.0, logprobs=None, + 'Hello!!!', score=0.0, logprobs=None, is_cached=False, usage=language_model.UsageNotAvailable() ) ) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 998b3bb..26f7467 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -19,7 +19,7 @@ import enum import threading import time -from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union +from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union from langfun.core import component from langfun.core import concurrent from langfun.core import console @@ -86,25 +86,75 @@ class LMSamplingUsage(pg.Object): completion_tokens: int total_tokens: int num_requests: int = 1 + estimated_cost: Annotated[ + float | None, + ( + 'Estimated cost in US dollars. If None, cost estimating is not ' + 'suppported on the model being queried.' + ) + ] = None + + def __bool__(self) -> bool: + return self.num_requests > 0 + + @property + def average_prompt_tokens(self) -> int: + """Returns the average prompt tokens per request.""" + return self.prompt_tokens // self.num_requests + + @property + def average_completion_tokens(self) -> int: + """Returns the average completion tokens per request.""" + return self.completion_tokens // self.num_requests + + @property + def average_total_tokens(self) -> int: + """Returns the average total tokens per request.""" + return self.total_tokens // self.num_requests - def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage': + @property + def average_estimated_cost(self) -> float | None: + """Returns the average estimated cost per request.""" + if self.estimated_cost is None: + return None + return self.estimated_cost / self.num_requests + + def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage': + if other is None: + return self return LMSamplingUsage( prompt_tokens=self.prompt_tokens + other.prompt_tokens, completion_tokens=self.completion_tokens + other.completion_tokens, total_tokens=self.total_tokens + other.total_tokens, num_requests=self.num_requests + other.num_requests, + estimated_cost=( + self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary + if (self.estimated_cost is not None + and other.estimated_cost is not None) + else None + ) ) + def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage': + return self + other + class UsageNotAvailable(LMSamplingUsage): """Usage information not available.""" prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation - num_requests: pg.typing.Int(1).freeze() # pytype: disable=invalid-annotation + estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation - def __bool__(self) -> bool: - return False + def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable': + if other is None: + return self + return UsageNotAvailable( + num_requests=self.num_requests + other.num_requests + ) + + def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable': + return self + other class LMSamplingResult(pg.Object): @@ -123,6 +173,11 @@ class LMSamplingResult(pg.Object): 'Usage information. Currently only OpenAI models are supported.', ] = UsageNotAvailable() + is_cached: Annotated[ + bool, + 'Whether the result is from cache or not.' + ] = False + class LMSamplingOptions(component.Component): """Language model sampling options.""" @@ -425,12 +480,13 @@ def sample( response = sample.response response.metadata.score = sample.score response.metadata.logprobs = sample.logprobs + response.metadata.is_cached = result.is_cached # NOTE(daiyip): Current usage is computed at per-result level, # which is accurate when n=1. For n > 1, we average the usage across # multiple samples. usage = result.usage - if len(result.samples) == 1 or not usage: + if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable): response.metadata.usage = usage else: n = len(result.samples) @@ -438,6 +494,9 @@ def sample( prompt_tokens=usage.prompt_tokens // n, completion_tokens=usage.completion_tokens // n, total_tokens=usage.total_tokens // n, + estimated_cost=( + usage.estimated_cost / n if usage.estimated_cost else None + ) ) # Track usage. @@ -445,7 +504,7 @@ def sample( if trackers: model_id = self.model_id for tracker in trackers: - tracker.track(model_id, usage) + tracker.track(model_id, usage, result.is_cached) # Track the prompt for corresponding response. response.source = prompt @@ -474,7 +533,9 @@ def _sample_with_cache_lookup( request_to_result_index[len(requests)] = i requests.append(prompt) else: - results[i] = r.clone() + result = r.clone() + assert result.is_cached, result + results[i] = result # Sample non-cache-hit prompts. if requests: @@ -491,8 +552,12 @@ def _sample_with_cache_lookup( sample.response.set('cache_seed', cache_seed) if cache_seed is not None: - self.cache.put(self, prompt, result.clone(), seed=cache_seed) - + self.cache.put( + self, + prompt, + result.clone(override=dict(is_cached=True)), + seed=cache_seed + ) return results # pytype: disable=bad-return-type @abc.abstractmethod @@ -800,30 +865,81 @@ def rate_to_max_concurrency( return DEFAULT_MAX_CONCURRENCY # Default of 1 +class UsageSummary(pg.Object): + """Usage sumary.""" + + class AggregatedUsage(pg.Object): + """Aggregated usage.""" + + total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0) + breakdown: dict[str, LMSamplingUsage] = {} + + def __bool__(self) -> bool: + """Returns True if the usage is non-empty.""" + return bool(self.breakdown) + + def add( + self, + model_id: str, + usage: LMSamplingUsage, + ) -> None: + """Adds an entry to the breakdown.""" + aggregated = self.breakdown.get(model_id, None) + with pg.notify_on_change(False): + self.breakdown[model_id] = usage + aggregated + self.rebind(total=self.total + usage, skip_notification=True) + + @property + def total(self) -> LMSamplingUsage: + return self.cached.total + self.uncached.total + + def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool): + """Updates the usage summary.""" + if is_cached: + usage.rebind(estimated_cost=0.0, skip_notification=True) + self.cached.add(model_id, usage) + else: + self.uncached.add(model_id, usage) + + +pg.members( + dict( + cached=( + pg.typing.Object( + UsageSummary.AggregatedUsage, + default=UsageSummary.AggregatedUsage() + ), + 'Aggregated usages for cached LLM calls.' + ), + uncached=( + pg.typing.Object( + UsageSummary.AggregatedUsage, + default=UsageSummary.AggregatedUsage() + ), + 'Aggregated usages for uncached LLM calls.' + ), + ) +)(UsageSummary) + + class _UsageTracker: """Usage tracker.""" def __init__(self, model_ids: set[str] | None): self.model_ids = model_ids + self.usage_summary = UsageSummary() self._lock = threading.Lock() - self.usages = { - m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids - } if model_ids else {} - - def track(self, model_id: str, usage: LMSamplingUsage): - if self.model_ids is not None and model_id not in self.model_ids: - return - with self._lock: - if not isinstance(usage, UsageNotAvailable) and model_id in self.usages: - self.usages[model_id] += usage - else: - self.usages[model_id] = usage + + def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool): + if self.model_ids is None or model_id in self.model_ids: + with self._lock: + self.usage_summary.update(model_id, usage, is_cached) @contextlib.contextmanager def track_usages( *lm: Union[str, LanguageModel] -) -> Iterator[dict[str, LMSamplingUsage]]: +) -> Iterator[UsageSummary]: """Context manager to track the usages of all language models in scope. `lf.track_usages` works with threads spawned by `lf.concurrent_map` and @@ -854,6 +970,6 @@ def track_usages( tracker = _UsageTracker(set(model_ids) if model_ids else None) with component.context(__usage_trackers__=trackers + [tracker]): try: - yield tracker.usages + yield tracker.usage_summary finally: pass diff --git a/langfun/core/language_model_test.py b/langfun/core/language_model_test.py index 359ca28..73af1f3 100644 --- a/langfun/core/language_model_test.py +++ b/langfun/core/language_model_test.py @@ -49,6 +49,7 @@ def fake_sample(prompts): prompt_tokens=100, completion_tokens=100, total_tokens=200, + estimated_cost=1.0, ), ) for prompt in prompts @@ -128,14 +129,15 @@ def test_sample(self): 'foo', score=-1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=-1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -144,14 +146,15 @@ def test_sample(self): 'bar', score=-1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=-1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), ], ) @@ -169,14 +172,15 @@ def test_sample(self): 'foo' * 2, score=0.5, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=0.5, logprobs=None, ), ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -185,7 +189,8 @@ def test_sample(self): 'bar' * 2, score=0.5, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=0.5, @@ -193,7 +198,8 @@ def test_sample(self): ), ], usage=lm_lib.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 + prompt_tokens=100, completion_tokens=100, total_tokens=200, + num_requests=1, estimated_cost=1.0, ), ), ] @@ -209,14 +215,15 @@ def test_sample(self): 'foo', score=1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=1.0, logprobs=None, ), ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -225,7 +232,8 @@ def test_sample(self): 'bar', score=1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=1.0, @@ -233,7 +241,8 @@ def test_sample(self): ), ], usage=lm_lib.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 + prompt_tokens=100, completion_tokens=100, total_tokens=200, + num_requests=1, estimated_cost=1.0, ), ), ] @@ -248,14 +257,15 @@ def test_sample(self): 'foo' * 2, score=0.7, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=0.7, logprobs=None, ), ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -264,7 +274,8 @@ def test_sample(self): 'bar' * 2, score=0.7, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=0.7, @@ -272,7 +283,8 @@ def test_sample(self): ), ], usage=lm_lib.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 + prompt_tokens=100, completion_tokens=100, total_tokens=200, + num_requests=1, estimated_cost=1.0, ), ), ] @@ -284,7 +296,9 @@ def test_call(self): self.assertEqual(response.text, 'foo') self.assertEqual(response.score, -1.0) self.assertIsNone(response.logprobs) - self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200)) + self.assertEqual( + response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0) + ) # Test override sampling_options. self.assertEqual( @@ -307,14 +321,17 @@ def test_using_cache(self): cache_seed=0, score=-1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage( + 100, 100, 200, 1, 1.0 + ), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=-1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -324,14 +341,15 @@ def test_using_cache(self): cache_seed=0, score=-1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=-1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), ], ) @@ -339,7 +357,9 @@ def test_using_cache(self): self.assertEqual(cache.stats.num_hits, 0) self.assertEqual(cache.stats.num_updates, 2) - self.assertEqual(lm('foo'), 'foo') + result = lm('foo') + self.assertEqual(result, 'foo') + self.assertTrue(result.metadata.is_cached) self.assertEqual(lm('bar'), 'bar') self.assertEqual(cache.stats.num_queries, 4) self.assertEqual(cache.stats.num_hits, 2) @@ -361,14 +381,15 @@ def test_using_cache(self): cache_seed=0, score=1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), lm_lib.LMSamplingResult( [ @@ -378,14 +399,15 @@ def test_using_cache(self): cache_seed=0, score=1.0, logprobs=None, - usage=lm_lib.LMSamplingUsage(100, 100, 200), + is_cached=False, + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), tags=[message_lib.Message.TAG_LM_RESPONSE], ), score=1.0, logprobs=None, ) ], - usage=lm_lib.LMSamplingUsage(100, 100, 200), + usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), ), ], ) @@ -663,20 +685,128 @@ def call_lm(prompt): lm2('hi') list(concurrent.concurrent_map(call_lm, ['hi', 'hello'])) - self.assertEqual(usages2, { - 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1), + print(usages2) + self.assertEqual(usages2.uncached.breakdown, { + 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), + }) + self.assertFalse(usages2.cached) + self.assertEqual(usages3.uncached.breakdown, { + 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0), }) - self.assertEqual(usages3, { - 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4), + self.assertFalse(usages3.cached) + self.assertEqual(usages4.uncached.breakdown, { + 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0), + 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), }) - self.assertEqual(usages4, { - 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4), - 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1), + self.assertFalse(usages4.cached) + self.assertEqual(usages1.uncached.breakdown, { + 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0), + 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), }) - self.assertEqual(usages1, { - 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5), - 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1), + self.assertFalse(usages1.cached) + self.assertEqual( + usages1.total, + lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0), + ) + + cache = in_memory.InMemory() + lm = MockModel(cache=cache, name='model1') + with lm_lib.track_usages() as usages1: + _ = lm('hi') + self.assertEqual(usages1.uncached.breakdown, { + 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0), }) + self.assertFalse(usages1.cached) + with lm_lib.track_usages() as usages2: + _ = lm('hi') + self.assertEqual(usages2.cached.breakdown, { + 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0), + }) + self.assertFalse(usages2.uncached) + + +class LMSamplingUsageTest(unittest.TestCase): + + def test_basics(self): + usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0) + self.assertEqual(usage.num_requests, 4) + self.assertEqual(usage.prompt_tokens, 100) + self.assertEqual(usage.completion_tokens, 200) + self.assertEqual(usage.total_tokens, 300) + self.assertEqual(usage.estimated_cost, 5.0) + self.assertEqual(usage.average_prompt_tokens, 25) + self.assertEqual(usage.average_completion_tokens, 50) + self.assertEqual(usage.average_total_tokens, 75) + self.assertEqual(usage.average_estimated_cost, 1.25) + + def test_add(self): + usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0) + usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0) + self.assertEqual(usage1 + usage2, usage1 + usage2) + self.assertIs(usage1 + None, usage1) + self.assertIs(None + usage1, usage1) + + def test_usage_not_available(self): + usage_not_available = lm_lib.UsageNotAvailable() + self.assertEqual(usage_not_available.prompt_tokens, 0) + self.assertEqual(usage_not_available.completion_tokens, 0) + self.assertEqual(usage_not_available.total_tokens, 0) + self.assertEqual(usage_not_available.average_prompt_tokens, 0) + self.assertEqual(usage_not_available.average_completion_tokens, 0) + self.assertEqual(usage_not_available.average_total_tokens, 0) + self.assertIsNone(usage_not_available.average_estimated_cost) + self.assertTrue(usage_not_available) + self.assertEqual( + usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0), + lm_lib.UsageNotAvailable(num_requests=5) + ) + self.assertEqual( + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available, + lm_lib.UsageNotAvailable(num_requests=5) + ) + self.assertIs(None + usage_not_available, usage_not_available) + self.assertIs(usage_not_available + None, usage_not_available) + + +class UsageSummaryTest(unittest.TestCase): + + def test_basics(self): + usage_summary = lm_lib.UsageSummary() + self.assertFalse(usage_summary.total) + self.assertFalse(usage_summary.cached) + self.assertFalse(usage_summary.uncached) + + # Add uncached. + usage_summary.update( + 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False + ) + self.assertEqual( + usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0) + ) + self.assertEqual( + usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0) + ) + # Add cached. + self.assertFalse(usage_summary.cached) + usage_summary.update( + 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True + ) + self.assertEqual( + usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0) + ) + self.assertEqual( + usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0) + ) + # Add UsageNotAvailable. + usage_summary.update( + 'model1', lm_lib.UsageNotAvailable(num_requests=1), False + ) + self.assertEqual( + usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3) + ) + self.assertEqual( + usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2) + ) if __name__ == '__main__': diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index 823a7ee..cb99daf 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -95,11 +95,18 @@ from langfun.core.llms.anthropic import Claude3Haiku from langfun.core.llms.groq import Groq +from langfun.core.llms.groq import GroqLlama3_2_3B +from langfun.core.llms.groq import GroqLlama3_2_1B +from langfun.core.llms.groq import GroqLlama3_1_70B +from langfun.core.llms.groq import GroqLlama3_1_8B from langfun.core.llms.groq import GroqLlama3_70B from langfun.core.llms.groq import GroqLlama3_8B from langfun.core.llms.groq import GroqLlama2_70B from langfun.core.llms.groq import GroqMistral_8x7B -from langfun.core.llms.groq import GroqGemma7B_IT +from langfun.core.llms.groq import GroqGemma2_9B_IT +from langfun.core.llms.groq import GroqGemma_7B_IT +from langfun.core.llms.groq import GroqWhisper_Large_v3 +from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo from langfun.core.llms.vertexai import VertexAI from langfun.core.llms.vertexai import VertexAIGemini1_5 diff --git a/langfun/core/llms/anthropic.py b/langfun/core/llms/anthropic.py index 5a42edc..e0339f1 100644 --- a/langfun/core/llms/anthropic.py +++ b/langfun/core/llms/anthropic.py @@ -28,15 +28,57 @@ # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated # as RPM/TPM of the largest-available model (Claude-3-Opus). + # Price in US dollars at https://www.anthropic.com/pricing + # as of 2024-10-10. 'claude-3-5-sonnet-20240620': pg.Dict( - max_tokens=4096, rpm=4000, tpm=400000 + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.015, + ), + 'claude-3-opus-20240229': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.015, + cost_per_1k_output_tokens=0.075, + ), + 'claude-3-sonnet-20240229': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.015, + ), + 'claude-3-haiku-20240307': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.00025, + cost_per_1k_output_tokens=0.00125, + ), + 'claude-2.1': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.008, + cost_per_1k_output_tokens=0.024, + ), + 'claude-2.0': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.008, + cost_per_1k_output_tokens=0.024, + ), + 'claude-instant-1.2': pg.Dict( + max_tokens=4096, + rpm=4000, + tpm=400000, + cost_per_1k_input_tokens=0.0008, + cost_per_1k_output_tokens=0.0024, ), - 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), - 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), - 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), - 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), - 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), - 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000), } @@ -107,6 +149,25 @@ def max_concurrency(self) -> int: requests_per_min=rpm, tokens_per_min=tpm ) + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_input_tokens', None + ) + cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_output_tokens', None + ) + if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None: + return None + return ( + cost_per_1k_input_tokens * num_input_tokens + + cost_per_1k_output_tokens * num_output_tokens + ) / 1000 + def request( self, prompt: lf.Message, @@ -181,6 +242,10 @@ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: prompt_tokens=input_tokens, completion_tokens=output_tokens, total_tokens=input_tokens + output_tokens, + estimated_cost=self.estimate_cost( + num_input_tokens=input_tokens, + num_output_tokens=output_tokens, + ), ), ) diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py index 820d1c3..fb35f94 100644 --- a/langfun/core/llms/cache/in_memory_test.py +++ b/langfun/core/llms/cache/in_memory_test.py @@ -66,14 +66,15 @@ def cache_entry(response_text, cache_seed=0): [ lf.LMSample( lf.AIMessage(response_text, cache_seed=cache_seed), - score=1.0 + score=1.0, ) ], usage=lf.LMSamplingUsage( 1, len(response_text), len(response_text) + 1, - ) + ), + is_cached=True, ) ) diff --git a/langfun/core/llms/fake_test.py b/langfun/core/llms/fake_test.py index a5416e0..08850f8 100644 --- a/langfun/core/llms/fake_test.py +++ b/langfun/core/llms/fake_test.py @@ -34,6 +34,7 @@ def test_sample(self): 'hi', score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(2, 2, 4), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -85,6 +86,7 @@ def test_sample(self): canned_response, score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(2, 38, 40), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -106,6 +108,7 @@ def test_sample(self): canned_response, score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(15, 38, 53), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -150,6 +153,7 @@ def test_sample(self): 'Hello', score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(2, 5, 7), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -166,6 +170,7 @@ def test_sample(self): 'I am fine, how about you?', score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(12, 25, 37), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -199,6 +204,7 @@ def test_sample(self): 'Hello', score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(2, 5, 7), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -215,6 +221,7 @@ def test_sample(self): 'I am fine, how about you?', score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(12, 25, 37), tags=[lf.Message.TAG_LM_RESPONSE], ), diff --git a/langfun/core/llms/groq.py b/langfun/core/llms/groq.py index 2fd90d6..6610bb1 100644 --- a/langfun/core/llms/groq.py +++ b/langfun/core/llms/groq.py @@ -24,11 +24,73 @@ SUPPORTED_MODELS_AND_SETTINGS = { # Refer https://console.groq.com/docs/models - 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16), - 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16), - 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16), - 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16), - 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16), + # Price in US dollars at https://groq.com/pricing/ as of 2024-10-10. + 'llama-3.2-3b-preview': pg.Dict( + max_tokens=8192, + max_concurrency=64, + cost_per_1k_input_tokens=0.00006, + cost_per_1k_output_tokens=0.00006, + ), + 'llama-3.2-1b-preview': pg.Dict( + max_tokens=8192, + max_concurrency=64, + cost_per_1k_input_tokens=0.00004, + cost_per_1k_output_tokens=0.00004, + ), + 'llama-3.1-70b-versatile': pg.Dict( + max_tokens=8192, + max_concurrency=16, + cost_per_1k_input_tokens=0.00059, + cost_per_1k_output_tokens=0.00079, + ), + 'llama-3.1-8b-instant': pg.Dict( + max_tokens=8192, + max_concurrency=32, + cost_per_1k_input_tokens=0.00005, + cost_per_1k_output_tokens=0.00008, + ), + 'llama3-70b-8192': pg.Dict( + max_tokens=8192, + max_concurrency=16, + cost_per_1k_input_tokens=0.00059, + cost_per_1k_output_tokens=0.00079, + ), + 'llama3-8b-8192': pg.Dict( + max_tokens=8192, + max_concurrency=32, + cost_per_1k_input_tokens=0.00005, + cost_per_1k_output_tokens=0.00008, + ), + 'llama2-70b-4096': pg.Dict( + max_tokens=4096, + max_concurrency=16, + ), + 'mixtral-8x7b-32768': pg.Dict( + max_tokens=32768, + max_concurrency=16, + cost_per_1k_input_tokens=0.00024, + cost_per_1k_output_tokens=0.00024, + ), + 'gemma2-9b-it': pg.Dict( + max_tokens=8192, + max_concurrency=32, + cost_per_1k_input_tokens=0.0002, + cost_per_1k_output_tokens=0.0002, + ), + 'gemma-7b-it': pg.Dict( + max_tokens=8192, + max_concurrency=32, + cost_per_1k_input_tokens=0.00007, + cost_per_1k_output_tokens=0.00007, + ), + 'whisper-large-v3': pg.Dict( + max_tokens=8192, + max_concurrency=16, + ), + 'whisper-large-v3-turbo': pg.Dict( + max_tokens=8192, + max_concurrency=16, + ) } @@ -89,6 +151,25 @@ def model_id(self) -> str: def max_concurrency(self) -> int: return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_input_tokens', None + ) + cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_output_tokens', None + ) + if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None: + return None + return ( + cost_per_1k_input_tokens * num_input_tokens + + cost_per_1k_output_tokens * num_output_tokens + ) / 1000 + def request( self, prompt: lf.Message, @@ -156,6 +237,10 @@ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: prompt_tokens=usage['prompt_tokens'], completion_tokens=usage['completion_tokens'], total_tokens=usage['total_tokens'], + estimated_cost=self.estimate_cost( + num_input_tokens=usage['prompt_tokens'], + num_output_tokens=usage['completion_tokens'], + ), ), ) @@ -170,6 +255,24 @@ def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message: ) +class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name + """Llama3.2-3B with 8K context window. + + See: https://huggingface.co/meta-llama/Llama-3.2-3B + """ + + model = 'llama-3.2-3b-preview' + + +class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name + """Llama3.2-1B with 8K context window. + + See: https://huggingface.co/meta-llama/Llama-3.2-1B + """ + + model = 'llama-3.2-3b-preview' + + class GroqLlama3_8B(Groq): # pylint: disable=invalid-name """Llama3-8B with 8K context window. @@ -179,6 +282,24 @@ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name model = 'llama3-8b-8192' +class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name + """Llama3.1-70B with 8K context window. + + See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long + """ + + model = 'llama-3.1-70b-versatile' + + +class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name + """Llama3.1-8B with 8K context window. + + See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long + """ + + model = 'llama-3.1-8b-instant' + + class GroqLlama3_70B(Groq): # pylint: disable=invalid-name """Llama3-70B with 8K context window. @@ -206,10 +327,37 @@ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name model = 'mixtral-8x7b-32768' -class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name +class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name + """Gemma2 9B with 8K context window. + + See: https://huggingface.co/google/gemma-2-9b-it + """ + + model = 'gemma2-9b-it' + + +class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name """Gemma 7B with 8K context window. See: https://huggingface.co/google/gemma-1.1-7b-it """ model = 'gemma-7b-it' + + +class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name + """Whisper Large V3 with 8K context window. + + See: https://huggingface.co/openai/whisper-large-v3 + """ + + model = 'whisper-large-v3' + + +class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name + """Whisper Large V3 Turbo with 8K context window. + + See: https://huggingface.co/openai/whisper-large-v3-turbo + """ + + model = 'whisper-large-v3-turbo' diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index dbaec7a..47c22c4 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -50,57 +50,285 @@ # Models from https://platform.openai.com/docs/models # RPM is from https://platform.openai.com/docs/guides/rate-limits # o1 (preview) models. - 'o1-preview': pg.Dict(rpm=10000, tpm=5000000), - 'o1-preview-2024-09-12': pg.Dict(rpm=10000, tpm=5000000), - 'o1-mini': pg.Dict(rpm=10000, tpm=5000000), - 'o1-mini-2024-09-12': pg.Dict(rpm=10000, tpm=5000000), + # Pricing in US dollars, from https://openai.com/api/pricing/ + # as of 2024-10-10. + 'o1-preview': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.015, + cost_per_1k_output_tokens=0.06, + ), + 'o1-preview-2024-09-12': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.015, + cost_per_1k_output_tokens=0.06, + ), + 'o1-mini': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.012, + ), + 'o1-mini-2024-09-12': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.012, + ), # GPT-4o models - 'gpt-4o-mini': pg.Dict(rpm=10000, tpm=5000000), - 'gpt-4o-mini-2024-07-18': pg.Dict(rpm=10000, tpm=5000000), - 'gpt-4o': pg.Dict(rpm=10000, tpm=5000000), - 'gpt-4o-2024-08-06': pg.Dict(rpm=10000, tpm=5000000), - 'gpt-4o-2024-05-13': pg.Dict(rpm=10000, tpm=5000000), + 'gpt-4o-mini': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.00015, + cost_per_1k_output_tokens=0.0006, + ), + 'gpt-4o-mini-2024-07-18': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.00015, + cost_per_1k_output_tokens=0.0006, + ), + 'gpt-4o': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.0025, + cost_per_1k_output_tokens=0.01, + ), + 'gpt-4o-2024-08-06': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.0025, + cost_per_1k_output_tokens=0.01, + ), + 'gpt-4o-2024-05-13': pg.Dict( + in_service=True, + rpm=10000, + tpm=5000000, + cost_per_1k_input_tokens=0.005, + cost_per_1k_output_tokens=0.015, + ), # GPT-4-Turbo models - 'gpt-4-turbo': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=2000000), + 'gpt-4-turbo': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), + 'gpt-4-turbo-2024-04-09': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), + 'gpt-4-turbo-preview': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), + 'gpt-4-0125-preview': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), + 'gpt-4-1106-preview': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), + 'gpt-4-vision-preview': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, + ), 'gpt-4-1106-vision-preview': pg.Dict( - rpm=10000, tpm=2000000 + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.01, + cost_per_1k_output_tokens=0.03, ), # GPT-4 models - 'gpt-4': pg.Dict(rpm=10000, tpm=300000), - 'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000), - 'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000), - 'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000), - 'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000), - 'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000), + 'gpt-4': pg.Dict( + in_service=True, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.03, + cost_per_1k_output_tokens=0.06, + ), + 'gpt-4-0613': pg.Dict( + in_service=False, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.03, + cost_per_1k_output_tokens=0.06, + ), + 'gpt-4-0314': pg.Dict( + in_service=False, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.03, + cost_per_1k_output_tokens=0.06, + ), + 'gpt-4-32k': pg.Dict( + in_service=True, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.06, + cost_per_1k_output_tokens=0.12, + ), + 'gpt-4-32k-0613': pg.Dict( + in_service=False, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.06, + cost_per_1k_output_tokens=0.12, + ), + 'gpt-4-32k-0314': pg.Dict( + in_service=False, + rpm=10000, + tpm=300000, + cost_per_1k_input_tokens=0.06, + cost_per_1k_output_tokens=0.12, + ), # GPT-3.5-Turbo models - 'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000), - 'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000), + 'gpt-3.5-turbo': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.0005, + cost_per_1k_output_tokens=0.0015, + ), + 'gpt-3.5-turbo-0125': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.0005, + cost_per_1k_output_tokens=0.0015, + ), + 'gpt-3.5-turbo-1106': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.001, + cost_per_1k_output_tokens=0.002, + ), + 'gpt-3.5-turbo-0613': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.0015, + cost_per_1k_output_tokens=0.002, + ), + 'gpt-3.5-turbo-0301': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.0015, + cost_per_1k_output_tokens=0.002, + ), + 'gpt-3.5-turbo-16k': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.004, + ), + 'gpt-3.5-turbo-16k-0613': pg.Dict( + in_service=True, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.004, + ), + 'gpt-3.5-turbo-16k-0301': pg.Dict( + in_service=False, + rpm=10000, + tpm=2000000, + cost_per_1k_input_tokens=0.003, + cost_per_1k_output_tokens=0.004, + ), # GPT-3.5 models - 'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), + 'text-davinci-003': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'text-davinci-002': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'code-davinci-002': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), # GPT-3 instruction-tuned models - 'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), + 'text-curie-001': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'text-babbage-001': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM, + ), + 'text-ada-001': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM, + ), + 'davinci': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM, + ), + 'curie': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'babbage': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'ada': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), # GPT-3 base models - 'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), - 'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM), + 'babbage-002': pg.Dict( + in_service=False, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), + 'davinci-002': pg.Dict( + in_service=True, + rpm=_DEFAULT_RPM, + tpm=_DEFAULT_TPM + ), } @@ -172,6 +400,25 @@ def max_concurrency(self) -> int: requests_per_min=rpm, tokens_per_min=tpm ) + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_input_tokens', None + ) + cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_output_tokens', None + ) + if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None: + return None + return ( + cost_per_1k_input_tokens * num_input_tokens + + cost_per_1k_output_tokens * num_output_tokens + ) / 1000 + @classmethod def dir(cls): assert openai is not None @@ -239,10 +486,17 @@ def _open_ai_completion(prompts): ) n = len(samples_by_index) + estimated_cost = self.estimate_cost( + num_input_tokens=response.usage.prompt_tokens, + num_output_tokens=response.usage.completion_tokens, + ) usage = lf.LMSamplingUsage( prompt_tokens=response.usage.prompt_tokens // n, completion_tokens=response.usage.completion_tokens // n, total_tokens=response.usage.total_tokens // n, + estimated_cost=( + None if estimated_cost is None else (estimated_cost // n) + ) ) return [ lf.LMSamplingResult(samples_by_index[index], usage=usage) @@ -350,6 +604,10 @@ def _open_ai_chat_completion(prompt: lf.Message): prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + estimated_cost=self.estimate_cost( + num_input_tokens=response.usage.prompt_tokens, + num_output_tokens=response.usage.completion_tokens, + ) ), ) diff --git a/langfun/core/llms/openai_test.py b/langfun/core/llms/openai_test.py index 7e13caa..22fd5c5 100644 --- a/langfun/core/llms/openai_test.py +++ b/langfun/core/llms/openai_test.py @@ -210,6 +210,7 @@ def test_sample_completion(self): 'Sample 0 for prompt 0.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -225,6 +226,7 @@ def test_sample_completion(self): 'Sample 1 for prompt 0.', score=0.1, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -240,6 +242,7 @@ def test_sample_completion(self): 'Sample 2 for prompt 0.', score=0.2, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -265,6 +268,7 @@ def test_sample_completion(self): 'Sample 0 for prompt 1.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -280,6 +284,7 @@ def test_sample_completion(self): 'Sample 1 for prompt 1.', score=0.1, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -295,6 +300,7 @@ def test_sample_completion(self): 'Sample 2 for prompt 1.', score=0.2, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=16, completion_tokens=16, @@ -315,12 +321,17 @@ def test_sample_completion(self): def test_sample_chat_completion(self): with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: mock_chat_completion.side_effect = mock_chat_completion_query + openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({ + 'cost_per_1k_input_tokens': 1.0, + 'cost_per_1k_output_tokens': 1.0, + }) lm = openai.OpenAI(api_key='test_key', model='gpt-4') results = lm.sample( ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3) ) self.assertEqual(len(results), 2) + print(results[0]) self.assertEqual( results[0], lf.LMSamplingResult( @@ -330,10 +341,12 @@ def test_sample_chat_completion(self): 'Sample 0 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -345,10 +358,12 @@ def test_sample_chat_completion(self): 'Sample 1 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -360,10 +375,12 @@ def test_sample_chat_completion(self): 'Sample 2 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -372,7 +389,8 @@ def test_sample_chat_completion(self): ), ], usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 + prompt_tokens=100, completion_tokens=100, total_tokens=200, + estimated_cost=0.2, ), ), ) @@ -385,10 +403,12 @@ def test_sample_chat_completion(self): 'Sample 0 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -400,10 +420,12 @@ def test_sample_chat_completion(self): 'Sample 1 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -415,10 +437,12 @@ def test_sample_chat_completion(self): 'Sample 2 for message.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=33, completion_tokens=33, - total_tokens=66 + total_tokens=66, + estimated_cost=0.2 / 3, ), tags=[lf.Message.TAG_LM_RESPONSE], ), @@ -427,7 +451,8 @@ def test_sample_chat_completion(self): ), ], usage=lf.LMSamplingUsage( - prompt_tokens=100, completion_tokens=100, total_tokens=200 + prompt_tokens=100, completion_tokens=100, total_tokens=200, + estimated_cost=0.2, ), ), ) @@ -449,6 +474,7 @@ def test_sample_with_contextual_options(self): 'Sample 0 for prompt 0.', score=0.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=50, completion_tokens=50, @@ -464,6 +490,7 @@ def test_sample_with_contextual_options(self): 'Sample 1 for prompt 0.', score=0.1, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage( prompt_tokens=50, completion_tokens=50, diff --git a/langfun/core/llms/vertexai.py b/langfun/core/llms/vertexai.py index 3274292..042bf09 100644 --- a/langfun/core/llms/vertexai.py +++ b/langfun/core/llms/vertexai.py @@ -40,24 +40,106 @@ Credentials = Any +# https://cloud.google.com/vertex-ai/generative-ai/pricing +# describes that the average number of characters per token is about 4. +AVGERAGE_CHARS_PER_TOEKN = 4 + + +# Price in US dollars, +# from https://cloud.google.com/vertex-ai/generative-ai/pricing +# as of 2024-10-10. SUPPORTED_MODELS_AND_SETTINGS = { - 'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-pro-002': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-flash-002': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-pro': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-flash': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-pro-latest': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-flash-latest': pg.Dict(api='gemini', rpm=500), - 'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=50), - 'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=50), - 'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=200), - 'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300), - 'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100), + 'gemini-1.5-pro-001': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-pro-002': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-flash-002': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.00001875, + cost_per_1k_output_chars=0.000075, + ), + 'gemini-1.5-flash-001': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.00001875, + cost_per_1k_output_chars=0.000075, + ), + 'gemini-1.5-pro': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-flash': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.00001875, + cost_per_1k_output_chars=0.000075, + ), + 'gemini-1.5-pro-latest': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-flash-latest': pg.Dict( + api='gemini', + rpm=500, + cost_per_1k_input_chars=0.00001875, + cost_per_1k_output_chars=0.000075, + ), + 'gemini-1.5-pro-preview-0514': pg.Dict( + api='gemini', + rpm=50, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-pro-preview-0409': pg.Dict( + api='gemini', + rpm=50, + cost_per_1k_input_chars=0.0003125, + cost_per_1k_output_chars=0.00125, + ), + 'gemini-1.5-flash-preview-0514': pg.Dict( + api='gemini', + rpm=200, + cost_per_1k_input_chars=0.00001875, + cost_per_1k_output_chars=0.000075, + ), + 'gemini-1.0-pro': pg.Dict( + api='gemini', + rpm=300, + cost_per_1k_input_chars=0.000125, + cost_per_1k_output_chars=0.000375, + ), + 'gemini-1.0-pro-vision': pg.Dict( + api='gemini', + rpm=100, + cost_per_1k_input_chars=0.000125, + cost_per_1k_output_chars=0.000375, + ), # PaLM APIs. - 'text-bison': pg.Dict(api='palm', rpm=1600), - 'text-bison-32k': pg.Dict(api='palm', rpm=300), - 'text-unicorn': pg.Dict(api='palm', rpm=100), + 'text-bison': pg.Dict( + api='palm', + rpm=1600 + ), + 'text-bison-32k': pg.Dict( + api='palm', + rpm=300 + ), + 'text-unicorn': pg.Dict( + api='palm', + rpm=100 + ), # Endpoint # TODO(chengrun): Set a more appropriate rpm for endpoint. 'custom': pg.Dict(api='endpoint', rpm=20), @@ -161,6 +243,25 @@ def max_concurrency(self) -> int: tokens_per_min=0, ) + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_input_chars', None + ) + cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_output_chars', None + ) + if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None: + return None + return ( + cost_per_1k_input_chars * num_input_tokens + + cost_per_1k_output_chars * num_output_tokens + ) * AVGERAGE_CHARS_PER_TOEKN / 1000 + def _generation_config( self, prompt: lf.Message, options: lf.LMSamplingOptions ) -> Any: # generative_models.GenerationConfig @@ -285,6 +386,10 @@ def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult: prompt_tokens=usage_metadata.prompt_token_count, completion_tokens=usage_metadata.candidates_token_count, total_tokens=usage_metadata.total_token_count, + estimated_cost=self.estimate_cost( + num_input_tokens=usage_metadata.prompt_token_count, + num_output_tokens=usage_metadata.candidates_token_count, + ), ) return lf.LMSamplingResult( [ diff --git a/langfun/core/message_test.py b/langfun/core/message_test.py index ada8f48..0177d4f 100644 --- a/langfun/core/message_test.py +++ b/langfun/core/message_test.py @@ -385,7 +385,7 @@ def test_html_ai_message(self): self.assert_html_content( ai_message.to_html(enable_summary_tooltip=False), """ - AIMessage(...)lm-responselm-outputMy name is GeminiresultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323llm usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1metadataDict(...)resultmetadata.resultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323usagemetadata.usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1sourceUserMessage(...)lm-inputWhat is in this image?imageCustomModality(...)contentsource.metadata.image.content'foo'this is a testmetadataDict(...)imagesource.metadata.imageCustomModality(...)contentsource.metadata.image.content'foo' + AIMessage(...)lm-responselm-outputMy name is GeminiresultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323llm usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1estimated_costmetadata.usage.estimated_costNonemetadataDict(...)resultmetadata.resultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323usagemetadata.usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1estimated_costmetadata.usage.estimated_costNonesourceUserMessage(...)lm-inputWhat is in this image?imageCustomModality(...)contentsource.metadata.image.content'foo'this is a testmetadataDict(...)imagesource.metadata.imageCustomModality(...)contentsource.metadata.image.content'foo' """ ) self.assert_html_content( @@ -399,7 +399,7 @@ def test_html_ai_message(self): source_tag=None, ), """ - AIMessage(...)lm-responselm-outputMy name is GeminiresultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323llm usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1metadataDict(...)resultmetadata.resultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323usagemetadata.usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1sourceUserMessage(...)lm-inputWhat is in this image?imageCustomModality(...)contentsource.metadata.image.content'foo'this is a testmetadataDict(...)imagesource.metadata.imageCustomModality(...)contentsource.metadata.image.content'foo'sourceUserMessage(...)User inputmetadataDict(...) + AIMessage(...)lm-responselm-outputMy name is GeminiresultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323llm usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1estimated_costmetadata.usage.estimated_costNonemetadataDict(...)resultmetadata.resultDict(...)xmetadata.result.x1ymetadata.result.y2zmetadata.result.zDict(...)ametadata.result.z.aList(...)0metadata.result.z.a[0]121metadata.result.z.a[1]323usagemetadata.usageLMSamplingUsage(...)prompt_tokensmetadata.usage.prompt_tokens10completion_tokensmetadata.usage.completion_tokens2total_tokensmetadata.usage.total_tokens12num_requestsmetadata.usage.num_requests1estimated_costmetadata.usage.estimated_costNonesourceUserMessage(...)lm-inputWhat is in this image?imageCustomModality(...)contentsource.metadata.image.content'foo'this is a testmetadataDict(...)imagesource.metadata.imageCustomModality(...)contentsource.metadata.image.content'foo'sourceUserMessage(...)User inputmetadataDict(...) """ ) diff --git a/langfun/core/structured/completion_test.py b/langfun/core/structured/completion_test.py index a0cdf35..26a44f4 100644 --- a/langfun/core/structured/completion_test.py +++ b/langfun/core/structured/completion_test.py @@ -581,6 +581,7 @@ def test_returns_message(self): text='Activity(description="foo")', result=Activity(description='foo'), score=1.0, + is_cached=False, logprobs=None, usage=lf.LMSamplingUsage(553, 27, 580), tags=['lm-response', 'lm-output', 'transformed'] diff --git a/langfun/core/structured/parsing_test.py b/langfun/core/structured/parsing_test.py index 783d631..2af696b 100644 --- a/langfun/core/structured/parsing_test.py +++ b/langfun/core/structured/parsing_test.py @@ -285,7 +285,7 @@ def test_parse(self): self.assertEqual( r, lf.AIMessage( - '1', score=1.0, result=1, logprobs=None, + '1', score=1.0, result=1, logprobs=None, is_cached=False, usage=lf.LMSamplingUsage(652, 1, 653), tags=['lm-response', 'lm-output', 'transformed'] ), @@ -645,6 +645,7 @@ def test_call_with_returning_message(self): result=3, score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(315, 1, 316), tags=['lm-response', 'lm-output', 'transformed'] ), diff --git a/langfun/core/structured/prompting_test.py b/langfun/core/structured/prompting_test.py index f6b2de3..6d7eae8 100644 --- a/langfun/core/structured/prompting_test.py +++ b/langfun/core/structured/prompting_test.py @@ -82,6 +82,7 @@ def test_call(self): result=1, score=1.0, logprobs=None, + is_cached=False, usage=lf.LMSamplingUsage(323, 1, 324), tags=['lm-response', 'lm-output', 'transformed'], ), diff --git a/langfun/core/templates/selfplay_test.py b/langfun/core/templates/selfplay_test.py index f24d252..c5e7c22 100644 --- a/langfun/core/templates/selfplay_test.py +++ b/langfun/core/templates/selfplay_test.py @@ -59,7 +59,8 @@ def test_play(self): self.assertEqual( g(), lf.AIMessage( - '10', score=0.0, logprobs=None, usage=lf.UsageNotAvailable() + '10', score=0.0, logprobs=None, is_cached=False, + usage=lf.UsageNotAvailable() ) ) @@ -72,7 +73,8 @@ def test_play_with_num_turns(self): self.assertEqual( g(), lf.AIMessage( - '2', score=0.0, logprobs=None, usage=lf.UsageNotAvailable() + '2', score=0.0, logprobs=None, is_cached=False, + usage=lf.UsageNotAvailable() ) )