Skip to content

Commit

Permalink
Add a fake model that always returns the same canned response regardl…
Browse files Browse the repository at this point in the history
…ess of what the prompt was.

While this behavior is technically achievable via the `StaticMapping` model, we can reduce the burden on the user to have to create the mapping of all `(prompt: canned_response)` pairs by making it into its own model.

PiperOrigin-RevId: 571449156
  • Loading branch information
Langfun Authors committed Oct 6, 2023
1 parent dbe61af commit 23b1dcf
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
13 changes: 13 additions & 0 deletions langfun/core/llms/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
]


@lf.use_init_args(['canned_response'])
class StaticCannedResponse(lf.LanguageModel):
"""Language model that always gives the same canned response."""

canned_response: Annotated[str, 'A canned response.']

def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
return [
lf.LMSamplingResult([lf.LMSample(self.canned_response, 1.0)])
for _ in prompts
]


@lf.use_init_args(['mapping'])
class StaticMapping(lf.LanguageModel):
"""A static mapping from prompt to response."""
Expand Down
40 changes: 33 additions & 7 deletions langfun/core/llms/fake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,20 @@
import io
import unittest
import langfun.core as lf
from langfun.core.llms.fake import Echo
from langfun.core.llms.fake import StaticMapping
from langfun.core.llms.fake import StaticSequence
from langfun.core.llms import fake as fakelm


class EchoTest(unittest.TestCase):

def test_sample(self):
lm = Echo()
lm = fakelm.Echo()
self.assertEqual(
lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
)

def test_call(self):
string_io = io.StringIO()
lm = Echo(debug=True)
lm = fakelm.Echo(debug=True)
with contextlib.redirect_stdout(string_io):
self.assertEqual(lm('hi'), 'hi')
debug_info = string_io.getvalue()
Expand All @@ -41,10 +39,38 @@ def test_call(self):
self.assertIn('[0] LM RESPONSE', debug_info)


class StaticCannedResponseTest(unittest.TestCase):

def test_sample(self):
canned_response = "I'm sorry, I can't help you with that."
lm = fakelm.StaticCannedResponse(canned_response)
self.assertEqual(
lm.sample(['hi']),
[lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
)
self.assertEqual(
lm.sample(['Tell me a joke.']),
[lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
)

def test_call(self):
string_io = io.StringIO()
canned_response = "I'm sorry, I can't help you with that."
lm = fakelm.StaticCannedResponse(canned_response, debug=True)

with contextlib.redirect_stdout(string_io):
self.assertEqual(lm('hi'), canned_response)

debug_info = string_io.getvalue()
self.assertIn('[0] LM INFO:', debug_info)
self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
self.assertIn('[0] LM RESPONSE', debug_info)


class StaticMappingTest(unittest.TestCase):

def test_sample(self):
lm = StaticMapping({
lm = fakelm.StaticMapping({
'Hi': 'Hello',
'How are you?': 'I am fine, how about you?',
}, temperature=0.5)
Expand All @@ -63,7 +89,7 @@ def test_sample(self):
class StaticSequenceTest(unittest.TestCase):

def test_sample(self):
lm = StaticSequence([
lm = fakelm.StaticSequence([
'Hello',
'I am fine, how about you?',
], temperature=0.5)
Expand Down

0 comments on commit 23b1dcf

Please sign in to comment.