diff --git a/langfun/__init__.py b/langfun/__init__.py index 66dec9ac..fe86affb 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -35,6 +35,8 @@ PythonCode = coding.PythonCode from langfun.core import llms +lm_cache = llms.cache.lm_cache + from langfun.core import memories # Placeholder for Google-internal imports. diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index bc780ae1..62055adb 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -21,6 +21,7 @@ from langfun.core import component import pyglove as pg +from tqdm import auto as tqdm def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]: @@ -234,12 +235,14 @@ def concurrent_map( func: Callable[[Any], Any], parallel_inputs: Iterable[Any], *, + executor: concurrent.futures.ThreadPoolExecutor | None = None, max_workers: int = 32, ordered: bool = False, + show_progress: bool = False, timeout: int | None = None, silence_on_errors: Union[ Type[Exception], Tuple[Type[Exception], ...], None - ] = RetryError, + ] = Exception, retry_on_errors: Union[ Type[Exception], Tuple[Type[Exception], ...], @@ -254,10 +257,13 @@ def concurrent_map( Args: func: A user function. parallel_inputs: The inputs for `func` which will be processed in parallel. + executor: Thread pool exeutor to pool work items. If None, a new thread pool + executor will be created for current execution. max_workers: The max number of workers. ordered: If True, the returned iterator will emit (input, output, error) in the order of the elements in `parallel_inputs`. Otherwise, elements that are finished earlier will be delivered first. + show_progress: If True, show progress on console. timeout: The timeout in seconds for processing each input. It is the total processing time for each input, even multiple retries take place. If None, there is no timeout. @@ -298,11 +304,13 @@ def remaining_time(): return None return time.time() - start_time - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers - ) as executor: + executor = executor or concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) + + with executor: future_to_job = {} pending_futures = [] + total = 0 for inputs in parallel_inputs: job = Job(func, inputs) future = executor.submit( @@ -310,17 +318,33 @@ def remaining_time(): ) pending_futures.append(future) future_to_job[future] = job + total += 1 + + progress = tqdm.tqdm(total=total) if show_progress else None + def update_progress(success: int, failure: int) -> None: + if progress is not None: + completed = success + failure + progress.update(1) + progress.set_description( + 'Success: %.2f%% (%d/%d), Failure: %.2f%% (%d/%d)' + % (success * 100.0 / completed, success, completed, + failure * 100.0 / completed, failure, completed)) remaining_futures = [] + success, failure = 0, 0 if ordered: for i, future in enumerate(pending_futures): try: _ = future.result(timeout=remaining_time()) job = future_to_job[future] - if job.error is not None and not ( - silence_on_errors and isinstance(job.error, silence_on_errors) - ): - raise job.error + if job.error is not None: + failure += 1 + if not ( + silence_on_errors and isinstance(job.error, silence_on_errors)): + raise job.error + else: + success += 1 + update_progress(success, failure) del future_to_job[future] yield job.arg, job.result, job.error except concurrent.futures.TimeoutError: @@ -332,10 +356,15 @@ def remaining_time(): ): job = future_to_job[future] del future_to_job[future] - if job.error is not None and not ( - silence_on_errors and isinstance(job.error, silence_on_errors) - ): - raise job.error # pylint: disable=g-doc-exception + if job.error is not None: + failure += 1 + if not ( + silence_on_errors and isinstance(job.error, silence_on_errors) + ): + raise job.error # pylint: disable=g-doc-exception + else: + success += 1 + update_progress(success, failure) yield job.arg, job.result, job.error remaining_futures = future_to_job @@ -346,3 +375,6 @@ def remaining_time(): future.cancel() job.error = TimeoutError(f'Execution time exceeds {timeout} seconds.') yield job.arg, job.result, job.error + + if progress is not None: + progress.close() diff --git a/langfun/core/concurrent_test.py b/langfun/core/concurrent_test.py index 5722fc92..f7301753 100644 --- a/langfun/core/concurrent_test.py +++ b/langfun/core/concurrent_test.py @@ -173,7 +173,7 @@ def fun(x): return x**2 with component.context(y=2): - it = concurrent.concurrent_map(fun, [1, 2, 3]) + it = concurrent.concurrent_map(fun, [1, 2, 3], silence_on_errors=KeyError) self.assertEqual(next(it), (1, 1, None)) with self.assertRaises(ValueError): _ = next(it) @@ -276,7 +276,8 @@ def fun(x): return x**2 with component.context(y=2): - it = concurrent.concurrent_map(fun, [1, 2, 3], ordered=True) + it = concurrent.concurrent_map( + fun, [1, 2, 3], ordered=True, silence_on_errors=KeyError) self.assertEqual(next(it)[1], 1) with self.assertRaises(ValueError): @@ -302,6 +303,24 @@ def fun(x): ], ) + def test_concurrent_map_with_showing_progress(self): + def fun(x): + return x + + self.assertEqual( + [ + (i, o) + for i, o, _ in concurrent.concurrent_map( + fun, [1, 2, 3], ordered=True, show_progress=True + ) + ], + [ + (1, 1), + (2, 2), + (3, 3), + ], + ) + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 82c0cc8b..dfcd4d2a 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -94,7 +94,8 @@ def test_call(self): "LangFunc(template_str='Hello', clean=True, returns=None, " 'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, ' 'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), ' - 'cache=None, timeout=120.0, max_attempts=5, debug=False), ' + 'cache=ContextualAttribute(type=None, default=None), timeout=120.0, ' + 'max_attempts=5, debug=False), ' 'input_transform=None, output_transform=None)', ) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index ee1448ad..62dde882 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -54,16 +54,28 @@ class LMSamplingResult(pg.Object): class LMSamplingOptions(component.Component): """Language model sampling options.""" - temperature: Annotated[float, 'Model temperature between [0, 1.0].'] = 0.0 + temperature: Annotated[ + float, + ( + 'Model temperature, which is usually between 0 and 1.0. ' + 'OpenAI models have temperature range from 0.0 to 2.0.' + ) + ] = 0.0 max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024 n: Annotated[int | None, 'Max number of samples to return.'] = 1 - top_k: Annotated[int | None, 'Top k tokens to sample the next token.'] = 40 + top_k: Annotated[ + int | None, + ( + 'Top k tokens to sample the next token. ' + 'Not applicable to OpenAI models.' + ) + ] = 40 top_p: Annotated[ float | None, ( 'Only sample the next token from top N tokens whose accumulated ' - 'probability // mass <= p. Not applicable to OpenAI models and ' - 'BigBard.' + 'probability // mass <= p. For OpenAI models, set `temperature` or ' + '`top_p` but not both.' ), ] = None random_seed: Annotated[ @@ -116,7 +128,7 @@ class LanguageModel(component.Component): ( 'Sampling cache. If None, no cache will be used.' ) - ] = None + ] = component.contextual(default=None) timeout: Annotated[ float | None, 'Timeout in seconds. If None, there is no timeout.' diff --git a/langfun/core/llms/cache/__init__.py b/langfun/core/llms/cache/__init__.py index b5f4717a..c5470e99 100644 --- a/langfun/core/llms/cache/__init__.py +++ b/langfun/core/llms/cache/__init__.py @@ -20,6 +20,7 @@ from langfun.core.llms.cache.base import LMCacheEntry from langfun.core.llms.cache.in_memory import InMemory +from langfun.core.llms.cache.in_memory import lm_cache # pylint: enable=g-bad-import-order diff --git a/langfun/core/llms/cache/in_memory.py b/langfun/core/llms/cache/in_memory.py index a2bded3c..88871a99 100644 --- a/langfun/core/llms/cache/in_memory.py +++ b/langfun/core/llms/cache/in_memory.py @@ -14,7 +14,9 @@ """In-memory LM cache.""" import collections +import contextlib from typing import Annotated, Any, Iterator +import langfun.core as lf from langfun.core.llms.cache import base import pyglove as pg @@ -110,3 +112,30 @@ def save(self, path: str) -> None: entries = [dict(k=k, v=v) for k, v in self.items(model_id)] records.append(dict(model_id=model_id, entries=entries)) pg.save(records, path) + + +@contextlib.contextmanager +def lm_cache( + load: str | None = None, + save: str | None = None, +) -> Iterator[InMemory]: + """Context manager to enable cache for LMs under the context. + + If LMs under the context manager have explicitly specified cache, they will + use their own cache. Otherwise they will use the cache created by the context + manager. + + Args: + load: If not None, JSON file to load the cache. + save: If not None, JSON file to save the cache when context manager exits. + + Yields: + A cache object created. + """ + c = InMemory(load) + try: + with lf.context(cache=c): + yield c + finally: + if save: + c.save(save) diff --git a/langfun/core/llms/cache/in_memory_test.py b/langfun/core/llms/cache/in_memory_test.py index 6aa6ecad..e2a2967e 100644 --- a/langfun/core/llms/cache/in_memory_test.py +++ b/langfun/core/llms/cache/in_memory_test.py @@ -189,5 +189,40 @@ def test_save_load(self): self.assertEqual(lm2('a'), 'a') +class UseCacheTest(unittest.TestCase): + + def test_lm_cache(self): + with in_memory.lm_cache() as c: + lm = fake.Echo() + self.assertIs(lm.cache, c) + lm = fake.Echo(cache=in_memory.InMemory()) + self.assertIsNot(lm.cache, c) + + def test_lm_cache_load_save(self): + pg.set_load_handler(pg.symbolic.default_load_handler) + pg.set_save_handler(pg.symbolic.default_save_handler) + + cache = in_memory.InMemory() + lm = fake.StaticSequence(['1', '2', '3'], cache=cache) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('b'), '2') + + tmp_dir = tempfile.gettempdir() + path1 = os.path.join(tmp_dir, 'memory1.json') + cache.save(path1) + + path2 = os.path.join(tmp_dir, 'memory2.json') + + with in_memory.lm_cache(load=path1, save=path2) as c1: + self.assertEqual(len(c1), 2) + + lm = fake.StaticSequence(['4', '5', '6']) + self.assertEqual(lm('a'), '1') + self.assertEqual(lm('b'), '2') + + with in_memory.lm_cache(load=path2, save=path2) as c2: + self.assertEqual(len(c2), 2) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index 47cb10d8..d9b04c0d 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -122,7 +122,7 @@ def _get_request_args( args['model' if self.is_chat_model else 'engine'] = self.model if options.top_p is not None: - args['top_p'] = options.top_k + args['top_p'] = options.top_p return args def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]: diff --git a/langfun/core/llms/openai_test.py b/langfun/core/llms/openai_test.py index cbcac703..479fe6fa 100644 --- a/langfun/core/llms/openai_test.py +++ b/langfun/core/llms/openai_test.py @@ -61,6 +61,39 @@ def test_model_id(self): self.assertEqual( openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)') + def test_get_request_args(self): + self.assertEqual( + openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args( + lf.LMSamplingOptions( + temperature=2.0, + n=2, + max_tokens=4096, + top_p=1.0)), + dict( + engine='text-davinci-003', + n=2, + temperature=2.0, + max_tokens=4096, + stream=False, + timeout=90.0, + top_p=1.0, + ) + ) + self.assertEqual( + openai.Gpt4(api_key='test_key')._get_request_args( + lf.LMSamplingOptions( + temperature=1.0, + n=1)), + dict( + model='gpt-4', + n=1, + temperature=1.0, + max_tokens=1024, + stream=False, + timeout=120.0, + ) + ) + def test_call_completion(self): with mock.patch('openai.Completion.create') as mock_completion: mock_completion.side_effect = mock_completion_query diff --git a/langfun/core/sampling.py b/langfun/core/sampling.py index a9919847..2f0c026e 100644 --- a/langfun/core/sampling.py +++ b/langfun/core/sampling.py @@ -176,6 +176,7 @@ def _call_fn(example): for _, result, error in concurrent.concurrent_map( _call_fn, pg_sample_fn(sampling_space, num_examples=num_examples), + silence_on_errors=concurrent.RetryError, max_workers=max_workers ): if error is None: diff --git a/requirements.txt b/requirements.txt index e61adf17..a37358bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ jinja2>=3.1.2 openai>=0.18.1 pyglove>=0.4.3 termcolor==1.1.0 +tqdm>=4.64.1 \ No newline at end of file