-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Add interface `lf.LMCache` - Implement `lf.llms.cache.InMemory`. - Support custom `key` function and `ttl`. NOTE: cache key is determined by (model_id, sampling_options, prompt). Usage: ```python lm = lf.llms.Gpt35(cache=lf.llms.cache.InMemory()) print(lf.LangFunc('Intro to the U.S.A', lm=lm)) ``` PiperOrigin-RevId: 567410319
- Loading branch information
Showing
11 changed files
with
369 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2023 The Langfun Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""langfun LLM cache implementations.""" | ||
|
||
# pylint: disable=g-importing-member | ||
# pylint: disable=g-bad-import-order | ||
|
||
from langfun.core.llms.cache.base import LMCacheBase | ||
from langfun.core.llms.cache.base import LMCacheEntry | ||
|
||
from langfun.core.llms.cache.in_memory import InMemory | ||
|
||
|
||
# pylint: enable=g-bad-import-order | ||
# pylint: enable=g-importing-member |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2023 The Langfun Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""LM cache base.""" | ||
|
||
import abc | ||
import dataclasses | ||
import datetime | ||
from typing import Annotated, Any, Callable | ||
import langfun.core as lf | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class LMCacheEntry: | ||
result: lf.LMSamplingResult | ||
expire: datetime.datetime | None | ||
|
||
|
||
class LMCacheBase(lf.LMCache): | ||
"""The common LMCache base.""" | ||
|
||
key: Annotated[ | ||
Callable[[lf.LanguageModel, lf.Message], Any] | None, | ||
( | ||
'A callable ojbect used for computing the key (hashable structure) ' | ||
'from the language model used and input prompt. If None, a default ' | ||
'key will be used, which are sensitive to the model id, sampling ' | ||
'options and the input prompt.' | ||
) | ||
] = None | ||
|
||
ttl: Annotated[ | ||
int | None, | ||
( | ||
'Time-to-live in seconds.' | ||
) | ||
] = None | ||
|
||
def _on_bound(self): | ||
super()._on_bound() | ||
self._key = self.key or default_key | ||
|
||
def get(self, | ||
lm: lf.LanguageModel, | ||
prompt: lf.Message) -> lf.LMSamplingResult | None: | ||
"""Gets the cached result of a prompt generated by a language model.""" | ||
entry = self._get(self._key(lm, prompt)) | ||
if entry is None: | ||
return None | ||
if entry.expire is not None and entry.expire < datetime.datetime.now(): | ||
return None | ||
return entry.result | ||
|
||
def put(self, | ||
lm: lf.LanguageModel, | ||
prompt: lf.Message, | ||
result: lf.LMSamplingResult) -> None: | ||
"""Puts the result of a prompt generated by a language model in cache.""" | ||
expire = None | ||
if self.ttl: | ||
expire = datetime.datetime.now() + datetime.timedelta(seconds=self.ttl) | ||
entry = LMCacheEntry(result, expire) | ||
self._put(self._key(lm, prompt), entry) | ||
|
||
@abc.abstractmethod | ||
def _get(self, key: Any) -> LMCacheEntry | None: | ||
"""Returns a LM cache entry associated with the key.""" | ||
|
||
@abc.abstractmethod | ||
def _put(self, key: Any, entry: LMCacheEntry) -> None: | ||
"""Puts a LM cache entry associated with the key.""" | ||
|
||
|
||
def default_key(lm: lf.LanguageModel, prompt: lf.Message) -> Any: | ||
"""Default key for LM cache.""" | ||
return (lm.model_id, lm.sampling_options.cache_key(), prompt.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2023 The Langfun Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""In-memory LM cache.""" | ||
|
||
from typing import Any | ||
from langfun.core.llms.cache import base | ||
|
||
|
||
class InMemory(base.LMCacheBase): | ||
"""In memory cache.""" | ||
|
||
def _get(self, key: Any) -> base.LMCacheEntry | None: | ||
"""Returns a LM cache entry associated with the key.""" | ||
return _CACHE_MEMORY.get(key, None) | ||
|
||
def _put(self, key: Any, entry: base.LMCacheEntry) -> None: | ||
"""Puts a LM cache entry associated with the key.""" | ||
_CACHE_MEMORY[key] = entry | ||
|
||
def reset(self) -> None: | ||
"""Resets the cache.""" | ||
_CACHE_MEMORY.clear() | ||
|
||
|
||
# NOTE(daiyip): We install a process-level cache store, so different InMemory() | ||
# object could access the same memory. This is not a problem across different | ||
# language models, since the `model_id` of the language model is included as a | ||
# part of the cache key. | ||
_CACHE_MEMORY = {} |
Oops, something went wrong.