From 92ebbdfc175e2ee63e047accadcd10b9f3513657 Mon Sep 17 00:00:00 2001 From: Langfun Authors Date: Mon, 18 Nov 2024 16:52:55 -0800 Subject: [PATCH] REST API implementation for VertexAI Gemini models. PiperOrigin-RevId: 697802489 --- langfun/core/llms/__init__.py | 3 + langfun/core/llms/rest.py | 2 +- langfun/core/llms/vertexai.py | 273 ++++++++++++++++++++++++++++- langfun/core/llms/vertexai_test.py | 212 +++++++++++++++++++++- 4 files changed, 474 insertions(+), 16 deletions(-) diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index dd60f20..54ec904 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -120,6 +120,8 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo from langfun.core.llms.vertexai import VertexAI +from langfun.core.llms.vertexai import VertexRestfulAI +from langfun.core.llms.vertexai import VertexRestfulAIGemini1_5 from langfun.core.llms.vertexai import VertexAIGemini1_5 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest @@ -134,6 +136,7 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514 from langfun.core.llms.vertexai import VertexAIGeminiPro1 from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision +from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision_001 from langfun.core.llms.vertexai import VertexAIPalm2 from langfun.core.llms.vertexai import VertexAIPalm2_32K from langfun.core.llms.vertexai import VertexAICustom diff --git a/langfun/core/llms/rest.py b/langfun/core/llms/rest.py index 11273a0..03eaf72 100644 --- a/langfun/core/llms/rest.py +++ b/langfun/core/llms/rest.py @@ -26,7 +26,7 @@ class REST(lf.LanguageModel): api_endpoint: Annotated[ str, 'The endpoint of the REST API.' - ] + ] = '' request: Annotated[ Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]], diff --git a/langfun/core/llms/vertexai.py b/langfun/core/llms/vertexai.py index 5a66d0b..d58b98a 100644 --- a/langfun/core/llms/vertexai.py +++ b/langfun/core/llms/vertexai.py @@ -19,11 +19,14 @@ import langfun.core as lf from langfun.core import modalities as lf_modalities +from langfun.core.llms import rest import pyglove as pg try: # pylint: disable=g-import-not-at-top + from google import auth as google_auth from google.auth import credentials as credentials_lib + from google.auth.transport import requests as auth_requests import vertexai from google.cloud.aiplatform import models as aiplatform_models from vertexai import generative_models @@ -32,6 +35,8 @@ Credentials = credentials_lib.Credentials except ImportError: + google_auth = None + auth_requests = None credentials_lib = None # pylint: disable=invalid-name vertexai = None generative_models = None @@ -127,6 +132,12 @@ cost_per_1k_input_chars=0.000125, cost_per_1k_output_chars=0.000375, ), + 'gemini-1.0-pro-vision-001': pg.Dict( + api='gemini', + rpm=100, + cost_per_1k_input_chars=0.000125, + cost_per_1k_output_chars=0.000375, + ), # PaLM APIs. 'text-bison': pg.Dict( api='palm', @@ -449,6 +460,239 @@ def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult: ]) +@lf.use_init_args(['model']) +class VertexRestfulAI(rest.REST): + """Language model served on VertexAI with REST API.""" + + model: pg.typing.Annotated[ + pg.typing.Enum( + pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) + ), + ( + 'Vertex AI model name with REST API support. See ' + 'https://cloud.google.com/vertex-ai/generative-ai/docs/' + 'model-reference/inference#supported-models' + ' for details.' + ), + ] + + project: Annotated[ + str | None, + ( + 'Vertex AI project ID. Or set from environment variable ' + 'VERTEXAI_PROJECT.' + ), + ] = None + + location: Annotated[ + str | None, + ( + 'Vertex AI service location. Or set from environment variable ' + 'VERTEXAI_LOCATION.' + ), + ] = None + + credentials: Annotated[ + Credentials | None, + ( + 'Credentials to use. If None, the default credentials to the ' + 'environment will be used.' + ), + ] = None + + supported_modalities: Annotated[ + list[str], + 'A list of MIME types for supported modalities' + ] = [] + + def _on_bound(self): + super()._on_bound() + if google_auth is None: + raise ValueError( + 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.' + ) + self._project = None + self._credentials = None + + def _initialize(self): + project = self.project or os.environ.get('VERTEXAI_PROJECT', None) + if not project: + raise ValueError( + 'Please specify `project` during `__init__` or set environment ' + 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.' + ) + + location = self.location or os.environ.get('VERTEXAI_LOCATION', None) + if not location: + raise ValueError( + 'Please specify `location` during `__init__` or set environment ' + 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.' + ) + + self._project = project + credentials = self.credentials + if credentials is None: + # Use default credentials. + credentials = google_auth.default( + scopes=['https://www.googleapis.com/auth/cloud-platform'] + ) + self._credentials = credentials + + @property + def max_concurrency(self) -> int: + """Returns the maximum number of concurrent requests.""" + return self.rate_to_max_concurrency( + requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm, + tokens_per_min=0, + ) + + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_input_chars', None + ) + cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( + 'cost_per_1k_output_chars', None + ) + if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None: + return None + return ( + cost_per_1k_input_chars * num_input_tokens + + cost_per_1k_output_chars * num_output_tokens + ) * AVGERAGE_CHARS_PER_TOEKN / 1000 + + @functools.cached_property + def _session(self): + assert self._api_initialized + assert self._credentials is not None + assert auth_requests is not None + s = auth_requests.AuthorizedSession(self._credentials) + s.headers.update(self.headers or {}) + return s + + @property + def headers(self): + return { + 'Content-Type': 'application/json; charset=utf-8', + } + + @property + def api_endpoint(self) -> str: + return ( + f'https://{self.location}-aiplatform.googleapis.com/v1/projects/' + f'{self.project}/locations/{self.location}/publishers/google/' + f'models/{self.model}:generateContent' + ) + + def request( + self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions + ) -> dict[str, Any]: + request = dict( + generationConfig=self._generation_config(prompt, sampling_options) + ) + request['contents'] = [self._content_from_message(prompt)] + return request + + def _generation_config( + self, prompt: lf.Message, options: lf.LMSamplingOptions + ) -> dict[str, Any]: + """Returns a dict as generation config for prompt and LMSamplingOptions.""" + config = dict( + temperature=options.temperature, + maxOutputTokens=options.max_tokens, + candidateCount=options.n, + topK=options.top_k, + topP=options.top_p, + stopSequences=options.stop, + seed=options.random_seed, + responseLogprobs=options.logprobs, + logprobs=options.top_logprobs, + ) + + if json_schema := prompt.metadata.get('json_schema'): + if not isinstance(json_schema, dict): + raise ValueError( + f'`json_schema` must be a dict, got {json_schema!r}.' + ) + json_schema = pg.to_json(json_schema) + config['responseSchema'] = json_schema + config['responseMimeType'] = 'application/json' + prompt.metadata.formatted_text = ( + prompt.text + + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' + + pg.to_json_str(json_schema, json_indent=2) + ) + return config + + def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]: + """Gets generation content from langfun message.""" + parts = [] + for lf_chunk in prompt.chunk(): + if isinstance(lf_chunk, str): + parts.append({'text': lf_chunk}) + elif isinstance(lf_chunk, lf_modalities.Mime): + try: + modalities = lf_chunk.make_compatible( + self.supported_modalities + ['text/plain'] + ) + if isinstance(modalities, lf_modalities.Mime): + modalities = [modalities] + for modality in modalities: + if modality.is_text: + parts.append({'text': modality.to_text()}) + else: + parts.append({ + 'bytes': modality.to_bytes(), + 'mimeType': modality.mime_type, + }) + except lf.ModalityError as e: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e + else: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') + return dict(role='user', parts=parts) + + def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: + messages = [ + self._message_from_content_parts(candidate['content']['parts']) + for candidate in json['candidates'] + ] + usage = json['usageMetadata'] + input_tokens = usage['promptTokenCount'] + output_tokens = usage['candidatesTokenCount'] + return lf.LMSamplingResult( + [lf.LMSample(message) for message in messages], + usage=lf.LMSamplingUsage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + estimated_cost=self.estimate_cost( + num_input_tokens=input_tokens, + num_output_tokens=output_tokens, + ), + ), + ) + + def _message_from_content_parts( + self, parts: list[dict[str, Any]] + ) -> lf.Message: + """Converts Vertex AI's content parts protocol to message.""" + chunks = [] + for part in parts: + if text_part := part.get('text'): + chunks.append(text_part) + elif inline_part := part.get('inlineData'): + chunks.append( + lf_modalities.Mime(inline_part['data'], inline_part['mimeType']) + ) + else: + raise ValueError(f'Unsupported part: {part}') + return lf.AIMessage.from_chunks(chunks) + + class _ModelHub: """Vertex AI model hub.""" @@ -547,13 +791,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name model = 'gemini-1.5-pro' -class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexRestfulAIGemini1_5(VertexRestfulAI): # pylint: disable=invalid-name + """Vertex AI Gemini 1.5 model with REST API.""" + + supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation + _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES + ) + + +class VertexAIGeminiPro1_5_002(VertexRestfulAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-002' -class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_001(VertexRestfulAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-001' @@ -583,13 +835,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name model = 'gemini-1.5-flash' -class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_002(VertexRestfulAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-002' -class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_001(VertexRestfulAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-001' @@ -601,14 +853,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid model = 'gemini-1.5-flash-preview-0514' -class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name +class VertexAIGeminiPro1(VertexRestfulAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.0 Pro model.""" model = 'gemini-1.0-pro' class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 1.0 Pro model.""" + """Vertex AI Gemini 1.0 Pro Vision model.""" model = 'gemini-1.0-pro-vision' supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation @@ -616,6 +868,15 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name ) +class VertexAIGeminiPro1Vision_001(VertexRestfulAI): # pylint: disable=invalid-name + """Vertex AI Gemini 1.0 Pro Vision model with REST API.""" + + model = 'gemini-1.0-pro-vision-001' + supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation + _IMAGE_TYPES + _VIDEO_TYPES + ) + + class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name """Vertex AI PaLM2 text generation model.""" diff --git a/langfun/core/llms/vertexai_test.py b/langfun/core/llms/vertexai_test.py index fbb2cbf..8a04d2f 100644 --- a/langfun/core/llms/vertexai_test.py +++ b/langfun/core/llms/vertexai_test.py @@ -14,6 +14,7 @@ """Tests for Gemini models.""" import os +from typing import Any import unittest from unittest import mock @@ -23,6 +24,7 @@ from langfun.core import modalities as lf_modalities from langfun.core.llms import vertexai import pyglove as pg +import requests example_image = ( @@ -64,6 +66,40 @@ def mock_generate_content(content, generation_config, **kwargs): }) +def mock_requests_post(url: str, json: dict[str, Any], **kwargs): + del url, kwargs + c = pg.Dict(json['generationConfig']) + content = json['contents'][0]['parts'][0]['text'] + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str({ + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + { + 'text': ( + f'This is a response to {content} with ' + f'temperature={c.temperature}, ' + f'top_p={c.topP}, ' + f'top_k={c.topK}, ' + f'max_tokens={c.maxOutputTokens}, ' + f'stop={"".join(c.stopSequences)}.' + ) + }, + ], + }, + }, + ], + 'usageMetadata': { + 'promptTokenCount': 3, + 'candidatesTokenCount': 4, + } + }).encode() + return response + + def mock_endpoint_predict(instances, **kwargs): del kwargs assert len(instances) == 1 @@ -83,7 +119,7 @@ class VertexAITest(unittest.TestCase): def test_content_from_message_text_only(self): text = 'This is a beautiful day' - model = vertexai.VertexAIGeminiPro1() + model = vertexai.VertexAIGeminiPro1Vision() chunks = model._content_from_message(lf.UserMessage(text)) self.assertEqual(chunks, [text]) @@ -95,7 +131,7 @@ def test_content_from_message_mm(self): # Non-multimodal model. with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - vertexai.VertexAIGeminiPro1()._content_from_message(message) + vertexai.VertexAIPalm2()._content_from_message(message) model = vertexai.VertexAIGeminiPro1Vision() chunks = model._content_from_message(message) @@ -119,7 +155,7 @@ def test_generation_response_to_message_text_only(self): }, ], }) - model = vertexai.VertexAIGeminiPro1() + model = vertexai.VertexAIGeminiPro1Vision() message = model._generation_response_to_message(response) self.assertEqual(message, lf.AIMessage('hello world')) @@ -158,25 +194,25 @@ class TextGenerationModel: def test_project_and_location_check(self): with self.assertRaisesRegex(ValueError, 'Please specify `project`'): - _ = vertexai.VertexAIGeminiPro1()._api_initialized + _ = vertexai.VertexAIGeminiPro1Vision()._api_initialized with self.assertRaisesRegex(ValueError, 'Please specify `location`'): - _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized + _ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized self.assertTrue( - vertexai.VertexAIGeminiPro1( + vertexai.VertexAIGeminiPro1Vision( project='abc', location='us-central1' )._api_initialized ) os.environ['VERTEXAI_PROJECT'] = 'abc' os.environ['VERTEXAI_LOCATION'] = 'us-central1' - self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized) + self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized) del os.environ['VERTEXAI_PROJECT'] del os.environ['VERTEXAI_LOCATION'] def test_generation_config(self): - model = vertexai.VertexAIGeminiPro1() + model = vertexai.VertexAIGeminiPro1Vision() json_schema = { 'type': 'object', 'properties': { @@ -245,7 +281,9 @@ def test_call_generative_model(self): ) as mock_generate: mock_generate.side_effect = mock_generate_content - lm = vertexai.VertexAIGeminiPro1(project='abc', location='us-central1') + lm = vertexai.VertexAIGeminiPro1Vision( + project='abc', location='us-central1' + ) self.assertEqual( lm( 'hello', @@ -328,5 +366,161 @@ def test_call_endpoint_model(self): ) +class VertexRestfulAITest(unittest.TestCase): + """Tests for Vertex model with REST API.""" + + def test_content_from_message_text_only(self): + text = 'This is a beautiful day' + model = vertexai.VertexAIGeminiPro1_5_002() + chunks = model._content_from_message(lf.UserMessage(text)) + self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) + + def test_content_from_message_mm(self): + message = lf.UserMessage( + 'This is an <<[[image]]>>, what is it?', + image=lf_modalities.Image.from_bytes(example_image), + ) + + # Non-multimodal model. + with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): + vertexai.VertexAIGeminiPro1()._content_from_message(message) + + model = vertexai.VertexAIGeminiPro1Vision() + chunks = model._content_from_message(message) + self.maxDiff = None + self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?']) + self.assertIsInstance(chunks[1], generative_models.Part) + + def test_generation_response_to_message_text_only(self): + response = generative_models.GenerationResponse.from_dict({ + 'candidates': [ + { + 'index': 0, + 'content': { + 'role': 'model', + 'parts': [ + { + 'text': 'hello world', + }, + ], + }, + }, + ], + }) + model = vertexai.VertexAIGeminiPro1Vision() + message = model._generation_response_to_message(response) + self.assertEqual(message, lf.AIMessage('hello world')) + + def test_model_hub(self): + with mock.patch( + 'vertexai.generative_models.' + 'GenerativeModel.__init__' + ) as mock_model_init: + mock_model_init.side_effect = lambda *args, **kwargs: None + model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model( + 'gemini-1.0-pro' + ) + self.assertIsNotNone(model) + self.assertIs( + vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'), + model, + ) + + def test_project_and_location_check(self): + with self.assertRaisesRegex(ValueError, 'Please specify `project`'): + _ = vertexai.VertexAIGeminiPro1()._api_initialized + + with self.assertRaisesRegex(ValueError, 'Please specify `location`'): + _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized + + self.assertTrue( + vertexai.VertexAIGeminiPro1( + project='abc', location='us-central1' + )._api_initialized + ) + + os.environ['VERTEXAI_PROJECT'] = 'abc' + os.environ['VERTEXAI_LOCATION'] = 'us-central1' + self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized) + del os.environ['VERTEXAI_PROJECT'] + del os.environ['VERTEXAI_LOCATION'] + + def test_generation_config(self): + model = vertexai.VertexAIGeminiPro1() + json_schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + }, + 'required': ['name'], + 'title': 'Person', + } + actual = model._generation_config( + lf.UserMessage('hi', json_schema=json_schema), + lf.LMSamplingOptions( + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop=['\n'], + ), + ) + self.assertEqual( + actual, + dict( + candidateCount=1, + temperature=2.0, + topP=1.0, + topK=20, + maxOutputTokens=1024, + stopSequences=['\n'], + responseLogprobs=False, + logprobs=None, + seed=None, + responseMimeType='application/json', + responseSchema={ + 'type': 'object', + 'properties': { + 'name': {'type': 'string'} + }, + 'required': ['name'], + 'title': 'Person', + } + ), + ) + with self.assertRaisesRegex( + ValueError, '`json_schema` must be a dict, got' + ): + model._generation_config( + lf.UserMessage('hi', json_schema='not a dict'), + lf.LMSamplingOptions(), + ) + + def test_call_model(self): + with mock.patch('requests.Session.post') as mock_generate: + mock_generate.side_effect = mock_requests_post + + lm = vertexai.VertexAIGeminiPro1_5_002( + project='abc', location='us-central1' + ) + r = lm( + 'hello', + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop='\n', + ) + self.assertEqual( + r.text, + ( + 'This is a response to hello with temperature=2.0, ' + 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' + ), + ) + self.assertEqual(r.metadata.usage.prompt_tokens, 3) + self.assertEqual(r.metadata.usage.completion_tokens, 4) + + if __name__ == '__main__': unittest.main()