Skip to content

Commit

Permalink
Add stop to lf.LMSamplingOptions and enable it for OpenAI models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591402329
  • Loading branch information
daiyip authored and langfun authors committed Dec 16, 2023
1 parent 0c68bf7 commit 3b90e92
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
6 changes: 3 additions & 3 deletions langfun/core/eval/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_basics(self):
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
self.assertEqual(s.hash, s.clone().hash)
# Test persistent hash.
self.assertEqual(s.hash, '0c857e07')
self.assertEqual(s.hash, 'c76d4fe6')
self.assertEqual(
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_search_space(self):
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
)
# Test persistent hash.
self.assertEqual(s.hash, 'cbb0adcf')
self.assertEqual(s.hash, 'e987475a')

summary = s.run(verbose=True)
self.assertEqual(len(summary.evaluations), 2)
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_run(self):
],
)
# Test for persistent hash.
self.assertEqual(s.hash, 'ec4758d3')
self.assertEqual(s.hash, 'bb86a963')
s.run()
expected = {
s.children[0].id: dict(
Expand Down
10 changes: 5 additions & 5 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def test_call(self):
print(repr(l))
self.assertEqual(
repr(l),
"LangFunc(template_str='Hello', clean=True, "
'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, '
'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), '
'cache=None, timeout=120.0, max_attempts=5, retry_interval=(5, 60), '
'exponential_backoff=True, debug=False))',
"LangFunc(template_str='Hello', clean=True,"
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
' random_seed=None), cache=None, timeout=120.0, max_attempts=5,'
' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
)

l = LangFunc('Hello')
Expand Down
9 changes: 9 additions & 0 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ class LMSamplingOptions(component.Component):
'`top_p` but not both.'
),
] = None
stop: Annotated[
list[str] | None,
(
'A list of stop sequences that prevent LLMs from outputting '
'more tokens. For example, when `stop` is set to ["User:", "Model:"] '
'LLMs will stop to emit more tokens when `User:` or '
'`Model:` is reached.'
),
] = None
random_seed: Annotated[
int | None, 'A fixed random seed used during model inference.'
] = None
Expand Down
2 changes: 2 additions & 0 deletions langfun/core/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def _get_request_args(

if options.top_p is not None:
args['top_p'] = options.top_p
if options.stop:
args['stop'] = options.stop
return args

def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
Expand Down
8 changes: 4 additions & 4 deletions langfun/core/llms/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,17 @@ def test_get_request_args(self):
)
self.assertEqual(
openai.Gpt4(api_key='test_key')._get_request_args(
lf.LMSamplingOptions(
temperature=1.0,
n=1)),
lf.LMSamplingOptions(temperature=1.0, stop=['\n'], n=1)
),
dict(
model='gpt-4',
n=1,
temperature=1.0,
max_tokens=1024,
stream=False,
timeout=120.0,
)
stop=['\n'],
),
)

def test_call_completion(self):
Expand Down

0 comments on commit 3b90e92

Please sign in to comment.