Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fake model that always returns the same canned response regardless of what the prompt was. #33

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions langfun/core/llms/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
]


@lf.use_init_args(['response'])
class StaticResponse(lf.LanguageModel):
"""Language model that always gives the same canned response."""

response: Annotated[
str,
'A canned response that will be returned regardless of the prompt.'
]

def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
return [
lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
for _ in prompts
]


@lf.use_init_args(['mapping'])
class StaticMapping(lf.LanguageModel):
"""A static mapping from prompt to response."""
Expand All @@ -45,11 +61,11 @@ def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:

@lf.use_init_args(['sequence'])
class StaticSequence(lf.LanguageModel):
"""A static mapping from prompt to response."""
"""A static sequence of responses to use."""

sequence: Annotated[
list[str],
'A sequence of strings as the respones.'
'A sequence of strings as the response.'
]

def _on_bound(self):
Expand Down
42 changes: 34 additions & 8 deletions langfun/core/llms/fake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Echo LLM."""
"""Test Fake LLMs."""

import contextlib
import io
import unittest
import langfun.core as lf
from langfun.core.llms.fake import Echo
from langfun.core.llms.fake import StaticMapping
from langfun.core.llms.fake import StaticSequence
from langfun.core.llms import fake as fakelm


class EchoTest(unittest.TestCase):

def test_sample(self):
lm = Echo()
lm = fakelm.Echo()
self.assertEqual(
lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
)

def test_call(self):
string_io = io.StringIO()
lm = Echo(debug=True)
lm = fakelm.Echo(debug=True)
with contextlib.redirect_stdout(string_io):
self.assertEqual(lm('hi'), 'hi')
debug_info = string_io.getvalue()
Expand All @@ -41,10 +39,38 @@ def test_call(self):
self.assertIn('[0] LM RESPONSE', debug_info)


class StaticResponseTest(unittest.TestCase):

def test_sample(self):
canned_response = "I'm sorry, I can't help you with that."
lm = fakelm.StaticResponse(canned_response)
self.assertEqual(
lm.sample(['hi']),
[lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
)
self.assertEqual(
lm.sample(['Tell me a joke.']),
[lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
)

def test_call(self):
string_io = io.StringIO()
canned_response = "I'm sorry, I can't help you with that."
lm = fakelm.StaticResponse(canned_response, debug=True)

with contextlib.redirect_stdout(string_io):
self.assertEqual(lm('hi'), canned_response)

debug_info = string_io.getvalue()
self.assertIn('[0] LM INFO:', debug_info)
self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
self.assertIn('[0] LM RESPONSE', debug_info)


class StaticMappingTest(unittest.TestCase):

def test_sample(self):
lm = StaticMapping({
lm = fakelm.StaticMapping({
'Hi': 'Hello',
'How are you?': 'I am fine, how about you?',
}, temperature=0.5)
Expand All @@ -63,7 +89,7 @@ def test_sample(self):
class StaticSequenceTest(unittest.TestCase):

def test_sample(self):
lm = StaticSequence([
lm = fakelm.StaticSequence([
'Hello',
'I am fine, how about you?',
], temperature=0.5)
Expand Down
Loading