From 120866b2dae2e0e1b5a49ba4fff6d5bb728df0ff Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Mon, 2 Dec 2024 15:34:17 -0800 Subject: [PATCH] Migrate OpenAI models to use REST. This insures forward compatibility with OpenAI APIs and allow langfun to be decoupled from the increasing heavy OpenAI SDK. PiperOrigin-RevId: 702105210 --- README.md | 1 - langfun/core/llms/openai.py | 349 ++++++++++++---------------- langfun/core/llms/openai_test.py | 384 +++++++++++++------------------ requirements.txt | 2 - 4 files changed, 302 insertions(+), 434 deletions(-) diff --git a/README.md b/README.md index 38b102d..5da44a9 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,6 @@ If you want to customize your installation, you can select specific features usi | llm-google | All supported Google-powered LLMs. | | llm-google-vertexai | LLMs powered by Google Cloud VertexAI | | llm-google-genai | LLMs powered by Google Generative AI API | -| llm-openai | LLMs powered by OpenAI | | mime | All MIME supports. | | mime-auto | Automatic MIME type detection. | | mime-docx | DocX format support. | diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index d12ce4e..03317e5 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -13,34 +13,14 @@ # limitations under the License. """Language models from OpenAI.""" -import collections -import functools import os from typing import Annotated, Any import langfun.core as lf from langfun.core import modalities as lf_modalities +from langfun.core.llms import rest import pyglove as pg -try: - import openai # pylint: disable=g-import-not-at-top - - if hasattr(openai, 'error'): - # For lower versions. - ServiceUnavailableError = openai.error.ServiceUnavailableError - RateLimitError = openai.error.RateLimitError - APITimeoutError = ( - openai.error.APIError, - '.*The server had an error processing your request' - ) - else: - # For higher versions. - ServiceUnavailableError = getattr(openai, 'InternalServerError') - RateLimitError = getattr(openai, 'RateLimitError') - APITimeoutError = getattr(openai, 'APITimeoutError') -except ImportError: - openai = None - # From https://platform.openai.com/settings/organization/limits _DEFAULT_TPM = 250000 @@ -289,7 +269,7 @@ rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM ), - # GPT-3 instruction-tuned models + # GPT-3 instruction-tuned models (Deprecated) 'text-curie-001': pg.Dict( in_service=False, rpm=_DEFAULT_RPM, @@ -325,9 +305,9 @@ rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM ), - # GPT-3 base models + # GPT-3 base models that are still in service. 'babbage-002': pg.Dict( - in_service=False, + in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM ), @@ -340,7 +320,7 @@ @lf.use_init_args(['model']) -class OpenAI(lf.LanguageModel): +class OpenAI(rest.REST): """OpenAI model.""" model: pg.typing.Annotated[ @@ -348,7 +328,9 @@ class OpenAI(lf.LanguageModel): pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) ), 'The name of the model to use.', - ] = 'gpt-3.5-turbo' + ] + + api_endpoint: str = 'https://api.openai.com/v1/chat/completions' multimodal: Annotated[ bool, @@ -372,27 +354,45 @@ class OpenAI(lf.LanguageModel): ), ] = None + project: Annotated[ + str | None, + ( + 'Project. If None, the key will be read from environment ' + "variable 'OPENAI_PROJECT'. Based on the value, usages from " + "these API requests will count against the project's quota. " + ), + ] = None + def _on_bound(self): super()._on_bound() - self.__dict__.pop('_api_initialized', None) - if openai is None: - raise RuntimeError( - 'Please install "langfun[llm-openai]" to use OpenAI models.' - ) + self._api_key = None + self._organization = None + self._project = None - @functools.cached_property - def _api_initialized(self): + def _initialize(self): api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None) if not api_key: raise ValueError( 'Please specify `api_key` during `__init__` or set environment ' 'variable `OPENAI_API_KEY` with your OpenAI API key.' ) - openai.api_key = api_key - org = self.organization or os.environ.get('OPENAI_ORGANIZATION', None) - if org: - openai.organization = org - return True + self._api_key = api_key + self._organization = self.organization or os.environ.get( + 'OPENAI_ORGANIZATION', None + ) + self._project = self.project or os.environ.get('OPENAI_PROJECT', None) + + @property + def headers(self) -> dict[str, Any]: + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self._api_key}', + } + if self._organization: + headers['OpenAI-Organization'] = self._organization + if self._project: + headers['OpenAI-Project'] = self._project + return headers @property def model_id(self) -> str: @@ -428,23 +428,16 @@ def estimate_cost( @classmethod def dir(cls): - assert openai is not None - return openai.Model.list() + return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service] - @property - def is_chat_model(self): - """Returns True if the model is a chat model.""" - return self.model.startswith(('o1', 'gpt-4', 'gpt-3.5-turbo')) - - def _get_request_args( + def _request_args( self, options: lf.LMSamplingOptions) -> dict[str, Any]: # Reference: # https://platform.openai.com/docs/api-reference/completions/create # NOTE(daiyip): options.top_k is not applicable. args = dict( + model=self.model, n=options.n, - stream=False, - timeout=self.timeout, top_logprobs=options.top_logprobs, ) if options.logprobs: @@ -453,13 +446,10 @@ def _get_request_args( raise RuntimeError('`logprobs` is not supported on {self.model!r}.') args['logprobs'] = options.logprobs - # Completion and ChatCompletion uses different parameter name for model. - args['model' if self.is_chat_model else 'engine'] = self.model - if options.temperature is not None: args['temperature'] = options.temperature if options.max_tokens is not None: - args['max_tokens'] = options.max_tokens + args['max_completion_tokens'] = options.max_tokens if options.top_p is not None: args['top_p'] = options.top_p if options.stop: @@ -468,168 +458,113 @@ def _get_request_args( args['seed'] = options.random_seed return args - def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: - assert self._api_initialized - if self.is_chat_model: - return self._chat_complete_batch(prompts) - else: - return self._complete_batch(prompts) - - def _complete_batch( - self, prompts: list[lf.Message] - ) -> list[lf.LMSamplingResult]: - - def _open_ai_completion(prompts): - assert openai is not None - response = openai.Completion.create( - prompt=[p.text for p in prompts], - **self._get_request_args(self.sampling_options), - ) - # Parse response. - samples_by_index = collections.defaultdict(list) - for choice in response.choices: - samples_by_index[choice.index].append( - lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0) - ) - - n = len(samples_by_index) - estimated_cost = self.estimate_cost( - num_input_tokens=response.usage.prompt_tokens, - num_output_tokens=response.usage.completion_tokens, - ) - usage = lf.LMSamplingUsage( - prompt_tokens=response.usage.prompt_tokens // n, - completion_tokens=response.usage.completion_tokens // n, - total_tokens=response.usage.total_tokens // n, - estimated_cost=( - None if estimated_cost is None else (estimated_cost // n) - ) - ) - return [ - lf.LMSamplingResult(samples_by_index[index], usage=usage) - for index in sorted(samples_by_index.keys()) - ] - - return self._parallel_execute_with_currency_control( - _open_ai_completion, - [prompts], - retry_on_errors=( - ServiceUnavailableError, - RateLimitError, - APITimeoutError, - ), - )[0] - - def _chat_complete_batch( - self, prompts: list[lf.Message] - ) -> list[lf.LMSamplingResult]: - def _content_from_message(message: lf.Message): - if self.multimodal: - content = [] - for chunk in message.chunk(): - if isinstance(chunk, str): - item = dict(type='text', text=chunk) - elif isinstance(chunk, lf_modalities.Image): - if chunk.uri and chunk.uri.lower().startswith( - ('http:', 'https:', 'ftp:') - ): - uri = chunk.uri - else: - uri = chunk.content_uri - item = dict(type='image_url', image_url=dict(url=uri)) - else: - raise ValueError(f'Unsupported modality object: {chunk!r}.') - content.append(item) + def _content_from_message(self, message: lf.Message): + """Returns a OpenAI content object from a Langfun message.""" + def _uri_from(chunk: lf.Modality) -> str: + if chunk.uri and chunk.uri.lower().startswith( + ('http:', 'https:', 'ftp:') + ): + return chunk.uri + return chunk.content_uri + + content = [] + for chunk in message.chunk(): + if isinstance(chunk, str): + item = dict(type='text', text=chunk) + elif isinstance(chunk, lf_modalities.Image) and self.multimodal: + item = dict(type='image_url', image_url=dict(url=_uri_from(chunk))) else: - content = message.text - return content - - def _open_ai_chat_completion(prompt: lf.Message): - request_args = self._get_request_args(self.sampling_options) - # Users could use `metadata_json_schema` to pass additional - # request arguments. - json_schema = prompt.metadata.get('json_schema') - if json_schema is not None: - if not isinstance(json_schema, dict): - raise ValueError( - f'`json_schema` must be a dict, got {json_schema!r}.' - ) - if 'title' not in json_schema: - raise ValueError( - f'The root of `json_schema` must have a `title` field, ' - f'got {json_schema!r}.' - ) - request_args.update( - response_format=dict( - type='json_schema', - json_schema=dict( - schema=json_schema, - name=json_schema['title'], - strict=True, - ) - ) - ) - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(request_args['response_format'], json_indent=2) - ) + raise ValueError(f'Unsupported modality: {chunk!r}.') + content.append(item) + return content - # Prepare messages. - messages = [] - # Users could use `metadata_system_message` to pass system message. - system_message = prompt.metadata.get('system_message') - if system_message: - system_message = lf.SystemMessage.from_value(system_message) - messages.append( - dict(role='system', content=_content_from_message(system_message)) + def request( + self, + prompt: lf.Message, + sampling_options: lf.LMSamplingOptions + ) -> dict[str, Any]: + """Returns the JSON input for a message.""" + request_args = self._request_args(sampling_options) + + # Users could use `metadata_json_schema` to pass additional + # request arguments. + json_schema = prompt.metadata.get('json_schema') + if json_schema is not None: + if not isinstance(json_schema, dict): + raise ValueError( + f'`json_schema` must be a dict, got {json_schema!r}.' ) - messages.append(dict(role='user', content=_content_from_message(prompt))) - - assert openai is not None - response = openai.ChatCompletion.create(messages=messages, **request_args) - - samples = [] - for choice in response.choices: - logprobs = None - choice_logprobs = getattr(choice, 'logprobs', None) - if choice_logprobs: - logprobs = [ - ( - t.token, - t.logprob, - [(tt.token, tt.logprob) for tt in t.top_logprobs], - ) - for t in choice_logprobs.content - ] - samples.append( - lf.LMSample( - choice.message.content, - score=0.0, - logprobs=logprobs, - ) + if 'title' not in json_schema: + raise ValueError( + f'The root of `json_schema` must have a `title` field, ' + f'got {json_schema!r}.' ) - - return lf.LMSamplingResult( - samples=samples, - usage=lf.LMSamplingUsage( - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - estimated_cost=self.estimate_cost( - num_input_tokens=response.usage.prompt_tokens, - num_output_tokens=response.usage.completion_tokens, + request_args.update( + response_format=dict( + type='json_schema', + json_schema=dict( + schema=json_schema, + name=json_schema['title'], + strict=True, ) - ), + ) + ) + prompt.metadata.formatted_text = ( + prompt.text + + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' + + pg.to_json_str(request_args['response_format'], json_indent=2) + ) + + # Prepare messages. + messages = [] + # Users could use `metadata_system_message` to pass system message. + system_message = prompt.metadata.get('system_message') + if system_message: + system_message = lf.SystemMessage.from_value(system_message) + messages.append( + dict(role='system', + content=self._content_from_message(system_message)) ) + messages.append( + dict(role='user', content=self._content_from_message(prompt)) + ) + request = dict() + request.update(request_args) + request['messages'] = messages + return request - return self._parallel_execute_with_currency_control( - _open_ai_chat_completion, - prompts, - retry_on_errors=( - ServiceUnavailableError, - RateLimitError, - APITimeoutError + def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample: + # Reference: + # https://platform.openai.com/docs/api-reference/chat/object + logprobs = None + choice_logprobs = choice.get('logprobs') + if choice_logprobs: + logprobs = [ + ( + t['token'], + t['logprob'], + [(tt['token'], tt['logprob']) for tt in t['top_logprobs']], + ) + for t in choice_logprobs['content'] + ] + return lf.LMSample( + choice['message']['content'], + score=0.0, + logprobs=logprobs, + ) + + def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: + usage = json['usage'] + return lf.LMSamplingResult( + samples=[self._parse_choice(choice) for choice in json['choices']], + usage=lf.LMSamplingUsage( + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'], + total_tokens=usage['total_tokens'], + estimated_cost=self.estimate_cost( + num_input_tokens=usage['prompt_tokens'], + num_output_tokens=usage['completion_tokens'], + ) ), ) diff --git a/langfun/core/llms/openai_test.py b/langfun/core/llms/openai_test.py index 22fd5c5..1dd36c4 100644 --- a/langfun/core/llms/openai_test.py +++ b/langfun/core/llms/openai_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for OpenAI models.""" +from typing import Any import unittest from unittest import mock @@ -20,86 +21,106 @@ from langfun.core import modalities as lf_modalities from langfun.core.llms import openai import pyglove as pg +import requests -def mock_completion_query(prompt, *, n=1, **kwargs): - del kwargs - choices = [] - for i, _ in enumerate(prompt): - for k in range(n): - choices.append(pg.Dict( - index=i, - text=f'Sample {k} for prompt {i}.', - logprobs=k / 10, - )) - return pg.Dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) - - -def mock_chat_completion_query(messages, *, n=1, **kwargs): +def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs): + del url, kwargs + messages = json['messages'] if len(messages) > 1: system_message = f' system={messages[0]["content"]}' else: system_message = '' - if 'response_format' in kwargs: - response_format = f' format={kwargs["response_format"]["type"]}' + if 'response_format' in json: + response_format = f' format={json["response_format"]["type"]}' else: response_format = '' choices = [] - for k in range(n): - choices.append(pg.Dict( - message=pg.Dict( + for k in range(json['n']): + if json.get('logprobs'): + logprobs = dict( + content=[ + dict( + token='chosen_token', + logprob=0.5, + top_logprobs=[ + dict( + token=f'alternative_token_{i + 1}', + logprob=0.1 + ) for i in range(3) + ] + ) + ] + ) + else: + logprobs = None + + choices.append(dict( + message=dict( content=( f'Sample {k} for message.{system_message}{response_format}' ) ), - logprobs=None, + logprobs=logprobs, )) - return pg.Dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str( + dict( + choices=choices, + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + ), + ) + ).encode() + return response -def mock_chat_completion_query_vision(messages, *, n=1, **kwargs): - del kwargs +def mock_chat_completion_request_vision( + url: str, json: dict[str, Any], **kwargs +): + del url, kwargs choices = [] urls = [ c['image_url']['url'] - for c in messages[0]['content'] if c['type'] == 'image_url' + for c in json['messages'][0]['content'] if c['type'] == 'image_url' ] - for k in range(n): + for k in range(json['n']): choices.append(pg.Dict( message=pg.Dict( content=f'Sample {k} for message: {"".join(urls)}' ), logprobs=None, )) - return pg.Dict( - choices=choices, - usage=lf.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - ), - ) + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str( + dict( + choices=choices, + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + ), + ) + ).encode() + return response class OpenAITest(unittest.TestCase): """Tests for OpenAI language model.""" + def test_dir(self): + self.assertIn('gpt-4-turbo', openai.OpenAI.dir()) + + def test_key(self): + with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'): + openai.Gpt4()('hi') + def test_model_id(self): self.assertEqual( openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)') @@ -112,29 +133,9 @@ def test_resource_id(self): def test_max_concurrency(self): self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0) - def test_get_request_args(self): - self.assertEqual( - openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args( - lf.LMSamplingOptions( - temperature=2.0, - logprobs=True, - n=2, - max_tokens=4096, - top_p=1.0)), - dict( - engine='text-davinci-003', - logprobs=True, - top_logprobs=None, - n=2, - temperature=2.0, - max_tokens=4096, - stream=False, - timeout=90.0, - top_p=1.0, - ) - ) + def test_request_args(self): self.assertEqual( - openai.Gpt4(api_key='test_key')._get_request_args( + openai.Gpt4(api_key='test_key')._request_args( lf.LMSamplingOptions( temperature=1.0, stop=['\n'], n=1, random_seed=123 ) @@ -144,40 +145,93 @@ def test_get_request_args(self): top_logprobs=None, n=1, temperature=1.0, - stream=False, - timeout=120.0, stop=['\n'], seed=123, ), ) with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'): - openai.GptO1Preview(api_key='test_key')._get_request_args( + openai.GptO1Preview(api_key='test_key')._request_args( lf.LMSamplingOptions( temperature=1.0, logprobs=True ) ) - def test_call_completion(self): - with mock.patch('openai.Completion.create') as mock_completion: - mock_completion.side_effect = mock_completion_query - lm = openai.OpenAI(api_key='test_key', model='text-davinci-003') + def test_call_chat_completion(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai.OpenAI( + model='gpt-4', + api_key='test_key', + organization='my_org', + project='my_project' + ) self.assertEqual( lm('hello', sampling_options=lf.LMSamplingOptions(n=2)), - 'Sample 0 for prompt 0.', + 'Sample 0 for message.', ) - def test_call_chat_completion(self): - with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: - mock_chat_completion.side_effect = mock_chat_completion_query - lm = openai.OpenAI(api_key='test_key', model='gpt-4') + def test_call_chat_completion_with_logprobs(self): + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request + lm = openai.OpenAI( + model='gpt-4', + api_key='test_key', + organization='my_org', + project='my_project' + ) + results = lm.sample(['hello'], logprobs=True) + self.assertEqual(len(results), 1) self.assertEqual( - lm('hello', sampling_options=lf.LMSamplingOptions(n=2)), - 'Sample 0 for message.', + results[0], + lf.LMSamplingResult( + [ + lf.LMSample( + response=lf.AIMessage( + text='Sample 0 for message.', + metadata={ + 'score': 0.0, + 'logprobs': [( + 'chosen_token', + 0.5, + [ + ('alternative_token_1', 0.1), + ('alternative_token_2', 0.1), + ('alternative_token_3', 0.1), + ], + )], + 'is_cached': False, + 'usage': lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + estimated_cost=0.009, + ), + }, + tags=['lm-response'], + ), + logprobs=[( + 'chosen_token', + 0.5, + [ + ('alternative_token_1', 0.1), + ('alternative_token_2', 0.1), + ('alternative_token_3', 0.1), + ], + )], + ) + ], + usage=lf.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + estimated_cost=0.009, + ), + ), ) def test_call_chat_completion_vision(self): - with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: - mock_chat_completion.side_effect = mock_chat_completion_query_vision + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request_vision lm_1 = openai.Gpt4Turbo(api_key='test_key') lm_2 = openai.Gpt4VisionPreview(api_key='test_key') for lm in (lm_1, lm_2): @@ -191,136 +245,18 @@ def test_call_chat_completion_vision(self): ), 'Sample 0 for message: https://fake/image', ) - - def test_sample_completion(self): - with mock.patch('openai.Completion.create') as mock_completion: - mock_completion.side_effect = mock_completion_query - lm = openai.OpenAI(api_key='test_key', model='text-davinci-003') - results = lm.sample( - ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3) + lm_3 = openai.Gpt35Turbo(api_key='test_key') + with self.assertRaisesRegex(ValueError, 'Unsupported modality'): + lm_3( + lf.UserMessage( + 'hello <<[[image]]>>', + image=lf_modalities.Image.from_uri('https://fake/image') + ), ) - self.assertEqual(len(results), 2) - self.assertEqual( - results[0], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for prompt 0.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for prompt 0.', - score=0.1, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.1, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for prompt 0.', - score=0.2, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.2, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=50, completion_tokens=50, total_tokens=100 - ), - ), - ) - self.assertEqual( - results[1], - lf.LMSamplingResult( - [ - lf.LMSample( - lf.AIMessage( - 'Sample 0 for prompt 1.', - score=0.0, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.0, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 1 for prompt 1.', - score=0.1, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.1, - logprobs=None, - ), - lf.LMSample( - lf.AIMessage( - 'Sample 2 for prompt 1.', - score=0.2, - logprobs=None, - is_cached=False, - usage=lf.LMSamplingUsage( - prompt_tokens=16, - completion_tokens=16, - total_tokens=33 - ), - tags=[lf.Message.TAG_LM_RESPONSE], - ), - score=0.2, - logprobs=None, - ), - ], - usage=lf.LMSamplingUsage( - prompt_tokens=50, completion_tokens=50, total_tokens=100 - ), - ), - ) - def test_sample_chat_completion(self): - with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: - mock_chat_completion.side_effect = mock_chat_completion_query + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({ 'cost_per_1k_input_tokens': 1.0, 'cost_per_1k_output_tokens': 1.0, @@ -458,8 +394,8 @@ def test_sample_chat_completion(self): ) def test_sample_with_contextual_options(self): - with mock.patch('openai.Completion.create') as mock_completion: - mock_completion.side_effect = mock_completion_query + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request lm = openai.OpenAI(api_key='test_key', model='text-davinci-003') with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)): results = lm.sample(['hello']) @@ -471,7 +407,7 @@ def test_sample_with_contextual_options(self): [ lf.LMSample( lf.AIMessage( - 'Sample 0 for prompt 0.', + 'Sample 0 for message.', score=0.0, logprobs=None, is_cached=False, @@ -487,8 +423,8 @@ def test_sample_with_contextual_options(self): ), lf.LMSample( lf.AIMessage( - 'Sample 1 for prompt 0.', - score=0.1, + 'Sample 1 for message.', + score=0.0, logprobs=None, is_cached=False, usage=lf.LMSamplingUsage( @@ -498,19 +434,19 @@ def test_sample_with_contextual_options(self): ), tags=[lf.Message.TAG_LM_RESPONSE], ), - score=0.1, + score=0.0, logprobs=None, ), ], usage=lf.LMSamplingUsage( prompt_tokens=100, completion_tokens=100, total_tokens=200 ), - ), + ) ) def test_call_with_system_message(self): - with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: - mock_chat_completion.side_effect = mock_chat_completion_query + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request lm = openai.OpenAI(api_key='test_key', model='gpt-4') self.assertEqual( lm( @@ -520,12 +456,12 @@ def test_call_with_system_message(self): ), sampling_options=lf.LMSamplingOptions(n=2) ), - 'Sample 0 for message. system=hi', + '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''', ) def test_call_with_json_schema(self): - with mock.patch('openai.ChatCompletion.create') as mock_chat_completion: - mock_chat_completion.side_effect = mock_chat_completion_query + with mock.patch('requests.Session.post') as mock_request: + mock_request.side_effect = mock_chat_completion_request lm = openai.OpenAI(api_key='test_key', model='gpt-4') self.assertEqual( lm( diff --git a/requirements.txt b/requirements.txt index b31335f..4577f83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,6 @@ tqdm>=4.64.1 google-cloud-aiplatform>=1.5.0 # extras:llm-google-genai google-generativeai>=0.3.2 -# extras:llm-openai -openai>=0.27.2 # extras:mime-auto python-magic>=0.4.27