Skip to content

Commit

Permalink
LanguageModel interface cleanup.
Browse files Browse the repository at this point in the history
- Remove `error` in LMResult. Error will always be raised when retry failed.
- Correct the `__call__` return annotation from str to `lf.Message`.

PiperOrigin-RevId: 567643634
  • Loading branch information
daiyip authored and langfun authors committed Sep 22, 2023
1 parent 8e2d7a4 commit 5394ad7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 23 deletions.
31 changes: 12 additions & 19 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ class LMSamplingResult(pg.Object):
),
] = []

error: Annotated[
Exception | None,
(
'Error information if sampling request failed. If Not None, '
'`samples` will be an empty list.'
),
] = None


class LMSamplingOptions(component.Component):
"""Language model sampling options."""
Expand Down Expand Up @@ -200,14 +192,15 @@ def _sample_with_cache_lookup(
results[i] = r.clone()

# Sample non-cache-hit prompts.
requested_results = self._sample(requests)
assert len(requested_results) == len(requests), (
requests, requested_results)
if requests:
requested_results = self._sample(requests)
assert len(requested_results) == len(requests), (
requests, requested_results)

# Combine cached results and newly requested results.
for i, (prompt, result) in enumerate(zip(requests, requested_results)):
results[request_to_result_index[i]] = result
self.cache.put(self, prompt, result)
# Combine cached results and newly requested results.
for i, (prompt, result) in enumerate(zip(requests, requested_results)):
results[request_to_result_index[i]] = result
self.cache.put(self, prompt, result.clone())

return results # pytype: disable=bad-return-type

Expand All @@ -218,7 +211,10 @@ def _sample(
) -> list[LMSamplingResult]:
"""Subclass should override."""

def __call__(self, prompt: message_lib.Message, **kwargs) -> str:
def __call__(
self,
prompt: message_lib.Message,
**kwargs) -> message_lib.Message:
"""Returns the first candidate."""
with component.context(override_attrs=True, **kwargs):
sampling_options = self.sampling_options
Expand Down Expand Up @@ -252,7 +248,4 @@ def __call__(self, prompt: message_lib.Message, **kwargs) -> str:
),
color='blue',
)

if result.error:
raise result.error
return response
12 changes: 12 additions & 0 deletions langfun/core/language_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ def test_using_cache(self):
self.assertEqual(cache.cache_hit, 2)
self.assertEqual(cache.num_records, 3)

lm = MockModel(cache=cache,
top_k=1,
failures_before_attempt=1,
max_attempts=1)
try:
lm.sample(['a']),
except concurrent.RetryError:
pass

lm = MockModel(cache=cache, top_k=1)
self.assertEqual(lm('a'), 'a')

def test_retry(self):
lm = MockModel(
failures_before_attempt=1, top_k=1,
Expand Down
3 changes: 2 additions & 1 deletion langfun/core/llms/cache/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def _put(self, key: Any, entry: base.LMCacheEntry) -> None:
"""Puts a LM cache entry associated with the key."""
_CACHE_MEMORY[key] = entry

def reset(self) -> None:
@classmethod
def reset(cls) -> None:
"""Resets the cache."""
_CACHE_MEMORY.clear()

Expand Down
6 changes: 3 additions & 3 deletions langfun/core/llms/cache/in_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class InMemoryLMCacheTest(unittest.TestCase):

def test_basics(self):
in_memory.InMemory().reset()
in_memory.InMemory.reset()
lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory())
self.assertEqual(lm('a'), '1')
self.assertEqual(lm('a'), '1')
Expand All @@ -32,15 +32,15 @@ def test_basics(self):
self.assertEqual(lm('c'), '3')

def test_ttl(self):
in_memory.InMemory().reset()
in_memory.InMemory.reset()
lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory(ttl=1))
self.assertEqual(lm('a'), '1')
self.assertEqual(lm('a'), '1')
time.sleep(2)
self.assertEqual(lm('a'), '2')

def test_different_sampling_options(self):
in_memory.InMemory().reset()
in_memory.InMemory.reset()
lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory())
self.assertEqual(lm('a'), '1')
self.assertEqual(lm('a'), '1')
Expand Down

0 comments on commit 5394ad7

Please sign in to comment.