From c889c4be978b3a898870a2e9a9a65a9349425593 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Sat, 23 Sep 2023 12:13:01 -0700 Subject: [PATCH] Introducing `lf.use_cache` and enhancing `lf.concurrent_map`. - Introduces context manager `lf.use_cache` for enabling cache with automatic saving. - Enhances `lf.concurrent_map` - use `Exception` as the default value for argument `silence_on_errors`. - support argument `show_progress`, which will display progresses under console/colab notebook based on tqtm. - Fixed an issue in assigning `top_p` for OpenAI models. PiperOrigin-RevId: 567890684 --- langfun/__init__.py | 2 + langfun/core/concurrent.py | 56 ++++++++++++++++++----- langfun/core/concurrent_test.py | 23 +++++++++- langfun/core/langfunc_test.py | 3 +- langfun/core/language_model.py | 22 +++++++-- langfun/core/llms/cache/__init__.py | 1 + langfun/core/llms/cache/in_memory.py | 29 ++++++++++++ langfun/core/llms/cache/in_memory_test.py | 35 ++++++++++++++ langfun/core/llms/openai.py | 2 +- langfun/core/sampling.py | 1 + requirements.txt | 1 + 11 files changed, 154 insertions(+), 21 deletions(-) diff --git a/langfun/__init__.py b/langfun/__init__.py index 66dec9ac..ce2c4178 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -35,6 +35,8 @@ PythonCode = coding.PythonCode from langfun.core import llms +use_cache = llms.cache.use_cache + from langfun.core import memories # Placeholder for Google-internal imports. diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index bc780ae1..62055adb 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -21,6 +21,7 @@ from langfun.core import component import pyglove as pg +from tqdm import auto as tqdm def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]: @@ -234,12 +235,14 @@ def concurrent_map( func: Callable[[Any], Any], parallel_inputs: Iterable[Any], *, + executor: concurrent.futures.ThreadPoolExecutor | None = None, max_workers: int = 32, ordered: bool = False, + show_progress: bool = False, timeout: int | None = None, silence_on_errors: Union[ Type[Exception], Tuple[Type[Exception], ...], None - ] = RetryError, + ] = Exception, retry_on_errors: Union[ Type[Exception], Tuple[Type[Exception], ...], @@ -254,10 +257,13 @@ def concurrent_map( Args: func: A user function. parallel_inputs: The inputs for `func` which will be processed in parallel. + executor: Thread pool exeutor to pool work items. If None, a new thread pool + executor will be created for current execution. max_workers: The max number of workers. ordered: If True, the returned iterator will emit (input, output, error) in the order of the elements in `parallel_inputs`. Otherwise, elements that are finished earlier will be delivered first. + show_progress: If True, show progress on console. timeout: The timeout in seconds for processing each input. It is the total processing time for each input, even multiple retries take place. If None, there is no timeout. @@ -298,11 +304,13 @@ def remaining_time(): return None return time.time() - start_time - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers - ) as executor: + executor = executor or concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) + + with executor: future_to_job = {} pending_futures = [] + total = 0 for inputs in parallel_inputs: job = Job(func, inputs) future = executor.submit( @@ -310,17 +318,33 @@ def remaining_time(): ) pending_futures.append(future) future_to_job[future] = job + total += 1 + + progress = tqdm.tqdm(total=total) if show_progress else None + def update_progress(success: int, failure: int) -> None: + if progress is not None: + completed = success + failure + progress.update(1) + progress.set_description( + 'Success: %.2f%% (%d/%d), Failure: %.2f%% (%d/%d)' + % (success * 100.0 / completed, success, completed, + failure * 100.0 / completed, failure, completed)) remaining_futures = [] + success, failure = 0, 0 if ordered: for i, future in enumerate(pending_futures): try: _ = future.result(timeout=remaining_time()) job = future_to_job[future] - if job.error is not None and not ( - silence_on_errors and isinstance(job.error, silence_on_errors) - ): - raise job.error + if job.error is not None: + failure += 1 + if not ( + silence_on_errors and isinstance(job.error, silence_on_errors)): + raise job.error + else: + success += 1 + update_progress(success, failure) del future_to_job[future] yield job.arg, job.result, job.error except concurrent.futures.TimeoutError: @@ -332,10 +356,15 @@ def remaining_time(): ): job = future_to_job[future] del future_to_job[future] - if job.error is not None and not ( - silence_on_errors and isinstance(job.error, silence_on_errors) - ): - raise job.error # pylint: disable=g-doc-exception + if job.error is not None: + failure += 1 + if not ( + silence_on_errors and isinstance(job.error, silence_on_errors) + ): + raise job.error # pylint: disable=g-doc-exception + else: + success += 1 + update_progress(success, failure) yield job.arg, job.result, job.error remaining_futures = future_to_job @@ -346,3 +375,6 @@ def remaining_time(): future.cancel() job.error = TimeoutError(f'Execution time exceeds {timeout} seconds.') yield job.arg, job.result, job.error + + if progress is not None: + progress.close() diff --git a/langfun/core/concurrent_test.py b/langfun/core/concurrent_test.py index 5722fc92..f7301753 100644 --- a/langfun/core/concurrent_test.py +++ b/langfun/core/concurrent_test.py @@ -173,7 +173,7 @@ def fun(x): return x**2 with component.context(y=2): - it = concurrent.concurrent_map(fun, [1, 2, 3]) + it = concurrent.concurrent_map(fun, [1, 2, 3], silence_on_errors=KeyError) self.assertEqual(next(it), (1, 1, None)) with self.assertRaises(ValueError): _ = next(it) @@ -276,7 +276,8 @@ def fun(x): return x**2 with component.context(y=2): - it = concurrent.concurrent_map(fun, [1, 2, 3], ordered=True) + it = concurrent.concurrent_map( + fun, [1, 2, 3], ordered=True, silence_on_errors=KeyError) self.assertEqual(next(it)[1], 1) with self.assertRaises(ValueError): @@ -302,6 +303,24 @@ def fun(x): ], ) + def test_concurrent_map_with_showing_progress(self): + def fun(x): + return x + + self.assertEqual( + [ + (i, o) + for i, o, _ in concurrent.concurrent_map( + fun, [1, 2, 3], ordered=True, show_progress=True + ) + ], + [ + (1, 1), + (2, 2), + (3, 3), + ], + ) + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 82c0cc8b..dfcd4d2a 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -94,7 +94,8 @@ def test_call(self): "LangFunc(template_str='Hello', clean=True, returns=None, " 'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, ' 'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), ' - 'cache=None, timeout=120.0, max_attempts=5, debug=False), ' + 'cache=ContextualAttribute(type=None, default=None), timeout=120.0, ' + 'max_attempts=5, debug=False), ' 'input_transform=None, output_transform=None)', ) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index ee1448ad..62dde882 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -54,16 +54,28 @@ class LMSamplingResult(pg.Object): class LMSamplingOptions(component.Component): """Language model sampling options.""" - temperature: Annotated[float, 'Model temperature between [0, 1.0].'] = 0.0 + temperature: Annotated[ + float, + ( + 'Model temperature, which is usually between 0 and 1.0. ' + 'OpenAI models have temperature range from 0.0 to 2.0.' + ) + ] = 0.0 max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024 n: Annotated[int | None, 'Max number of samples to return.'] = 1 - top_k: Annotated[int | None, 'Top k tokens to sample the next token.'] = 40 + top_k: Annotated[ + int | None, + ( + 'Top k tokens to sample the next token. ' + 'Not applicable to OpenAI models.' + ) + ] = 40 top_p: Annotated[ float | None, ( 'Only sample the next token from top N tokens whose accumulated ' - 'probability // mass <= p. Not applicable to OpenAI models and ' - 'BigBard.' + 'probability // mass <= p. For OpenAI models, set `temperature` or ' + '`top_p` but not both.' ), ] = None random_seed: Annotated[ @@ -116,7 +128,7 @@ class LanguageModel(component.Component): ( 'Sampling cache. If None, no cache will be used.' ) - ] = None + ] = component.contextual(default=None) timeout: Annotated[ float | None, 'Timeout in seconds. If None, there is no timeout.' diff --git a/langfun/core/llms/cache/__init__.py b/langfun/core/llms/cache/__init__.py index b5f4717a..27ad0d37 100644 --- a/langfun/core/llms/cache/__init__.py +++ b/langfun/core/llms/cache/__init__.py @@ -20,6 +20,7 @@ from langfun.core.llms.cache.base import LMCacheEntry from langfun.core.llms.cache.in_memory import InMemory +from langfun.core.llms.cache.in_memory import use_cache # pylint: enable=g-bad-import-order diff --git a/langfun/core/llms/cache/in_memory.py b/langfun/core/llms/cache/in_memory.py index a2bded3c..941a5b15 100644 --- a/langfun/core/llms/cache/in_memory.py +++ b/langfun/core/llms/cache/in_memory.py @@ -14,7 +14,9 @@ """In-memory LM cache.""" import collections +import contextlib from typing import Annotated, Any, Iterator +import langfun.core as lf from langfun.core.llms.cache import base import pyglove as pg @@ -110,3 +112,30 @@ def save(self, path: str) -> None: entries = [dict(k=k, v=v) for k, v in self.items(model_id)] records.append(dict(model_id=model_id, entries=entries)) pg.save(records, path) + + +@contextlib.contextmanager +def use_cache( + load: str | None = None, + save: str | None = None, +) -> Iterator[InMemory]: + """Context manager to enable cache for LMs under the context. + + If LMs under the context manager have explicitly specified cache, they will + use their own cache. Otherwise they will use the cache created by the context + manager. + + Args: + load: If not None, JSON file to load the cache. + save: If not None, JSON file to save the cache when context manager exits. + + Yields: + A cache object created. + """ + c = InMemory(load) + try: + with lf.context(cache=c): + yield c + finally: + if save: + c.save(save) diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py index 6aa6ecad..47efc386 100644 --- a/langfun/core/llms/cache/in_memory_test.py +++ b/langfun/core/llms/cache/in_memory_test.py @@ -189,5 +189,40 @@ def test_save_load(self): self.assertEqual(lm2('a'), 'a') +class UseCacheTest(unittest.TestCase): + + def test_use_cache(self): + with in_memory.use_cache() as c: + lm = fake.Echo() + self.assertIs(lm.cache, c) + lm = fake.Echo(cache=in_memory.InMemory()) + self.assertIsNot(lm.cache, c) + + def test_use_cache_load_save(self): + pg.set_load_handler(pg.symbolic.default_load_handler) + pg.set_save_handler(pg.symbolic.default_save_handler) + + cache = in_memory.InMemory() + lm = fake.StaticSequence(['1', '2', '3'], cache=cache) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('b'), '2') + + tmp_dir = tempfile.gettempdir() + path1 = os.path.join(tmp_dir, 'memory1.json') + cache.save(path1) + + path2 = os.path.join(tmp_dir, 'memory2.json') + + with in_memory.use_cache(load=path1, save=path2) as c1: + self.assertEqual(len(c1), 2) + + lm = fake.StaticSequence(['4', '5', '6']) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('b'), '2') + + with in_memory.use_cache(load=path2, save=path2) as c2: + self.assertEqual(len(c2), 2) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index 47cb10d8..d9b04c0d 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -122,7 +122,7 @@ def _get_request_args( args['model' if self.is_chat_model else 'engine'] = self.model if options.top_p is not None: - args['top_p'] = options.top_k + args['top_p'] = options.top_p return args def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]: diff --git a/langfun/core/sampling.py b/langfun/core/sampling.py index a9919847..2f0c026e 100644 --- a/langfun/core/sampling.py +++ b/langfun/core/sampling.py @@ -176,6 +176,7 @@ def _call_fn(example): for _, result, error in concurrent.concurrent_map( _call_fn, pg_sample_fn(sampling_space, num_examples=num_examples), + silence_on_errors=concurrent.RetryError, max_workers=max_workers ): if error is None: diff --git a/requirements.txt b/requirements.txt index e61adf17..a37358bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ jinja2>=3.1.2 openai>=0.18.1 pyglove>=0.4.3 termcolor==1.1.0 +tqdm>=4.64.1 \ No newline at end of file