diff --git a/langfun/core/llms/cache/base.py b/langfun/core/llms/cache/base.py index 2d67069d..3f92653c 100644 --- a/langfun/core/llms/cache/base.py +++ b/langfun/core/llms/cache/base.py @@ -22,8 +22,9 @@ @dataclasses.dataclass(frozen=True) class LMCacheEntry: + """LM cache entry.""" result: lf.LMSamplingResult - expire: datetime.datetime | None + expire: datetime.datetime | None = None class LMCacheBase(lf.LMCache): @@ -34,8 +35,8 @@ class LMCacheBase(lf.LMCache): ( 'A callable ojbect used for computing the key (hashable structure) ' 'from the language model used and input prompt. If None, a default ' - 'key will be used, which are sensitive to the model id, sampling ' - 'options and the input prompt.' + 'key will be used, which is sensitive to the sampling options and ' + 'input prompt.' ) ] = None @@ -54,7 +55,7 @@ def get(self, lm: lf.LanguageModel, prompt: lf.Message) -> lf.LMSamplingResult | None: """Gets the cached result of a prompt generated by a language model.""" - entry = self._get(self._key(lm, prompt)) + entry = self._get(lm.model_id, self._key(lm, prompt)) if entry is None: return None if entry.expire is not None and entry.expire < datetime.datetime.now(): @@ -70,17 +71,17 @@ def put(self, if self.ttl: expire = datetime.datetime.now() + datetime.timedelta(seconds=self.ttl) entry = LMCacheEntry(result, expire) - self._put(self._key(lm, prompt), entry) + self._put(lm.model_id, self._key(lm, prompt), entry) @abc.abstractmethod - def _get(self, key: Any) -> LMCacheEntry | None: + def _get(self, model_id: str, key: str) -> LMCacheEntry | None: """Returns a LM cache entry associated with the key.""" @abc.abstractmethod - def _put(self, key: Any, entry: LMCacheEntry) -> None: + def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None: """Puts a LM cache entry associated with the key.""" def default_key(lm: lf.LanguageModel, prompt: lf.Message) -> Any: """Default key for LM cache.""" - return (lm.model_id, lm.sampling_options.cache_key(), prompt.text) + return (prompt.text, lm.sampling_options.cache_key()) diff --git a/langfun/core/llms/cache/in_memory.py b/langfun/core/llms/cache/in_memory.py index fd784de5..a2bded3c 100644 --- a/langfun/core/llms/cache/in_memory.py +++ b/langfun/core/llms/cache/in_memory.py @@ -13,29 +13,100 @@ # limitations under the License. """In-memory LM cache.""" -from typing import Any +import collections +from typing import Annotated, Any, Iterator from langfun.core.llms.cache import base +import pyglove as pg +@pg.use_init_args(['filename', 'ttl', 'key']) class InMemory(base.LMCacheBase): """In memory cache.""" - def _get(self, key: Any) -> base.LMCacheEntry | None: + filename: Annotated[ + str | None, + ( + 'File name to load and save in memory cache.' + ) + ] = None + + def _on_bound(self) -> None: + super()._on_bound() + self._cache = collections.defaultdict(dict) + + if self.filename is not None: + records = pg.load(self.filename) + for record in records: + model_cache = {} + for entry in record.entries: + model_cache[entry.k] = entry.v + self._cache[record.model_id] = model_cache + + def model_ids(self) -> list[str]: + """Returns the model ids of cached queires.""" + return list(self._cache.keys()) + + def __len__(self) -> int: + """Returns the number of entries in the cache.""" + return sum(len(v) for v in self._cache.values()) + + def keys(self, model_id: str | None = None) -> Iterator[str]: + """Returns the cached keys for a model.""" + if model_id is None: + for model_cache in self._cache.values(): + for k in model_cache.keys(): + yield k + else: + for k in self._cache[model_id].keys(): + yield k + + def values(self, model_id: str | None = None) -> Iterator[base.LMCacheEntry]: + """Returns the cached entries for a model.""" + if model_id is None: + for model_cache in self._cache.values(): + for v in model_cache.values(): + yield v + else: + for v in self._cache[model_id].values(): + yield v + + def items( + self, + model_id: str | None = None + ) -> Iterator[tuple[str, base.LMCacheEntry]]: + """Returns the cached items for a model.""" + if model_id is None: + for model_cache in self._cache.values(): + for k, v in model_cache.items(): + yield k, v + else: + for k, v in self._cache[model_id].items(): + yield k, v + + def _get(self, model_id: str, key: Any) -> base.LMCacheEntry | None: """Returns a LM cache entry associated with the key.""" - return _CACHE_MEMORY.get(key, None) + return self._cache[model_id].get(key, None) - def _put(self, key: Any, entry: base.LMCacheEntry) -> None: + def _put(self, model_id: str, key: Any, entry: base.LMCacheEntry) -> None: """Puts a LM cache entry associated with the key.""" - _CACHE_MEMORY[key] = entry + self._cache[model_id][key] = entry - @classmethod - def reset(cls) -> None: + def reset(self, model_id: str | None = None) -> None: """Resets the cache.""" - _CACHE_MEMORY.clear() + if model_id is not None: + self._cache[model_id].clear() + else: + self._cache.clear() + def _sym_clone(self, deep: bool, memo: Any = None) -> 'InMemory': + v = super()._sym_clone(deep, memo) + v._cache = self._cache # pylint: disable=protected-access + return v -# NOTE(daiyip): We install a process-level cache store, so different InMemory() -# object could access the same memory. This is not a problem across different -# language models, since the `model_id` of the language model is included as a -# part of the cache key. -_CACHE_MEMORY = {} + def save(self, path: str) -> None: + """Saves the in-memory cache.""" + records = [] + for model_id in self.model_ids(): + entries = [dict(k=k, v=v) for k, v in self.items(model_id)] + records.append(dict(model_id=model_id, entries=entries)) + pg.save(records, path) diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py index 81b9672e..6aa6ecad 100644 --- a/langfun/core/llms/cache/in_memory_test.py +++ b/langfun/core/llms/cache/in_memory_test.py @@ -13,26 +13,109 @@ # limitations under the License. """Tests for language model.""" +import copy +import os +import tempfile import time import unittest +import langfun.core as lf from langfun.core.llms import fake +from langfun.core.llms.cache import base from langfun.core.llms.cache import in_memory +import pyglove as pg + class InMemoryLMCacheTest(unittest.TestCase): def test_basics(self): - in_memory.InMemory.reset() - lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) + cache = in_memory.InMemory() + lm = fake.StaticSequence(['1', '2', '3'], cache=cache) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1') self.assertEqual(lm('b'), '2') self.assertEqual(lm('a'), '1') self.assertEqual(lm('c'), '3') + self.assertEqual(cache.model_ids(), ['StaticSequence']) + self.assertEqual( + list(cache.keys()), + [ + ('a', (0.0, 1024, 1, 40, None, None)), + ('b', (0.0, 1024, 1, 40, None, None)), + ('c', (0.0, 1024, 1, 40, None, None)), + ]) + self.assertEqual( + list(cache.keys('StaticSequence')), + [ + ('a', (0.0, 1024, 1, 40, None, None)), + ('b', (0.0, 1024, 1, 40, None, None)), + ('c', (0.0, 1024, 1, 40, None, None)), + ]) + + def cache_entry(response_text): + return base.LMCacheEntry( + lf.LMSamplingResult([ + lf.LMSample(lf.AIMessage(response_text), score=1.0) + ]) + ) + + self.assertEqual( + list(cache.values()), + [ + cache_entry('1'), + cache_entry('2'), + cache_entry('3'), + ]) + self.assertEqual( + list(cache.values('StaticSequence')), + [ + cache_entry('1'), + cache_entry('2'), + cache_entry('3'), + ]) + self.assertEqual( + list(cache.items()), + [ + ( + ('a', (0.0, 1024, 1, 40, None, None)), + cache_entry('1'), + ), + ( + ('b', (0.0, 1024, 1, 40, None, None)), + cache_entry('2'), + ), + ( + ('c', (0.0, 1024, 1, 40, None, None)), + cache_entry('3'), + ) + ] + ) + self.assertEqual( + list(cache.items('StaticSequence')), + [ + ( + ('a', (0.0, 1024, 1, 40, None, None)), + cache_entry('1'), + ), + ( + ('b', (0.0, 1024, 1, 40, None, None)), + cache_entry('2'), + ), + ( + ('c', (0.0, 1024, 1, 40, None, None)), + cache_entry('3'), + ) + ] + ) + + # Test clone/copy semantics. + self.assertIs(cache.clone()._cache, cache._cache) + self.assertIs(cache.clone(deep=True)._cache, cache._cache) + self.assertIs(copy.copy(cache)._cache, cache._cache) + self.assertIs(copy.deepcopy(cache)._cache, cache._cache) def test_ttl(self): - in_memory.InMemory.reset() lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory(ttl=1)) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1') @@ -40,15 +123,22 @@ def test_ttl(self): self.assertEqual(lm('a'), '2') def test_different_sampling_options(self): - in_memory.InMemory.reset() - lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) + cache = in_memory.InMemory() + lm = fake.StaticSequence(['1', '2', '3'], cache=cache) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1') self.assertEqual(lm('a', temperature=1.0), '2') + self.assertEqual( + list(cache.keys()), + [ + ('a', (0.0, 1024, 1, 40, None, None)), + ('a', (1.0, 1024, 1, 40, None, None)) + ]) def test_different_model(self): - lm1 = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) - lm2 = fake.Echo(cache=in_memory.InMemory()) + cache = in_memory.InMemory() + lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache) + lm2 = fake.Echo(cache=cache) self.assertEqual(lm1('a'), '1') self.assertEqual(lm2('a'), 'a') @@ -56,6 +146,48 @@ def test_different_model(self): self.assertEqual(lm1('b'), '2') self.assertEqual(lm2('b'), 'b') + self.assertEqual( + list(cache.keys('StaticSequence')), + [ + ('a', (0.0, 1024, 1, 40, None, None)), + ('b', (0.0, 1024, 1, 40, None, None)), + ]) + self.assertEqual( + list(cache.keys('Echo')), + [ + ('a', (0.0, 1024, 1, 40, None, None)), + ('b', (0.0, 1024, 1, 40, None, None)), + ]) + self.assertEqual(len(cache), 4) + cache.reset('Echo') + self.assertEqual(list(cache.keys('Echo')), []) + cache.reset() + self.assertEqual(list(cache.keys()), []) + + def test_save_load(self): + pg.set_load_handler(pg.symbolic.default_load_handler) + pg.set_save_handler(pg.symbolic.default_save_handler) + + cache = in_memory.InMemory() + lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache) + lm2 = fake.Echo(cache=cache) + + self.assertEqual(lm1('a'), '1') + self.assertEqual(lm2('a'), 'a') + + tmp_dir = tempfile.gettempdir() + path = os.path.join(tmp_dir, 'memory.json') + cache.save(path) + + cache2 = in_memory.InMemory(path) + self.assertEqual(cache2._cache, cache._cache) + + lm1 = fake.StaticSequence(['x', 'y'], cache=cache2) + lm2 = fake.Echo(cache=cache2) + + self.assertEqual(lm1('a'), '1') + self.assertEqual(lm2('a'), 'a') + if __name__ == '__main__': unittest.main()