diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index cf0d2d7..ee1448a 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -50,14 +50,6 @@ class LMSamplingResult(pg.Object): ), ] = [] - error: Annotated[ - Exception | None, - ( - 'Error information if sampling request failed. If Not None, ' - '`samples` will be an empty list.' - ), - ] = None - class LMSamplingOptions(component.Component): """Language model sampling options.""" @@ -200,14 +192,15 @@ def _sample_with_cache_lookup( results[i] = r.clone() # Sample non-cache-hit prompts. - requested_results = self._sample(requests) - assert len(requested_results) == len(requests), ( - requests, requested_results) + if requests: + requested_results = self._sample(requests) + assert len(requested_results) == len(requests), ( + requests, requested_results) - # Combine cached results and newly requested results. - for i, (prompt, result) in enumerate(zip(requests, requested_results)): - results[request_to_result_index[i]] = result - self.cache.put(self, prompt, result) + # Combine cached results and newly requested results. + for i, (prompt, result) in enumerate(zip(requests, requested_results)): + results[request_to_result_index[i]] = result + self.cache.put(self, prompt, result.clone()) return results # pytype: disable=bad-return-type @@ -218,7 +211,10 @@ def _sample( ) -> list[LMSamplingResult]: """Subclass should override.""" - def __call__(self, prompt: message_lib.Message, **kwargs) -> str: + def __call__( + self, + prompt: message_lib.Message, + **kwargs) -> message_lib.Message: """Returns the first candidate.""" with component.context(override_attrs=True, **kwargs): sampling_options = self.sampling_options @@ -252,7 +248,4 @@ def __call__(self, prompt: message_lib.Message, **kwargs) -> str: ), color='blue', ) - - if result.error: - raise result.error return response diff --git a/langfun/core/language_model_test.py b/langfun/core/language_model_test.py index 5e8c765..5249991 100644 --- a/langfun/core/language_model_test.py +++ b/langfun/core/language_model_test.py @@ -180,6 +180,18 @@ def test_using_cache(self): self.assertEqual(cache.cache_hit, 2) self.assertEqual(cache.num_records, 3) + lm = MockModel(cache=cache, + top_k=1, + failures_before_attempt=1, + max_attempts=1) + try: + lm.sample(['a']), + except concurrent.RetryError: + pass + + lm = MockModel(cache=cache, top_k=1) + self.assertEqual(lm('a'), 'a') + def test_retry(self): lm = MockModel( failures_before_attempt=1, top_k=1, diff --git a/langfun/core/llms/cache/in_memory.py b/langfun/core/llms/cache/in_memory.py index 2d75bf9..fd784de 100644 --- a/langfun/core/llms/cache/in_memory.py +++ b/langfun/core/llms/cache/in_memory.py @@ -28,7 +28,8 @@ def _put(self, key: Any, entry: base.LMCacheEntry) -> None: """Puts a LM cache entry associated with the key.""" _CACHE_MEMORY[key] = entry - def reset(self) -> None: + @classmethod + def reset(cls) -> None: """Resets the cache.""" _CACHE_MEMORY.clear() diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py index c340bc9..81b9672 100644 --- a/langfun/core/llms/cache/in_memory_test.py +++ b/langfun/core/llms/cache/in_memory_test.py @@ -23,7 +23,7 @@ class InMemoryLMCacheTest(unittest.TestCase): def test_basics(self): - in_memory.InMemory().reset() + in_memory.InMemory.reset() lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1') @@ -32,7 +32,7 @@ def test_basics(self): self.assertEqual(lm('c'), '3') def test_ttl(self): - in_memory.InMemory().reset() + in_memory.InMemory.reset() lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory(ttl=1)) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1') @@ -40,7 +40,7 @@ def test_ttl(self): self.assertEqual(lm('a'), '2') def test_different_sampling_options(self): - in_memory.InMemory().reset() + in_memory.InMemory.reset() lm = fake.StaticSequence(['1', '2', '3'], cache=in_memory.InMemory()) self.assertEqual(lm('a'), '1') self.assertEqual(lm('a'), '1')