Skip to content

Commit

Permalink
Introducing lf.lm_cache and enhancing lf.concurrent_map.
Browse files Browse the repository at this point in the history
- Introduces context manager `lf.lm_cache` for enabling cache with automatic saving.
- Enhances `lf.concurrent_map`
  - use `Exception` as the default value for argument `silence_on_errors`.
  - support argument `show_progress`, which will display progresses under console/colab notebook based on tqtm.
- Fixed an issue in assigning `top_p` for OpenAI models.

PiperOrigin-RevId: 567890684
  • Loading branch information
daiyip authored and langfun authors committed Sep 25, 2023
1 parent dedee0e commit bda6cc6
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 21 deletions.
2 changes: 2 additions & 0 deletions langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
PythonCode = coding.PythonCode

from langfun.core import llms
lm_cache = llms.cache.lm_cache

from langfun.core import memories

# Placeholder for Google-internal imports.
Expand Down
56 changes: 44 additions & 12 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from langfun.core import component
import pyglove as pg
from tqdm import auto as tqdm


def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -234,12 +235,14 @@ def concurrent_map(
func: Callable[[Any], Any],
parallel_inputs: Iterable[Any],
*,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
max_workers: int = 32,
ordered: bool = False,
show_progress: bool = False,
timeout: int | None = None,
silence_on_errors: Union[
Type[Exception], Tuple[Type[Exception], ...], None
] = RetryError,
] = Exception,
retry_on_errors: Union[
Type[Exception],
Tuple[Type[Exception], ...],
Expand All @@ -254,10 +257,13 @@ def concurrent_map(
Args:
func: A user function.
parallel_inputs: The inputs for `func` which will be processed in parallel.
executor: Thread pool exeutor to pool work items. If None, a new thread pool
executor will be created for current execution.
max_workers: The max number of workers.
ordered: If True, the returned iterator will emit (input, output, error) in
the order of the elements in `parallel_inputs`. Otherwise, elements that
are finished earlier will be delivered first.
show_progress: If True, show progress on console.
timeout: The timeout in seconds for processing each input. It is the total
processing time for each input, even multiple retries take place. If None,
there is no timeout.
Expand Down Expand Up @@ -298,29 +304,47 @@ def remaining_time():
return None
return time.time() - start_time

with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
executor = executor or concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers)

with executor:
future_to_job = {}
pending_futures = []
total = 0
for inputs in parallel_inputs:
job = Job(func, inputs)
future = executor.submit(
with_context_access(job),
)
pending_futures.append(future)
future_to_job[future] = job
total += 1

progress = tqdm.tqdm(total=total) if show_progress else None
def update_progress(success: int, failure: int) -> None:
if progress is not None:
completed = success + failure
progress.update(1)
progress.set_description(
'Success: %.2f%% (%d/%d), Failure: %.2f%% (%d/%d)'
% (success * 100.0 / completed, success, completed,
failure * 100.0 / completed, failure, completed))

remaining_futures = []
success, failure = 0, 0
if ordered:
for i, future in enumerate(pending_futures):
try:
_ = future.result(timeout=remaining_time())
job = future_to_job[future]
if job.error is not None and not (
silence_on_errors and isinstance(job.error, silence_on_errors)
):
raise job.error
if job.error is not None:
failure += 1
if not (
silence_on_errors and isinstance(job.error, silence_on_errors)):
raise job.error
else:
success += 1
update_progress(success, failure)
del future_to_job[future]
yield job.arg, job.result, job.error
except concurrent.futures.TimeoutError:
Expand All @@ -332,10 +356,15 @@ def remaining_time():
):
job = future_to_job[future]
del future_to_job[future]
if job.error is not None and not (
silence_on_errors and isinstance(job.error, silence_on_errors)
):
raise job.error # pylint: disable=g-doc-exception
if job.error is not None:
failure += 1
if not (
silence_on_errors and isinstance(job.error, silence_on_errors)
):
raise job.error # pylint: disable=g-doc-exception
else:
success += 1
update_progress(success, failure)
yield job.arg, job.result, job.error
remaining_futures = future_to_job

Expand All @@ -346,3 +375,6 @@ def remaining_time():
future.cancel()
job.error = TimeoutError(f'Execution time exceeds {timeout} seconds.')
yield job.arg, job.result, job.error

if progress is not None:
progress.close()
23 changes: 21 additions & 2 deletions langfun/core/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def fun(x):
return x**2

with component.context(y=2):
it = concurrent.concurrent_map(fun, [1, 2, 3])
it = concurrent.concurrent_map(fun, [1, 2, 3], silence_on_errors=KeyError)
self.assertEqual(next(it), (1, 1, None))
with self.assertRaises(ValueError):
_ = next(it)
Expand Down Expand Up @@ -276,7 +276,8 @@ def fun(x):
return x**2

with component.context(y=2):
it = concurrent.concurrent_map(fun, [1, 2, 3], ordered=True)
it = concurrent.concurrent_map(
fun, [1, 2, 3], ordered=True, silence_on_errors=KeyError)
self.assertEqual(next(it)[1], 1)

with self.assertRaises(ValueError):
Expand All @@ -302,6 +303,24 @@ def fun(x):
],
)

def test_concurrent_map_with_showing_progress(self):
def fun(x):
return x

self.assertEqual(
[
(i, o)
for i, o, _ in concurrent.concurrent_map(
fun, [1, 2, 3], ordered=True, show_progress=True
)
],
[
(1, 1),
(2, 2),
(3, 3),
],
)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test_call(self):
"LangFunc(template_str='Hello', clean=True, returns=None, "
'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, '
'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), '
'cache=None, timeout=120.0, max_attempts=5, debug=False), '
'cache=ContextualAttribute(type=None, default=None), timeout=120.0, '
'max_attempts=5, debug=False), '
'input_transform=None, output_transform=None)',
)

Expand Down
22 changes: 17 additions & 5 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,28 @@ class LMSamplingResult(pg.Object):
class LMSamplingOptions(component.Component):
"""Language model sampling options."""

temperature: Annotated[float, 'Model temperature between [0, 1.0].'] = 0.0
temperature: Annotated[
float,
(
'Model temperature, which is usually between 0 and 1.0. '
'OpenAI models have temperature range from 0.0 to 2.0.'
)
] = 0.0
max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
n: Annotated[int | None, 'Max number of samples to return.'] = 1
top_k: Annotated[int | None, 'Top k tokens to sample the next token.'] = 40
top_k: Annotated[
int | None,
(
'Top k tokens to sample the next token. '
'Not applicable to OpenAI models.'
)
] = 40
top_p: Annotated[
float | None,
(
'Only sample the next token from top N tokens whose accumulated '
'probability // mass <= p. Not applicable to OpenAI models and '
'BigBard.'
'probability // mass <= p. For OpenAI models, set `temperature` or '
'`top_p` but not both.'
),
] = None
random_seed: Annotated[
Expand Down Expand Up @@ -116,7 +128,7 @@ class LanguageModel(component.Component):
(
'Sampling cache. If None, no cache will be used.'
)
] = None
] = component.contextual(default=None)

timeout: Annotated[
float | None, 'Timeout in seconds. If None, there is no timeout.'
Expand Down
1 change: 1 addition & 0 deletions langfun/core/llms/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langfun.core.llms.cache.base import LMCacheEntry

from langfun.core.llms.cache.in_memory import InMemory
from langfun.core.llms.cache.in_memory import lm_cache


# pylint: enable=g-bad-import-order
Expand Down
29 changes: 29 additions & 0 deletions langfun/core/llms/cache/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""In-memory LM cache."""

import collections
import contextlib
from typing import Annotated, Any, Iterator
import langfun.core as lf
from langfun.core.llms.cache import base
import pyglove as pg

Expand Down Expand Up @@ -110,3 +112,30 @@ def save(self, path: str) -> None:
entries = [dict(k=k, v=v) for k, v in self.items(model_id)]
records.append(dict(model_id=model_id, entries=entries))
pg.save(records, path)


@contextlib.contextmanager
def lm_cache(
load: str | None = None,
save: str | None = None,
) -> Iterator[InMemory]:
"""Context manager to enable cache for LMs under the context.
If LMs under the context manager have explicitly specified cache, they will
use their own cache. Otherwise they will use the cache created by the context
manager.
Args:
load: If not None, JSON file to load the cache.
save: If not None, JSON file to save the cache when context manager exits.
Yields:
A cache object created.
"""
c = InMemory(load)
try:
with lf.context(cache=c):
yield c
finally:
if save:
c.save(save)
35 changes: 35 additions & 0 deletions langfun/core/llms/cache/in_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,40 @@ def test_save_load(self):
self.assertEqual(lm2('a'), 'a')


class UseCacheTest(unittest.TestCase):

def test_lm_cache(self):
with in_memory.lm_cache() as c:
lm = fake.Echo()
self.assertIs(lm.cache, c)
lm = fake.Echo(cache=in_memory.InMemory())
self.assertIsNot(lm.cache, c)

def test_lm_cache_load_save(self):
pg.set_load_handler(pg.symbolic.default_load_handler)
pg.set_save_handler(pg.symbolic.default_save_handler)

cache = in_memory.InMemory()
lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
self.assertEqual(lm('a'), '1')
self.assertEqual(lm('b'), '2')

tmp_dir = tempfile.gettempdir()
path1 = os.path.join(tmp_dir, 'memory1.json')
cache.save(path1)

path2 = os.path.join(tmp_dir, 'memory2.json')

with in_memory.lm_cache(load=path1, save=path2) as c1:
self.assertEqual(len(c1), 2)

lm = fake.StaticSequence(['4', '5', '6'])
self.assertEqual(lm('a'), '1')
self.assertEqual(lm('b'), '2')

with in_memory.lm_cache(load=path2, save=path2) as c2:
self.assertEqual(len(c2), 2)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion langfun/core/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _get_request_args(
args['model' if self.is_chat_model else 'engine'] = self.model

if options.top_p is not None:
args['top_p'] = options.top_k
args['top_p'] = options.top_p
return args

def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
Expand Down
33 changes: 33 additions & 0 deletions langfun/core/llms/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ def test_model_id(self):
self.assertEqual(
openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')

def test_get_request_args(self):
self.assertEqual(
openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
lf.LMSamplingOptions(
temperature=2.0,
n=2,
max_tokens=4096,
top_p=1.0)),
dict(
engine='text-davinci-003',
n=2,
temperature=2.0,
max_tokens=4096,
stream=False,
timeout=90.0,
top_p=1.0,
)
)
self.assertEqual(
openai.Gpt4(api_key='test_key')._get_request_args(
lf.LMSamplingOptions(
temperature=1.0,
n=1)),
dict(
model='gpt-4',
n=1,
temperature=1.0,
max_tokens=1024,
stream=False,
timeout=120.0,
)
)

def test_call_completion(self):
with mock.patch('openai.Completion.create') as mock_completion:
mock_completion.side_effect = mock_completion_query
Expand Down
1 change: 1 addition & 0 deletions langfun/core/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _call_fn(example):
for _, result, error in concurrent.concurrent_map(
_call_fn,
pg_sample_fn(sampling_space, num_examples=num_examples),
silence_on_errors=concurrent.RetryError,
max_workers=max_workers
):
if error is None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ jinja2>=3.1.2
openai>=0.18.1
pyglove>=0.4.3
termcolor==1.1.0
tqdm>=4.64.1

0 comments on commit bda6cc6

Please sign in to comment.