Skip to content

Commit

Permalink
Refactor token handling and update headers (#787)
Browse files Browse the repository at this point in the history
* ✨ refactor token handling and update headers

* refactor: πŸ”„ update API version for chat completions

* πŸ”’ Comment out unsupported Azure AI token logic

* differenciate gptvsllama

* ♻️ Update token scopes for Azure AI constants

* ✨ Update Azure model deployment and candidates
  • Loading branch information
pelikhan authored Oct 21, 2024
1 parent 25e80a8 commit 071ce9c
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 97 deletions.
115 changes: 107 additions & 8 deletions docs/src/content/docs/getting-started/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,14 @@ to try the Azure OpenAI service.

<li>

Open your [Azure OpenAI resource](https://portal.azure.com)
Open your Azure OpenAI resource in the [Azure Portal](https://portal.azure.com)

</li>
<li>

Navigate to **Access Control**, then **View My Access**. Make sure your
Navigate to **Access Control (IAM)**, then **View My Access**. Make sure your
user or service principal has the **Cognitive Services OpenAI User/Contributor** role.
If you get a `401` error, it's typically here that you will fix it.
If you get a `401` error, click on **Add**, **Add role assignment** and add the **Cognitive Services OpenAI User** role to your user.

</li>
<li>
Expand Down Expand Up @@ -370,9 +370,107 @@ The rest of the steps are the same: Find the deployment name and use it in your

The `azure_serverless` supports models in the Azure AI model catalog can be deployed as [a serverless API](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless-availability) with pay-as-you-go billing.
This kind of deployment provides a way to consume models
as an API without hosting them on your subscription,
while keeping the enterprise security and compliance that organizations need.
This deployment option doesn't require quota from your subscription.
as an API without hosting them on your subscription, while keeping the enterprise security and compliance that organizations need.

Note that the OpenAI models, like gpt-4o..., are deployed to `.openai.azure.com` endpoints,
while the Azure AI models are deployed to `.models.ai.azure.com` endpoints.
They are configured slightly differently.

### Managed Identity (Entra ID)

<Steps>

<ol>

<li>

Open your **Azure AI Project** resource in the [Azure Portal](https://portal.azure.com)

</li>
<li>

Navigate to **Access Control (IAM)**, then **View My Access**. Make sure your
user or service principal has the **Azure AI Developer** role.
If you get a `401` error, click on **Add**, **Add role assignment** and add the **Azure AI Developer** role to your user.

</li>

<li>

Open a terminal and **login** with [Azure CLI](https://learn.microsoft.com/en-us/javascript/api/overview/azure/identity-readme?view=azure-node-latest#authenticate-via-the-azure-cli).

```sh
az login
```

</li>

<li>

Open https://ai.azure.com/ and open the **Deployments** page.

</li>

<li>

Deploy a **base model** from the catalog.
You can use the `Deployment Options` -> `Serverless API` option to deploy a model as a serverless API.

</li>

</ol>

</Steps>

The OpenAI models (gpt-4o, ...) are deployed to `.openai.azure.com` endpoints,
the other models are deployed to `.models.ai.azure.com` endpoints.

### `.models.ai.azure.com` endpoints

<Steps>

<ol>

<li>

Configure the **Endpoint Target URL** as the `AZURE_INFERENCE_ENDPOINT`.

```txt title=".env"
AZURE_INFERENCE_ENDPOINT=https://...models.ai.azure.com
```

</li>

<li>

Navigate to **deployments** and make sure that you have your LLM deployed and copy the Deployment Info name, you will need it in the script.

</li>

<li>

Update the `model` field in the `script` function to match the model deployment name in your Azure resource.

```js 'model: "azure_serverless:deployment-info-name"'
script({
model: "azure_serverless:deployment-info-name",
...
})
```

</li>

</ol>

</Steps>

#### Support for multiple inference deployements

For non-OpenAI models deployed on `.models.ai.azure.com`,
you can keep the same `AZURE_INFERENCE_ENDPOINT` and GenAIScript will automatically update the endpoint
with the deployment id name.

For OpenAI models deployed on `.openai.azure.com`, you can also keep the same deployment name.

### API Key

Expand Down Expand Up @@ -426,14 +524,15 @@ GENAISCRIPT_DEFAULT_SMALL_MODEL=azure_serverless:<deploymentid>

:::

### Support for multiple inference deployements
#### Support for multiple inference deployements

You can update the `AZURE_INFERENCE_CREDENTIAL` with a list of `deploymentid=key` pairs to support multiple deployments (each deployment has a different key).

```txt title=".env"
AZURE_INFERENCE_CREDENTIAL="
model1=key1
model2=key2model3=key3
model2=key2
model3=key3
"
```

Expand Down
14 changes: 6 additions & 8 deletions packages/cli/src/azuretoken.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import {
AZURE_OPENAI_TOKEN_EXPIRATION,
AZURE_OPENAI_TOKEN_SCOPES,
} from "../../core/src/constants"
import { AZURE_TOKEN_EXPIRATION } from "../../core/src/constants"
import { logVerbose } from "../../core/src/util"

/**
Expand Down Expand Up @@ -41,15 +38,16 @@ export function isAzureTokenExpired(token: AuthenticationToken) {
* Logs the expiration time of the token for debugging or informational purposes.
*/
export async function createAzureToken(
signal: AbortSignal
scopes: readonly string[],
abortSignal: AbortSignal
): Promise<AuthenticationToken> {
// Dynamically import DefaultAzureCredential from the Azure SDK
const { DefaultAzureCredential } = await import("@azure/identity")

// Obtain the Azure token using the DefaultAzureCredential
const azureToken = await new DefaultAzureCredential().getToken(
AZURE_OPENAI_TOKEN_SCOPES.slice(),
{ abortSignal: signal }
scopes.slice(),
{ abortSignal }
)

// Prepare the result token object with the token and expiration timestamp
Expand All @@ -58,7 +56,7 @@ export async function createAzureToken(
// Use provided expiration timestamp or default to a constant expiration time
expiresOnTimestamp: azureToken.expiresOnTimestamp
? azureToken.expiresOnTimestamp
: Date.now() + AZURE_OPENAI_TOKEN_EXPIRATION,
: Date.now() + AZURE_TOKEN_EXPIRATION,
}

// Log the expiration time of the token
Expand Down
58 changes: 41 additions & 17 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import {
TOOL_ID,
DEFAULT_EMBEDDINGS_MODEL,
DEFAULT_SMALL_MODEL,
AZURE_OPENAI_TOKEN_SCOPES,
MODEL_PROVIDER_AZURE_SERVERLESS,
AZURE_AI_INFERENCE_TOKEN_SCOPES,
} from "../../core/src/constants"
import { tryReadText } from "../../core/src/fs"
import {
Expand Down Expand Up @@ -87,8 +90,8 @@ class ModelManager implements ModelService {
const res = await fetch(`${conn.base}/api/pull`, {
method: "POST",
headers: {
"user-agent": TOOL_ID,
"content-type": "application/json",
"User-Agent": TOOL_ID,
"Content-Type": "application/json",
},
body: JSON.stringify({ name: model, stream: false }, null, 2),
})
Expand Down Expand Up @@ -159,7 +162,8 @@ export class NodeHost implements RuntimeHost {
}
clientLanguageModel: LanguageModel

private _azureToken: AuthenticationToken
private _azureOpenAIToken: AuthenticationToken
private _azureServerlessToken: AuthenticationToken
async getLanguageModelConfiguration(
modelId: string,
options?: { token?: boolean } & AbortSignalOptions & TraceOptions
Expand All @@ -168,25 +172,45 @@ export class NodeHost implements RuntimeHost {
await this.parseDefaults()
const tok = await parseTokenFromEnv(process.env, modelId)
if (!askToken && tok?.token) tok.token = "***"
if (
askToken &&
tok &&
!tok.token &&
tok.provider === MODEL_PROVIDER_AZURE // MODEL_PROVIDER_AZURE_SERVERLESS does not support Entra yet
) {
if (isAzureTokenExpired(this._azureToken)) {
logVerbose(
`fetching azure token (${this._azureToken?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(this._azureToken.expiresOnTimestamp).toLocaleString()}` : "not available"})`
)
this._azureToken = await createAzureToken(signal)
if (askToken && tok && !tok.token) {
if (
tok.provider === MODEL_PROVIDER_AZURE ||
(tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS &&
/\.openai\.azure\.com/i.test(tok.base))
) {
if (isAzureTokenExpired(this._azureOpenAIToken)) {
logVerbose(
`fetching Azure OpenAI token ${this._azureOpenAIToken?.expiresOnTimestamp >= Date.now() ? `(expired ${new Date(this._azureOpenAIToken.expiresOnTimestamp).toLocaleString()})` : ""}`
)
this._azureOpenAIToken = await createAzureToken(
AZURE_OPENAI_TOKEN_SCOPES,
signal
)
}
if (!this._azureOpenAIToken)
throw new Error("Azure OpenAI token not available")
tok.token = "Bearer " + this._azureOpenAIToken.token
} else if (tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS) {
if (isAzureTokenExpired(this._azureServerlessToken)) {
logVerbose(
`fetching Azure AI Infererence token ${this._azureServerlessToken?.expiresOnTimestamp >= Date.now() ? `(expired ${new Date(this._azureServerlessToken.expiresOnTimestamp).toLocaleString()})` : ""}`
)
this._azureServerlessToken = await createAzureToken(
AZURE_AI_INFERENCE_TOKEN_SCOPES,
signal
)
}
if (!this._azureServerlessToken)
throw new Error("Azure AI Inference token not available")
tok.token = "Bearer " + this._azureServerlessToken.token
}
if (!this._azureToken) throw new Error("Azure token not available")
tok.token = "Bearer " + this._azureToken.token
}
if (!tok) {
const { provider } = parseModelIdentifier(modelId)
if (provider === MODEL_PROVIDER_AZURE)
throw new Error("Azure end point not configured")
throw new Error("Azure OpenAI end point not configured")
else if (provider === MODEL_PROVIDER_AZURE_SERVERLESS)
throw new Error("Azure AI Inference end point not configured")
}
if (!tok && this.clientLanguageModel) {
return <LanguageModelConfiguration>{
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/aici.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ const AICIChatCompletion: ChatCompletionHandler = async (
const r = await fetchRetry(url, {
headers: {
"api-key": connection.token,
"user-agent": TOOL_ID,
"content-type": "application/json",
"User-Agent": TOOL_ID,
"Content-Type": "application/json",
...(headers || {}),
},
body,
Expand Down Expand Up @@ -426,8 +426,8 @@ async function listModels(cfg: LanguageModelConfiguration) {
method: "GET",
headers: {
"api-key": token,
"user-agent": TOOL_ID,
accept: "application/json",
"User-Agent": TOOL_ID,
Accept: "application/json",
},
})
if (res.status !== 200) return []
Expand Down
Loading

0 comments on commit 071ce9c

Please sign in to comment.