diff --git a/langfun/core/llms/fake.py b/langfun/core/llms/fake.py index eb1e54fe..65618793 100644 --- a/langfun/core/llms/fake.py +++ b/langfun/core/llms/fake.py @@ -27,6 +27,19 @@ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: ] +@lf.use_init_args(['canned_response']) +class StaticCannedResponse(lf.LanguageModel): + """Language model that always gives the same canned response.""" + + canned_response: Annotated[str, 'A canned response.'] + + def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: + return [ + lf.LMSamplingResult([lf.LMSample(self.canned_response, 1.0)]) + for _ in prompts + ] + + @lf.use_init_args(['mapping']) class StaticMapping(lf.LanguageModel): """A static mapping from prompt to response.""" diff --git a/langfun/core/llms/fake_test.py b/langfun/core/llms/fake_test.py index 0de1af7c..484362e2 100644 --- a/langfun/core/llms/fake_test.py +++ b/langfun/core/llms/fake_test.py @@ -17,22 +17,20 @@ import io import unittest import langfun.core as lf -from langfun.core.llms.fake import Echo -from langfun.core.llms.fake import StaticMapping -from langfun.core.llms.fake import StaticSequence +from langfun.core.llms import fake as fakelm class EchoTest(unittest.TestCase): def test_sample(self): - lm = Echo() + lm = fakelm.Echo() self.assertEqual( lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])] ) def test_call(self): string_io = io.StringIO() - lm = Echo(debug=True) + lm = fakelm.Echo(debug=True) with contextlib.redirect_stdout(string_io): self.assertEqual(lm('hi'), 'hi') debug_info = string_io.getvalue() @@ -41,10 +39,38 @@ def test_call(self): self.assertIn('[0] LM RESPONSE', debug_info) +class StaticCannedResponseTest(unittest.TestCase): + + def test_sample(self): + canned_response = "I'm sorry, I can't help you with that." + lm = fakelm.StaticCannedResponse(canned_response) + self.assertEqual( + lm.sample(['hi']), + [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])], + ) + self.assertEqual( + lm.sample(['Tell me a joke.']), + [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])], + ) + + def test_call(self): + string_io = io.StringIO() + canned_response = "I'm sorry, I can't help you with that." + lm = fakelm.StaticCannedResponse(canned_response, debug=True) + + with contextlib.redirect_stdout(string_io): + self.assertEqual(lm('hi'), canned_response) + + debug_info = string_io.getvalue() + self.assertIn('[0] LM INFO:', debug_info) + self.assertIn('[0] PROMPT SENT TO LM:', debug_info) + self.assertIn('[0] LM RESPONSE', debug_info) + + class StaticMappingTest(unittest.TestCase): def test_sample(self): - lm = StaticMapping({ + lm = fakelm.StaticMapping({ 'Hi': 'Hello', 'How are you?': 'I am fine, how about you?', }, temperature=0.5) @@ -63,7 +89,7 @@ def test_sample(self): class StaticSequenceTest(unittest.TestCase): def test_sample(self): - lm = StaticSequence([ + lm = fakelm.StaticSequence([ 'Hello', 'I am fine, how about you?', ], temperature=0.5)