diff --git a/langfun/core/__init__.py b/langfun/core/__init__.py index ee3e154a..2387742c 100644 --- a/langfun/core/__init__.py +++ b/langfun/core/__init__.py @@ -98,6 +98,7 @@ from langfun.core.language_model import LMSample from langfun.core.language_model import LMSamplingOptions from langfun.core.language_model import LMSamplingResult +from langfun.core.language_model import LMCache # Components for building agents. from langfun.core.memory import Memory diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index c1423fa7..82c0cc8b 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -94,8 +94,8 @@ def test_call(self): "LangFunc(template_str='Hello', clean=True, returns=None, " 'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, ' 'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), ' - 'timeout=120.0, max_attempts=5, debug=False), input_transform=None, ' - 'output_transform=None)', + 'cache=None, timeout=120.0, max_attempts=5, debug=False), ' + 'input_transform=None, output_transform=None)', ) l = LangFunc('Hello') diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 9a9fe110..cf0d2d7b 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -15,7 +15,7 @@ import abc import time -from typing import Annotated +from typing import Annotated, Any from langfun.core import component from langfun.core import console from langfun.core import message as message_lib @@ -78,6 +78,34 @@ class LMSamplingOptions(component.Component): int | None, 'A fixed random seed used during model inference.' ] = None + def cache_key(self) -> tuple[Any, ...]: + """Returns a tuple of current values as cache key.""" + return ( + self.temperature, + self.max_tokens, + self.n, + self.top_k, + self.top_p, + self.random_seed + ) + + +class LMCache(pg.Object): + """Interface for LM cache.""" + + @abc.abstractmethod + def get(self, + lm: 'LanguageModel', + prompt: message_lib.Message) -> LMSamplingResult | None: + """Gets the cached result of a prompt generated by a language model.""" + + @abc.abstractmethod + def put(self, + lm: 'LanguageModel', + prompt: message_lib.Message, + result: LMSamplingResult) -> None: + """Puts the result of a prompt generated by a language model in cache.""" + class LanguageModel(component.Component): """Interface of a language model. @@ -91,6 +119,13 @@ class LanguageModel(component.Component): sampling_options: LMSamplingOptions = LMSamplingOptions() + cache: Annotated[ + LMCache | None, + ( + 'Sampling cache. If None, no cache will be used.' + ) + ] = None + timeout: Annotated[ float | None, 'Timeout in seconds. If None, there is no timeout.' ] = 120.0 @@ -130,15 +165,51 @@ def _on_bound(self): super()._on_bound() self._call_counter = 0 + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return self.__class__.__name__ + def sample(self, prompts: list[str | message_lib.Message], **kwargs) -> list[LMSamplingResult]: """Samples one or multiple prompts.""" + prompts = [message_lib.UserMessage.from_value(p) for p in prompts] + with component.context(override_attrs=True, **kwargs): - return self._sample([ - message_lib.UserMessage.from_value(p) - for p in prompts - ]) + if self.cache is None: + return self._sample(prompts) + else: + return self._sample_with_cache_lookup(prompts) + + def _sample_with_cache_lookup( + self, prompts: list[str | message_lib.Message]) -> list[LMSamplingResult]: + """Sample with cache lookup.""" + assert self.cache is not None + + results = [None] * len(prompts) + requests, request_to_result_index = [], {} + + # Perform cache lookup and figure out sampling requests to make. + for i, prompt in enumerate(prompts): + r = self.cache.get(self, prompt) + if r is None: + request_to_result_index[len(requests)] = i + requests.append(prompt) + else: + results[i] = r.clone() + + # Sample non-cache-hit prompts. + requested_results = self._sample(requests) + assert len(requested_results) == len(requests), ( + requests, requested_results) + + # Combine cached results and newly requested results. + for i, (prompt, result) in enumerate(zip(requests, requested_results)): + results[request_to_result_index[i]] = result + self.cache.put(self, prompt, result) + + return results # pytype: disable=bad-return-type @abc.abstractmethod def _sample( diff --git a/langfun/core/language_model_test.py b/langfun/core/language_model_test.py index 8bed2305..5e8c765d 100644 --- a/langfun/core/language_model_test.py +++ b/langfun/core/language_model_test.py @@ -52,11 +52,49 @@ def fake_sample(prompts): )(prompts) +class SimpleCache(lm_lib.LMCache): + + def _on_bound(self): + super()._on_bound() + self._cache = {} + self.cache_hit = 0 + + def get(self, lm, prompt): + del lm + r = self._cache.get(prompt.text) + if r is not None: + self.cache_hit += 1 + return r + + def put(self, lm, prompt, result): + self._cache[prompt.text] = result + + @property + def num_records(self): + return len(self._cache) + + +class LMSamplingOptionsTest(unittest.TestCase): + """Tests for LMSamplingOptions.""" + + def test_cache_key(self): + options = lm_lib.LMSamplingOptions() + key1 = options.cache_key() + self.assertEqual(key1, (0.0, 1024, 1, 40, None, None)) + with options.override(temperature=1.0, max_tokens=256): + key2 = options.cache_key() + self.assertEqual(key2, (1.0, 256, 1, 40, None, None)) + + # Make sure key1 does not change upon override. + self.assertEqual(key1, (0.0, 1024, 1, 40, None, None)) + + class LanguageModelTest(unittest.TestCase): """Tests for LanguageModel.""" def test_init(self): lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2) + self.assertEqual(lm.model_id, 'MockModel') self.assertEqual(lm.failures_before_attempt, 1) self.assertEqual(lm.sampling_options.temperature, 0.5) self.assertEqual(lm.sampling_options.top_k, 2) @@ -117,6 +155,31 @@ def test_call(self): # Test override individual flags within sampling_options. self.assertEqual(lm('foo', top_k=2), 'foo' * 2) + def test_using_cache(self): + cache = SimpleCache() + lm = MockModel(cache=cache, top_k=1) + self.assertEqual( + lm.sample(prompts=['foo', 'bar']), + [ + lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]), + lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]), + ]) + + self.assertEqual(cache.cache_hit, 0) + self.assertEqual(cache.num_records, 2) + self.assertEqual( + lm.sample(prompts=['foo', 'baz'], temperature=1.0), + [ + lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]), + lm_lib.LMSamplingResult([lm_lib.LMSample('baz', score=1.0)]), + ]) + self.assertEqual(cache.cache_hit, 1) + self.assertEqual(cache.num_records, 3) + + self.assertEqual(lm('baz', temperature=1.0), 'baz') + self.assertEqual(cache.cache_hit, 2) + self.assertEqual(cache.num_records, 3) + def test_retry(self): lm = MockModel( failures_before_attempt=1, top_k=1, diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index 5a821ada..1e7c4dbc 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -36,5 +36,8 @@ # Placeholder for Google-internal imports. +# Include cache as sub-module. +from langfun.core.llms import cache + # pylint: enable=g-bad-import-order # pylint: enable=g-importing-member diff --git a/langfun/core/llms/cache/__init__.py b/langfun/core/llms/cache/__init__.py new file mode 100644 index 00000000..b5f4717a --- /dev/null +++ b/langfun/core/llms/cache/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""langfun LLM cache implementations.""" + +# pylint: disable=g-importing-member +# pylint: disable=g-bad-import-order + +from langfun.core.llms.cache.base import LMCacheBase +from langfun.core.llms.cache.base import LMCacheEntry + +from langfun.core.llms.cache.in_memory import InMemory + + +# pylint: enable=g-bad-import-order +# pylint: enable=g-importing-member diff --git a/langfun/core/llms/cache/base.py b/langfun/core/llms/cache/base.py new file mode 100644 index 00000000..2d67069d --- /dev/null +++ b/langfun/core/llms/cache/base.py @@ -0,0 +1,86 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LM cache base.""" + +import abc +import dataclasses +import datetime +from typing import Annotated, Any, Callable +import langfun.core as lf + + +@dataclasses.dataclass(frozen=True) +class LMCacheEntry: + result: lf.LMSamplingResult + expire: datetime.datetime | None + + +class LMCacheBase(lf.LMCache): + """The common LMCache base.""" + + key: Annotated[ + Callable[[lf.LanguageModel, lf.Message], Any] | None, + ( + 'A callable ojbect used for computing the key (hashable structure) ' + 'from the language model used and input prompt. If None, a default ' + 'key will be used, which are sensitive to the model id, sampling ' + 'options and the input prompt.' + ) + ] = None + + ttl: Annotated[ + int | None, + ( + 'Time-to-live in seconds.' + ) + ] = None + + def _on_bound(self): + super()._on_bound() + self._key = self.key or default_key + + def get(self, + lm: lf.LanguageModel, + prompt: lf.Message) -> lf.LMSamplingResult | None: + """Gets the cached result of a prompt generated by a language model.""" + entry = self._get(self._key(lm, prompt)) + if entry is None: + return None + if entry.expire is not None and entry.expire < datetime.datetime.now(): + return None + return entry.result + + def put(self, + lm: lf.LanguageModel, + prompt: lf.Message, + result: lf.LMSamplingResult) -> None: + """Puts the result of a prompt generated by a language model in cache.""" + expire = None + if self.ttl: + expire = datetime.datetime.now() + datetime.timedelta(seconds=self.ttl) + entry = LMCacheEntry(result, expire) + self._put(self._key(lm, prompt), entry) + + @abc.abstractmethod + def _get(self, key: Any) -> LMCacheEntry | None: + """Returns a LM cache entry associated with the key.""" + + @abc.abstractmethod + def _put(self, key: Any, entry: LMCacheEntry) -> None: + """Puts a LM cache entry associated with the key.""" + + +def default_key(lm: lf.LanguageModel, prompt: lf.Message) -> Any: + """Default key for LM cache.""" + return (lm.model_id, lm.sampling_options.cache_key(), prompt.text) diff --git a/langfun/core/llms/cache/in_memory.py b/langfun/core/llms/cache/in_memory.py new file mode 100644 index 00000000..2d75bf93 --- /dev/null +++ b/langfun/core/llms/cache/in_memory.py @@ -0,0 +1,40 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""In-memory LM cache.""" + +from typing import Any +from langfun.core.llms.cache import base + + +class InMemory(base.LMCacheBase): + """In memory cache.""" + + def _get(self, key: Any) -> base.LMCacheEntry | None: + """Returns a LM cache entry associated with the key.""" + return _CACHE_MEMORY.get(key, None) + + def _put(self, key: Any, entry: base.LMCacheEntry) -> None: + """Puts a LM cache entry associated with the key.""" + _CACHE_MEMORY[key] = entry + + def reset(self) -> None: + """Resets the cache.""" + _CACHE_MEMORY.clear() + + +# NOTE(daiyip): We install a process-level cache store, so different InMemory() +# object could access the same memory. This is not a problem across different +# language models, since the `model_id` of the language model is included as a +# part of the cache key. +_CACHE_MEMORY = {} diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py new file mode 100644 index 00000000..c340bc92 --- /dev/null +++ b/langfun/core/llms/cache/in_memory_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for language model.""" + +import time +import unittest + +from langfun.core.llms import fake +from langfun.core.llms.cache import in_memory + + +class InMemoryLMCacheTest(unittest.TestCase): + + def test_basics(self): + in_memory.InMemory().reset() + lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('b'), '2') + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('c'), '3') + + def test_ttl(self): + in_memory.InMemory().reset() + lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory(ttl=1)) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('a'), '1') + time.sleep(2) + self.assertEqual(lm('a'), '2') + + def test_different_sampling_options(self): + in_memory.InMemory().reset() + lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('a', temperature=1.0), '2') + + def test_different_model(self): + lm1 = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) + lm2 = fake.Echo(cache=in_memory.InMemory()) + + self.assertEqual(lm1('a'), '1') + self.assertEqual(lm2('a'), 'a') + self.assertEqual(lm1('a'), '1') + self.assertEqual(lm1('b'), '2') + self.assertEqual(lm2('b'), 'b') + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index e3dcaf19..0ea13196 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -15,7 +15,7 @@ import collections import os -from typing import Annotated, Any +from typing import Annotated, Any, Literal import langfun.core as lf import openai import pyglove as pg @@ -40,7 +40,7 @@ class OpenAI(lf.LanguageModel): """OpenAI model.""" model: pg.typing.Annotated[ - pg.typing.Enum[ + Literal[ 'gpt-4', 'gpt-4-32k', 'gpt-3.5-turbo', @@ -91,6 +91,11 @@ def _on_bound(self): if org: openai.organization = org + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f'OpenAI({self.model})' + @classmethod def dir(cls): return openai.Model.list() diff --git a/langfun/core/llms/openai_test.py b/langfun/core/llms/openai_test.py index e1668722..cbcac703 100644 --- a/langfun/core/llms/openai_test.py +++ b/langfun/core/llms/openai_test.py @@ -57,6 +57,10 @@ def mock_chat_completion_query(messages, *, n=1, **kwargs): class OpenaiTest(unittest.TestCase): """Tests for OpenAI language model.""" + def test_model_id(self): + self.assertEqual( + openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)') + def test_call_completion(self): with mock.patch('openai.Completion.create') as mock_completion: mock_completion.side_effect = mock_completion_query