diff --git a/README.md b/README.md index 5da44a9..aaa04ea 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,6 @@ If you want to customize your installation, you can select specific features usi | all | All Langfun features. | | llm | All supported LLMs. | | llm-google | All supported Google-powered LLMs. | -| llm-google-vertexai | LLMs powered by Google Cloud VertexAI | | llm-google-genai | LLMs powered by Google Generative AI API | | mime | All MIME supports. | | mime-auto | Automatic MIME type detection. | diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index 0d898a1..dd5259f 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -120,25 +120,19 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo from langfun.core.llms.vertexai import VertexAI -from langfun.core.llms.vertexai import VertexAIRest -from langfun.core.llms.vertexai import VertexAIRestGemini1_5 from langfun.core.llms.vertexai import VertexAIGemini1_5 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514 from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_Latest from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5 from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001 from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002 from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514 from langfun.core.llms.vertexai import VertexAIGeminiPro1 from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision -from langfun.core.llms.vertexai import VertexAIPalm2 -from langfun.core.llms.vertexai import VertexAIPalm2_32K -from langfun.core.llms.vertexai import VertexAICustom +from langfun.core.llms.vertexai import VertexAIEndpoint # LLaMA C++ models. diff --git a/langfun/core/llms/vertexai.py b/langfun/core/llms/vertexai.py index e806487..7344d6a 100644 --- a/langfun/core/llms/vertexai.py +++ b/langfun/core/llms/vertexai.py @@ -28,21 +28,13 @@ from google import auth as google_auth from google.auth import credentials as credentials_lib from google.auth.transport import requests as auth_requests - import vertexai - from google.cloud.aiplatform import models as aiplatform_models - from vertexai import generative_models - from vertexai import language_models # pylint: enable=g-import-not-at-top Credentials = credentials_lib.Credentials except ImportError: google_auth = None + credentials_lib = None auth_requests = None - credentials_lib = None # pylint: disable=invalid-name - vertexai = None - generative_models = None - language_models = None - aiplatform_models = None Credentials = Any @@ -56,408 +48,72 @@ # as of 2024-10-10. SUPPORTED_MODELS_AND_SETTINGS = { 'gemini-1.5-pro-001': pg.Dict( - api='gemini', rpm=100, cost_per_1k_input_chars=0.0003125, cost_per_1k_output_chars=0.00125, ), 'gemini-1.5-pro-002': pg.Dict( - api='gemini', rpm=100, cost_per_1k_input_chars=0.0003125, cost_per_1k_output_chars=0.00125, ), 'gemini-1.5-flash-002': pg.Dict( - api='gemini', rpm=500, cost_per_1k_input_chars=0.00001875, cost_per_1k_output_chars=0.000075, ), 'gemini-1.5-flash-001': pg.Dict( - api='gemini', rpm=500, cost_per_1k_input_chars=0.00001875, cost_per_1k_output_chars=0.000075, ), 'gemini-1.5-pro': pg.Dict( - api='gemini', rpm=100, cost_per_1k_input_chars=0.0003125, cost_per_1k_output_chars=0.00125, ), 'gemini-1.5-flash': pg.Dict( - api='gemini', - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-pro-latest': pg.Dict( - api='gemini', - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash-latest': pg.Dict( - api='gemini', rpm=500, cost_per_1k_input_chars=0.00001875, cost_per_1k_output_chars=0.000075, ), 'gemini-1.5-pro-preview-0514': pg.Dict( - api='gemini', rpm=50, cost_per_1k_input_chars=0.0003125, cost_per_1k_output_chars=0.00125, ), 'gemini-1.5-pro-preview-0409': pg.Dict( - api='gemini', rpm=50, cost_per_1k_input_chars=0.0003125, cost_per_1k_output_chars=0.00125, ), 'gemini-1.5-flash-preview-0514': pg.Dict( - api='gemini', rpm=200, cost_per_1k_input_chars=0.00001875, cost_per_1k_output_chars=0.000075, ), 'gemini-1.0-pro': pg.Dict( - api='gemini', rpm=300, cost_per_1k_input_chars=0.000125, cost_per_1k_output_chars=0.000375, ), 'gemini-1.0-pro-vision': pg.Dict( - api='gemini', rpm=100, cost_per_1k_input_chars=0.000125, cost_per_1k_output_chars=0.000375, ), - # PaLM APIs. - 'text-bison': pg.Dict( - api='palm', - rpm=1600 - ), - 'text-bison-32k': pg.Dict( - api='palm', - rpm=300 - ), - 'text-unicorn': pg.Dict( - api='palm', - rpm=100 - ), - # Endpoint # TODO(chengrun): Set a more appropriate rpm for endpoint. - 'custom': pg.Dict(api='endpoint', rpm=20), + 'vertexai-endpoint': pg.Dict( + rpm=20, + cost_per_1k_input_chars=0.0000125, + cost_per_1k_output_chars=0.0000375, + ), } -@lf.use_init_args(['model']) -class VertexAI(lf.LanguageModel): - """Language model served on VertexAI.""" - - model: pg.typing.Annotated[ - pg.typing.Enum( - pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) - ), - ( - 'Vertex AI model name. See ' - 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models ' - 'for details.' - ), - ] - - endpoint_name: pg.typing.Annotated[ - str | None, - 'Vertex Endpoint name or ID.', - ] - - project: Annotated[ - str | None, - ( - 'Vertex AI project ID. Or set from environment variable ' - 'VERTEXAI_PROJECT.' - ), - ] = None - - location: Annotated[ - str | None, - ( - 'Vertex AI service location. Or set from environment variable ' - 'VERTEXAI_LOCATION.' - ), - ] = None - - credentials: Annotated[ - Credentials | None, - ( - 'Credentials to use. If None, the default credentials to the ' - 'environment will be used.' - ), - ] = None - - supported_modalities: Annotated[ - list[str], - 'A list of MIME types for supported modalities' - ] = [] - - def _on_bound(self): - super()._on_bound() - self.__dict__.pop('_api_initialized', None) - if generative_models is None: - raise RuntimeError( - 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.' - ) - - @functools.cached_property - def _api_initialized(self): - project = self.project or os.environ.get('VERTEXAI_PROJECT', None) - if not project: - raise ValueError( - 'Please specify `project` during `__init__` or set environment ' - 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.' - ) - - location = self.location or os.environ.get('VERTEXAI_LOCATION', None) - if not location: - raise ValueError( - 'Please specify `location` during `__init__` or set environment ' - 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.' - ) - - credentials = self.credentials - # Placeholder for Google-internal credentials. - assert vertexai is not None - vertexai.init(project=project, location=location, credentials=credentials) - return True - - @property - def model_id(self) -> str: - """Returns a string to identify the model.""" - return f'VertexAI({self.model})' - - @property - def resource_id(self) -> str: - """Returns a string to identify the resource for rate control.""" - return self.model_id - - @property - def max_concurrency(self) -> int: - """Returns the maximum number of concurrent requests.""" - return self.rate_to_max_concurrency( - requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm, - tokens_per_min=0, - ) - - def estimate_cost( - self, - num_input_tokens: int, - num_output_tokens: int - ) -> float | None: - """Estimate the cost based on usage.""" - cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_input_chars', None - ) - cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_output_chars', None - ) - if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None: - return None - return ( - cost_per_1k_input_chars * num_input_tokens - + cost_per_1k_output_chars * num_output_tokens - ) * AVGERAGE_CHARS_PER_TOEKN / 1000 - - def _generation_config( - self, prompt: lf.Message, options: lf.LMSamplingOptions - ) -> Any: # generative_models.GenerationConfig - """Creates generation config from langfun sampling options.""" - assert generative_models is not None - # Users could use `metadata_json_schema` to pass additional - # request arguments. - json_schema = prompt.metadata.get('json_schema') - response_mime_type = None - if json_schema is not None: - if not isinstance(json_schema, dict): - raise ValueError( - f'`json_schema` must be a dict, got {json_schema!r}.' - ) - response_mime_type = 'application/json' - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(json_schema, json_indent=2) - ) - - return generative_models.GenerationConfig( - temperature=options.temperature, - top_p=options.top_p, - top_k=options.top_k, - max_output_tokens=options.max_tokens, - stop_sequences=options.stop, - response_mime_type=response_mime_type, - response_schema=json_schema, - ) - - def _content_from_message( - self, prompt: lf.Message - ) -> list[str | Any]: - """Gets generation input from langfun message.""" - assert generative_models is not None - chunks = [] - - for lf_chunk in prompt.chunk(): - if isinstance(lf_chunk, str): - chunks.append(lf_chunk) - elif isinstance(lf_chunk, lf_modalities.Mime): - try: - modalities = lf_chunk.make_compatible( - self.supported_modalities + ['text/plain'] - ) - if isinstance(modalities, lf_modalities.Mime): - modalities = [modalities] - for modality in modalities: - if modality.is_text: - chunk = modality.to_text() - else: - chunk = generative_models.Part.from_data( - modality.to_bytes(), modality.mime_type - ) - chunks.append(chunk) - except lf.ModalityError as e: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e - else: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') - return chunks - - def _generation_response_to_message( - self, - response: Any, # generative_models.GenerationResponse - ) -> lf.Message: - """Parses generative response into message.""" - return lf.AIMessage(response.text) - - def _generation_endpoint_response_to_message( - self, - response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction - ) -> lf.Message: - """Parses Endpoint response into message.""" - return lf.AIMessage(response.predictions[0]) - - def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: - assert self._api_initialized, 'Vertex AI API is not initialized.' - # TODO(yifenglu): It seems this exception is due to the instability of the - # API. We should revisit this later. - retry_on_errors = [ - (Exception, 'InternalServerError'), - (Exception, 'ResourceExhausted'), - (Exception, '_InactiveRpcError'), - (Exception, 'ValueError'), - ] - - return self._parallel_execute_with_currency_control( - self._sample_single, - prompts, - retry_on_errors=retry_on_errors, - ) - - def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult: - if self.sampling_options.n > 1: - raise ValueError( - f'`n` greater than 1 is not supported: {self.sampling_options.n}.' - ) - api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api - match api: - case 'gemini': - return self._sample_generative_model(prompt) - case 'palm': - return self._sample_text_generation_model(prompt) - case 'endpoint': - return self._sample_endpoint_model(prompt) - case _: - raise ValueError(f'Unsupported API: {api}') - - def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult: - """Samples a generative model.""" - model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model) - input_content = self._content_from_message(prompt) - response = model.generate_content( - input_content, - generation_config=self._generation_config( - prompt, self.sampling_options - ), - ) - usage_metadata = response.usage_metadata - usage = lf.LMSamplingUsage( - prompt_tokens=usage_metadata.prompt_token_count, - completion_tokens=usage_metadata.candidates_token_count, - total_tokens=usage_metadata.total_token_count, - estimated_cost=self.estimate_cost( - num_input_tokens=usage_metadata.prompt_token_count, - num_output_tokens=usage_metadata.candidates_token_count, - ), - ) - return lf.LMSamplingResult( - [ - # Scoring is not supported. - lf.LMSample( - self._generation_response_to_message(response), score=0.0 - ), - ], - usage=usage, - ) - - def _sample_text_generation_model( - self, prompt: lf.Message - ) -> lf.LMSamplingResult: - """Samples a text generation model.""" - model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model) - predict_options = dict( - temperature=self.sampling_options.temperature, - top_k=self.sampling_options.top_k, - top_p=self.sampling_options.top_p, - max_output_tokens=self.sampling_options.max_tokens, - stop_sequences=self.sampling_options.stop, - ) - response = model.predict(prompt.text, **predict_options) - return lf.LMSamplingResult([ - # Scoring is not supported. - lf.LMSample(lf.AIMessage(response.text), score=0.0) - ]) - - def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult: - """Samples a text generation model.""" - assert aiplatform_models is not None - model = aiplatform_models.Endpoint(self.endpoint_name) - # TODO(chengrun): Add support for stop_sequences. - predict_options = dict( - temperature=self.sampling_options.temperature - if self.sampling_options.temperature is not None - else 1.0, - top_k=self.sampling_options.top_k - if self.sampling_options.top_k is not None - else 32, - top_p=self.sampling_options.top_p - if self.sampling_options.top_p is not None - else 1, - max_tokens=self.sampling_options.max_tokens - if self.sampling_options.max_tokens is not None - else 8192, - ) - instances = [{'prompt': prompt.text, **predict_options}] - response = model.predict(instances=instances) - - return lf.LMSamplingResult([ - # Scoring is not supported. - lf.LMSample( - self._generation_endpoint_response_to_message(response), score=0.0 - ) - ]) - - @lf.use_init_args(['model']) @pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) -class VertexAIRest(rest.REST): +class VertexAI(rest.REST): """Language model served on VertexAI with REST API.""" model: pg.typing.Annotated[ @@ -687,39 +343,6 @@ def _message_from_content_parts( return lf.AIMessage.from_chunks(chunks) -class _ModelHub: - """Vertex AI model hub.""" - - def __init__(self): - self._generative_model_cache = {} - self._text_generation_model_cache = {} - - def get_generative_model( - self, model_id: str - ) -> Any: # generative_models.GenerativeModel: - """Gets a generative model by model id.""" - model = self._generative_model_cache.get(model_id, None) - if model is None: - assert generative_models is not None - model = generative_models.GenerativeModel(model_id) - self._generative_model_cache[model_id] = model - return model - - def get_text_generation_model( - self, model_id: str - ) -> Any: # language_models.TextGenerationModel - """Gets a text generation model by model id.""" - model = self._text_generation_model_cache.get(model_id, None) - if model is None: - assert language_models is not None - model = language_models.TextGenerationModel.from_pretrained(model_id) - self._text_generation_model_cache[model_id] = model - return model - - -_VERTEXAI_MODEL_HUB = _ModelHub() - - _IMAGE_TYPES = [ 'image/png', 'image/jpeg', @@ -773,33 +396,19 @@ class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name ) -class VertexAIGeminiPro1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Pro model.""" - - model = 'gemini-1.5-pro-latest' - - class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro' -class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 model with REST API.""" - - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES - ) - - -class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-002' -class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-001' @@ -817,25 +426,19 @@ class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-n model = 'gemini-1.5-pro-preview-0409' -class VertexAIGeminiFlash1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Flash model.""" - - model = 'gemini-1.5-flash-latest' - - class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash' -class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-002' -class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-001' @@ -847,7 +450,7 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid model = 'gemini-1.5-flash-preview-0514' -class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name +class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.0 Pro model.""" model = 'gemini-1.0-pro' @@ -862,19 +465,17 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name ) -class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name - """Vertex AI PaLM2 text generation model.""" - - model = 'text-bison' - - -class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name - """Vertex AI PaLM2 text generation model (32K context length).""" +class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name + """Vertex AI Endpoint model.""" - model = 'text-bison-32k' + model = 'vertexai-endpoint' + endpoint: Annotated[str, 'Vertex AI Endpoint ID.'] -class VertexAICustom(VertexAI): # pylint: disable=invalid-name - """Vertex AI Custom model (Endpoint).""" - - model = 'custom' + @property + def api_endpoint(self) -> str: + return ( + f'https://{self.location}-aiplatform.googleapis.com/v1/projects/' + f'{self.project}/locations/{self.location}/' + f'endpoints/{self.endpoint}:generateContent' + ) diff --git a/langfun/core/llms/vertexai_test.py b/langfun/core/llms/vertexai_test.py index de9ff56..3e9d1ca 100644 --- a/langfun/core/llms/vertexai_test.py +++ b/langfun/core/llms/vertexai_test.py @@ -13,13 +13,12 @@ # limitations under the License. """Tests for Gemini models.""" +import base64 import os from typing import Any import unittest from unittest import mock -from google.cloud.aiplatform import models as aiplatform_models -from vertexai import generative_models import langfun.core as lf from langfun.core import modalities as lf_modalities from langfun.core.llms import vertexai @@ -39,33 +38,6 @@ ) -def mock_generate_content(content, generation_config, **kwargs): - del kwargs - c = pg.Dict(generation_config.to_dict()) - return generative_models.GenerationResponse.from_dict({ - 'candidates': [ - { - 'index': 0, - 'content': { - 'role': 'model', - 'parts': [ - { - 'text': ( - f'This is a response to {content[0]} with ' - f'temperature={c.temperature}, ' - f'top_p={c.top_p}, ' - f'top_k={c.top_k}, ' - f'max_tokens={c.max_output_tokens}, ' - f'stop={"".join(c.stop_sequences)}.' - ) - }, - ], - }, - }, - ] - }) - - def mock_requests_post(url: str, json: dict[str, Any], **kwargs): del url, kwargs c = pg.Dict(json['generationConfig']) @@ -100,273 +72,7 @@ def mock_requests_post(url: str, json: dict[str, Any], **kwargs): return response -def mock_endpoint_predict(instances, **kwargs): - del kwargs - assert len(instances) == 1 - return aiplatform_models.Prediction( - predictions=[ - f"This is a response to {instances[0]['prompt']} with" - f" temperature={instances[0]['temperature']}," - f" top_p={instances[0]['top_p']}, top_k={instances[0]['top_k']}," - f" max_tokens={instances[0]['max_tokens']}." - ], - deployed_model_id='', - ) - - class VertexAITest(unittest.TestCase): - """Tests for Vertex model.""" - - def test_content_from_message_text_only(self): - text = 'This is a beautiful day' - model = vertexai.VertexAIGeminiPro1Vision() - chunks = model._content_from_message(lf.UserMessage(text)) - self.assertEqual(chunks, [text]) - - def test_content_from_message_mm(self): - message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', - image=lf_modalities.Image.from_bytes(example_image), - ) - - # Non-multimodal model. - with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - vertexai.VertexAIPalm2()._content_from_message(message) - - model = vertexai.VertexAIGeminiPro1Vision() - chunks = model._content_from_message(message) - self.maxDiff = None - self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?']) - self.assertIsInstance(chunks[1], generative_models.Part) - - def test_generation_response_to_message_text_only(self): - response = generative_models.GenerationResponse.from_dict({ - 'candidates': [ - { - 'index': 0, - 'content': { - 'role': 'model', - 'parts': [ - { - 'text': 'hello world', - }, - ], - }, - }, - ], - }) - model = vertexai.VertexAIGeminiPro1Vision() - message = model._generation_response_to_message(response) - self.assertEqual(message, lf.AIMessage('hello world')) - - def test_model_hub(self): - with mock.patch( - 'vertexai.generative_models.' - 'GenerativeModel.__init__' - ) as mock_model_init: - mock_model_init.side_effect = lambda *args, **kwargs: None - model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model( - 'gemini-1.0-pro' - ) - self.assertIsNotNone(model) - self.assertIs( - vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'), - model, - ) - - with mock.patch( - 'vertexai.language_models.' - 'TextGenerationModel.from_pretrained' - ) as mock_model_init: - - class TextGenerationModel: - pass - - mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel() - model = vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model( - 'text-bison' - ) - self.assertIsNotNone(model) - self.assertIs( - vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model('text-bison'), - model, - ) - - def test_project_and_location_check(self): - with self.assertRaisesRegex(ValueError, 'Please specify `project`'): - _ = vertexai.VertexAIGeminiPro1Vision()._api_initialized - - with self.assertRaisesRegex(ValueError, 'Please specify `location`'): - _ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized - - self.assertTrue( - vertexai.VertexAIGeminiPro1Vision( - project='abc', location='us-central1' - )._api_initialized - ) - - os.environ['VERTEXAI_PROJECT'] = 'abc' - os.environ['VERTEXAI_LOCATION'] = 'us-central1' - self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized) - del os.environ['VERTEXAI_PROJECT'] - del os.environ['VERTEXAI_LOCATION'] - - def test_generation_config(self): - model = vertexai.VertexAIGeminiPro1Vision() - json_schema = { - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - }, - 'required': ['name'], - 'title': 'Person', - } - config = model._generation_config( - lf.UserMessage('hi', json_schema=json_schema), - lf.LMSamplingOptions( - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop=['\n'], - ), - ) - actual = config.to_dict() - # There is a discrepancy between the `property_ordering` in the - # Google-internal version and the open-source version. - actual['response_schema'].pop('property_ordering', None) - if pg.KeyPath.parse('response_schema.type_').get(actual): - actual['response_schema']['type'] = actual['response_schema'].pop('type_') - if pg.KeyPath.parse('response_schema.properties.name.type_').get(actual): - actual['response_schema']['properties']['name']['type'] = actual[ - 'response_schema']['properties']['name'].pop('type_') - - self.assertEqual( - actual, - dict( - temperature=2.0, - top_p=1.0, - top_k=20.0, - max_output_tokens=1024, - stop_sequences=['\n'], - response_mime_type='application/json', - response_schema={ - 'type': 'OBJECT', - 'properties': { - 'name': {'type': 'STRING'} - }, - 'required': ['name'], - 'title': 'Person', - } - ), - ) - with self.assertRaisesRegex( - ValueError, '`json_schema` must be a dict, got' - ): - model._generation_config( - lf.UserMessage('hi', json_schema='not a dict'), - lf.LMSamplingOptions(), - ) - - def test_call_generative_model(self): - with mock.patch( - 'vertexai.generative_models.' - 'GenerativeModel.__init__' - ) as mock_model_init: - mock_model_init.side_effect = lambda *args, **kwargs: None - - with mock.patch( - 'vertexai.generative_models.' - 'GenerativeModel.generate_content' - ) as mock_generate: - mock_generate.side_effect = mock_generate_content - - lm = vertexai.VertexAIGeminiPro1Vision( - project='abc', location='us-central1' - ) - self.assertEqual( - lm( - 'hello', - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop='\n', - ).text, - ( - 'This is a response to hello with temperature=2.0, ' - 'top_p=1.0, top_k=20.0, max_tokens=1024, stop=\n.' - ), - ) - - def test_call_text_generation_model(self): - with mock.patch( - 'vertexai.language_models.' - 'TextGenerationModel.from_pretrained' - ) as mock_model_init: - - class TextGenerationModel: - - def predict(self, prompt, **kwargs): - c = pg.Dict(kwargs) - return pg.Dict( - text=( - f'This is a response to {prompt} with ' - f'temperature={c.temperature}, ' - f'top_p={c.top_p}, ' - f'top_k={c.top_k}, ' - f'max_tokens={c.max_output_tokens}, ' - f'stop={"".join(c.stop_sequences)}.' - ) - ) - - mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel() - lm = vertexai.VertexAIPalm2(project='abc', location='us-central1') - self.assertEqual( - lm( - 'hello', - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop='\n', - ).text, - ( - 'This is a response to hello with temperature=2.0, ' - 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' - ), - ) - - def test_call_endpoint_model(self): - with mock.patch( - 'google.cloud.aiplatform.models.Endpoint.__init__' - ) as mock_model_init: - mock_model_init.side_effect = lambda *args, **kwargs: None - with mock.patch( - 'google.cloud.aiplatform.models.Endpoint.predict' - ) as mock_model_predict: - - mock_model_predict.side_effect = mock_endpoint_predict - lm = vertexai.VertexAI( - 'custom', - endpoint_name='123', - project='abc', - location='us-central1', - ) - self.assertEqual( - lm( - 'hello', - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=50, - ), - 'This is a response to hello with temperature=2.0, top_p=1.0,' - ' top_k=20, max_tokens=50.', - ) - - -class VertexRestfulAITest(unittest.TestCase): """Tests for Vertex model with REST API.""" def test_content_from_message_text_only(self): @@ -376,9 +82,9 @@ def test_content_from_message_text_only(self): self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) def test_content_from_message_mm(self): + image = lf_modalities.Image.from_bytes(example_image) message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', - image=lf_modalities.Image.from_bytes(example_image), + 'This is an <<[[image]]>>, what is it?', image=image ) # Non-multimodal model. @@ -386,47 +92,25 @@ def test_content_from_message_mm(self): vertexai.VertexAIGeminiPro1()._content_from_message(message) model = vertexai.VertexAIGeminiPro1Vision() - chunks = model._content_from_message(message) - self.maxDiff = None - self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?']) - self.assertIsInstance(chunks[1], generative_models.Part) - - def test_generation_response_to_message_text_only(self): - response = generative_models.GenerationResponse.from_dict({ - 'candidates': [ - { - 'index': 0, - 'content': { - 'role': 'model', - 'parts': [ - { - 'text': 'hello world', - }, - ], + content = model._content_from_message(message) + self.assertEqual( + content, + { + 'role': 'user', + 'parts': [ + {'text': 'This is an'}, + { + 'inlineData': { + 'data': base64.b64encode(example_image).decode(), + 'mimeType': 'image/png', + } }, - }, - ], - }) - model = vertexai.VertexAIGeminiPro1Vision() - message = model._generation_response_to_message(response) - self.assertEqual(message, lf.AIMessage('hello world')) - - def test_model_hub(self): - with mock.patch( - 'vertexai.generative_models.' - 'GenerativeModel.__init__' - ) as mock_model_init: - mock_model_init.side_effect = lambda *args, **kwargs: None - model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model( - 'gemini-1.0-pro' - ) - self.assertIsNotNone(model) - self.assertIs( - vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'), - model, - ) + {'text': ', what is it?'}, + ], + }, + ) - @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True) + @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) def test_project_and_location_check(self): with self.assertRaisesRegex(ValueError, 'Please specify `project`'): _ = vertexai.VertexAIGeminiPro1()._api_initialized @@ -497,7 +181,7 @@ def test_generation_config(self): lf.LMSamplingOptions(), ) - @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True) + @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) def test_call_model(self): with mock.patch('requests.Session.post') as mock_generate: mock_generate.side_effect = mock_requests_post diff --git a/requirements.txt b/requirements.txt index 4577f83..440c8c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,6 @@ requests>=2.31.0 termcolor==1.1.0 tqdm>=4.64.1 -# extras:llm-google-vertex -google-cloud-aiplatform>=1.5.0 # extras:llm-google-genai google-generativeai>=0.3.2