Skip to content

Commit

Permalink
Enable Google authentication for VertexAI Anthropic models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694354376
  • Loading branch information
daiyip authored and langfun authors committed Nov 8, 2024
1 parent 3f6313c commit 7689d97
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 35 deletions.
64 changes: 43 additions & 21 deletions langfun/core/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Language models from Anthropic."""

import base64
import functools
import os
from typing import Annotated, Any, Literal

Expand All @@ -23,6 +24,20 @@
import pyglove as pg


try:
# pylint: disable=g-import-not-at-top
from google import auth as google_auth
from google.auth import credentials as credentials_lib
from google.auth.transport import requests as auth_requests
Credentials = credentials_lib.Credentials
# pylint: enable=g-import-not-at-top
except ImportError:
google_auth = None
auth_requests = None
credentials_lib = None
Credentials = None # pylint: disable=invalid-name


SUPPORTED_MODELS_AND_SETTINGS = {
# See https://docs.anthropic.com/claude/docs/models-overview
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
Expand Down Expand Up @@ -360,15 +375,6 @@ class ClaudeInstant(Anthropic):
class VertexAIAnthropic(Anthropic):
"""Anthropic models on VertexAI."""

access_token: Annotated[
str | None,
(
'Google Cloud access token. If None, it will be read from '
'environment variable \'GCP_ACCESS_TOKEN\'. '
'Get it by running `gcloud auth print-access-token`.'
)
] = None

project: Annotated[
str | None,
'Google Cloud project ID.',
Expand All @@ -379,12 +385,24 @@ class VertexAIAnthropic(Anthropic):
'GCP location with Anthropic models hosted.'
] = 'us-east5'

credentials: Annotated[
Credentials | None, # pytype: disable=invalid-annotation
(
'Credentials to use. If None, the default credentials '
'to the environment will be used.'
),
] = None

api_version = 'vertex-2023-10-16'

def _on_bound(self):
super()._on_bound()
if google_auth is None:
raise ValueError(
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
)
self._project = None
self._access_token = None
self._credentials = None

def _initialize(self):
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
Expand All @@ -394,22 +412,26 @@ def _initialize(self):
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
)
self._project = project

access_token = self.access_token or os.environ.get(
'GCP_ACCESS_TOKEN', None
)
if not access_token:
raise ValueError(
'Please specify `access_token` during `__init__` or set environment '
'variable `GCP_ACCESS_TOKEN` with the output of '
'`gcloud auth print-access-token`.'
credentials = self.credentials
if credentials is None:
# Use default credentials.
credentials = google_auth.default(
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
self._access_token = access_token
self._credentials = credentials

@functools.cached_property
def _session(self):
assert self._api_initialized
assert self._credentials is not None
assert auth_requests is not None
s = auth_requests.AuthorizedSession(self._credentials)
s.headers.update(self.headers or {})
return s

@property
def headers(self):
return {
'Authorization': f'Bearer {self._access_token}',
'Content-Type': 'application/json; charset=utf-8',
}

Expand Down
25 changes: 11 additions & 14 deletions langfun/core/llms/anthropic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any
import unittest
from unittest import mock

from google.auth import exceptions
from langfun.core import language_model
from langfun.core import message as lf_message
from langfun.core import modalities as lf_modalities
Expand Down Expand Up @@ -192,14 +194,16 @@ def test_basics(self):
lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
lm('hi')

with self.assertRaisesRegex(ValueError, 'Please specify `access_token`'):
lm = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
lm('hi')
model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')

# NOTE(daiyip): For OSS users, default credentials are not available unless
# users have already set up their GCP project. Therefore we ignore the
# exception here.
try:
model._initialize()
except exceptions.DefaultCredentialsError:
pass

model = anthropic.VertexAIClaude3_5_Sonnet_20241022(
project='langfun', access_token='my_token'
)
model._initialize()
self.assertEqual(
model.api_endpoint,
(
Expand All @@ -208,13 +212,6 @@ def test_basics(self):
'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
)
)
self.assertEqual(
model.headers,
{
'Authorization': 'Bearer my_token',
'Content-Type': 'application/json; charset=utf-8',
}
)
request = model.request(
lf_message.UserMessage('hi'),
language_model.LMSamplingOptions(temperature=0.0),
Expand Down

0 comments on commit 7689d97

Please sign in to comment.