From 8b5d1fbbe68203cd127d6b4b789ebe4b3c60c18d Mon Sep 17 00:00:00 2001 From: Langfun Authors Date: Fri, 6 Oct 2023 15:26:13 -0700 Subject: [PATCH] Add a fake model that always returns the same canned response regardless of what the prompt was. While this behavior is technically achievable via the `StaticMapping` model, we can reduce the burden on the user to have to create the mapping of all `(prompt: canned_response)` pairs by making it into its own model. PiperOrigin-RevId: 571449156 --- langfun/core/llms/fake.py | 20 ++++++++++++++-- langfun/core/llms/fake_test.py | 42 +++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/langfun/core/llms/fake.py b/langfun/core/llms/fake.py index eb1e54fe..8def19c4 100644 --- a/langfun/core/llms/fake.py +++ b/langfun/core/llms/fake.py @@ -27,6 +27,22 @@ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: ] +@lf.use_init_args(['response']) +class StaticResponse(lf.LanguageModel): + """Language model that always gives the same canned response.""" + + response: Annotated[ + str, + 'A canned response that will be returned regardless of the prompt.' + ] + + def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: + return [ + lf.LMSamplingResult([lf.LMSample(self.response, 1.0)]) + for _ in prompts + ] + + @lf.use_init_args(['mapping']) class StaticMapping(lf.LanguageModel): """A static mapping from prompt to response.""" @@ -45,11 +61,11 @@ def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]: @lf.use_init_args(['sequence']) class StaticSequence(lf.LanguageModel): - """A static mapping from prompt to response.""" + """A static sequence of responses to use.""" sequence: Annotated[ list[str], - 'A sequence of strings as the respones.' + 'A sequence of strings as the response.' ] def _on_bound(self): diff --git a/langfun/core/llms/fake_test.py b/langfun/core/llms/fake_test.py index 0de1af7c..38eef03f 100644 --- a/langfun/core/llms/fake_test.py +++ b/langfun/core/llms/fake_test.py @@ -11,28 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test Echo LLM.""" +"""Test Fake LLMs.""" import contextlib import io import unittest import langfun.core as lf -from langfun.core.llms.fake import Echo -from langfun.core.llms.fake import StaticMapping -from langfun.core.llms.fake import StaticSequence +from langfun.core.llms import fake as fakelm class EchoTest(unittest.TestCase): def test_sample(self): - lm = Echo() + lm = fakelm.Echo() self.assertEqual( lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])] ) def test_call(self): string_io = io.StringIO() - lm = Echo(debug=True) + lm = fakelm.Echo(debug=True) with contextlib.redirect_stdout(string_io): self.assertEqual(lm('hi'), 'hi') debug_info = string_io.getvalue() @@ -41,10 +39,38 @@ def test_call(self): self.assertIn('[0] LM RESPONSE', debug_info) +class StaticResponseTest(unittest.TestCase): + + def test_sample(self): + canned_response = "I'm sorry, I can't help you with that." + lm = fakelm.StaticResponse(canned_response) + self.assertEqual( + lm.sample(['hi']), + [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])], + ) + self.assertEqual( + lm.sample(['Tell me a joke.']), + [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])], + ) + + def test_call(self): + string_io = io.StringIO() + canned_response = "I'm sorry, I can't help you with that." + lm = fakelm.StaticResponse(canned_response, debug=True) + + with contextlib.redirect_stdout(string_io): + self.assertEqual(lm('hi'), canned_response) + + debug_info = string_io.getvalue() + self.assertIn('[0] LM INFO:', debug_info) + self.assertIn('[0] PROMPT SENT TO LM:', debug_info) + self.assertIn('[0] LM RESPONSE', debug_info) + + class StaticMappingTest(unittest.TestCase): def test_sample(self): - lm = StaticMapping({ + lm = fakelm.StaticMapping({ 'Hi': 'Hello', 'How are you?': 'I am fine, how about you?', }, temperature=0.5) @@ -63,7 +89,7 @@ def test_sample(self): class StaticSequenceTest(unittest.TestCase): def test_sample(self): - lm = StaticSequence([ + lm = fakelm.StaticSequence([ 'Hello', 'I am fine, how about you?', ], temperature=0.5)