From 071ce9caf14dc22bd290fd4f7ac12fe3844558da Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Mon, 21 Oct 2024 13:17:01 -0700 Subject: [PATCH] Refactor token handling and update headers (#787) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ refactor token handling and update headers * refactor: 🔄 update API version for chat completions * 🔒 Comment out unsupported Azure AI token logic * differenciate gptvsllama * ♻️ Update token scopes for Azure AI constants * ✨ Update Azure model deployment and candidates --- .../docs/getting-started/configuration.mdx | 115 ++++++++++++++++-- packages/cli/src/azuretoken.ts | 14 +-- packages/cli/src/nodehost.ts | 58 ++++++--- packages/core/src/aici.ts | 8 +- packages/core/src/connection.ts | 49 ++------ packages/core/src/constants.ts | 11 +- packages/core/src/fetch.ts | 3 +- packages/core/src/host.ts | 4 +- packages/core/src/openai.ts | 42 ++++--- 9 files changed, 207 insertions(+), 97 deletions(-) diff --git a/docs/src/content/docs/getting-started/configuration.mdx b/docs/src/content/docs/getting-started/configuration.mdx index 1a4346369f..3ac9cfdbfb 100644 --- a/docs/src/content/docs/getting-started/configuration.mdx +++ b/docs/src/content/docs/getting-started/configuration.mdx @@ -271,14 +271,14 @@ to try the Azure OpenAI service.
  • -Open your [Azure OpenAI resource](https://portal.azure.com) +Open your Azure OpenAI resource in the [Azure Portal](https://portal.azure.com)
  • -Navigate to **Access Control**, then **View My Access**. Make sure your +Navigate to **Access Control (IAM)**, then **View My Access**. Make sure your user or service principal has the **Cognitive Services OpenAI User/Contributor** role. -If you get a `401` error, it's typically here that you will fix it. +If you get a `401` error, click on **Add**, **Add role assignment** and add the **Cognitive Services OpenAI User** role to your user.
  • @@ -370,9 +370,107 @@ The rest of the steps are the same: Find the deployment name and use it in your The `azure_serverless` supports models in the Azure AI model catalog can be deployed as [a serverless API](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless-availability) with pay-as-you-go billing. This kind of deployment provides a way to consume models -as an API without hosting them on your subscription, -while keeping the enterprise security and compliance that organizations need. -This deployment option doesn't require quota from your subscription. +as an API without hosting them on your subscription, while keeping the enterprise security and compliance that organizations need. + +Note that the OpenAI models, like gpt-4o..., are deployed to `.openai.azure.com` endpoints, +while the Azure AI models are deployed to `.models.ai.azure.com` endpoints. +They are configured slightly differently. + +### Managed Identity (Entra ID) + + + +
      + +
    1. + +Open your **Azure AI Project** resource in the [Azure Portal](https://portal.azure.com) + +
    2. +
    3. + +Navigate to **Access Control (IAM)**, then **View My Access**. Make sure your +user or service principal has the **Azure AI Developer** role. +If you get a `401` error, click on **Add**, **Add role assignment** and add the **Azure AI Developer** role to your user. + +
    4. + +
    5. + +Open a terminal and **login** with [Azure CLI](https://learn.microsoft.com/en-us/javascript/api/overview/azure/identity-readme?view=azure-node-latest#authenticate-via-the-azure-cli). + +```sh +az login +``` + +
    6. + +
    7. + +Open https://ai.azure.com/ and open the **Deployments** page. + +
    8. + +
    9. + +Deploy a **base model** from the catalog. +You can use the `Deployment Options` -> `Serverless API` option to deploy a model as a serverless API. + +
    10. + +
    + +
    + +The OpenAI models (gpt-4o, ...) are deployed to `.openai.azure.com` endpoints, +the other models are deployed to `.models.ai.azure.com` endpoints. + +### `.models.ai.azure.com` endpoints + + + +
      + +
    1. + +Configure the **Endpoint Target URL** as the `AZURE_INFERENCE_ENDPOINT`. + +```txt title=".env" +AZURE_INFERENCE_ENDPOINT=https://...models.ai.azure.com +``` + +
    2. + +
    3. + +Navigate to **deployments** and make sure that you have your LLM deployed and copy the Deployment Info name, you will need it in the script. + +
    4. + +
    5. + +Update the `model` field in the `script` function to match the model deployment name in your Azure resource. + +```js 'model: "azure_serverless:deployment-info-name"' +script({ + model: "azure_serverless:deployment-info-name", + ... +}) +``` + +
    6. + +
    + +
    + +#### Support for multiple inference deployements + +For non-OpenAI models deployed on `.models.ai.azure.com`, +you can keep the same `AZURE_INFERENCE_ENDPOINT` and GenAIScript will automatically update the endpoint +with the deployment id name. + +For OpenAI models deployed on `.openai.azure.com`, you can also keep the same deployment name. ### API Key @@ -426,14 +524,15 @@ GENAISCRIPT_DEFAULT_SMALL_MODEL=azure_serverless: ::: -### Support for multiple inference deployements +#### Support for multiple inference deployements You can update the `AZURE_INFERENCE_CREDENTIAL` with a list of `deploymentid=key` pairs to support multiple deployments (each deployment has a different key). ```txt title=".env" AZURE_INFERENCE_CREDENTIAL=" model1=key1 -model2=key2model3=key3 +model2=key2 +model3=key3 " ``` diff --git a/packages/cli/src/azuretoken.ts b/packages/cli/src/azuretoken.ts index 15a6cb122f..cf2a148411 100644 --- a/packages/cli/src/azuretoken.ts +++ b/packages/cli/src/azuretoken.ts @@ -1,7 +1,4 @@ -import { - AZURE_OPENAI_TOKEN_EXPIRATION, - AZURE_OPENAI_TOKEN_SCOPES, -} from "../../core/src/constants" +import { AZURE_TOKEN_EXPIRATION } from "../../core/src/constants" import { logVerbose } from "../../core/src/util" /** @@ -41,15 +38,16 @@ export function isAzureTokenExpired(token: AuthenticationToken) { * Logs the expiration time of the token for debugging or informational purposes. */ export async function createAzureToken( - signal: AbortSignal + scopes: readonly string[], + abortSignal: AbortSignal ): Promise { // Dynamically import DefaultAzureCredential from the Azure SDK const { DefaultAzureCredential } = await import("@azure/identity") // Obtain the Azure token using the DefaultAzureCredential const azureToken = await new DefaultAzureCredential().getToken( - AZURE_OPENAI_TOKEN_SCOPES.slice(), - { abortSignal: signal } + scopes.slice(), + { abortSignal } ) // Prepare the result token object with the token and expiration timestamp @@ -58,7 +56,7 @@ export async function createAzureToken( // Use provided expiration timestamp or default to a constant expiration time expiresOnTimestamp: azureToken.expiresOnTimestamp ? azureToken.expiresOnTimestamp - : Date.now() + AZURE_OPENAI_TOKEN_EXPIRATION, + : Date.now() + AZURE_TOKEN_EXPIRATION, } // Log the expiration time of the token diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 8fcc0bcdf4..3c42f72d60 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -26,6 +26,9 @@ import { TOOL_ID, DEFAULT_EMBEDDINGS_MODEL, DEFAULT_SMALL_MODEL, + AZURE_OPENAI_TOKEN_SCOPES, + MODEL_PROVIDER_AZURE_SERVERLESS, + AZURE_AI_INFERENCE_TOKEN_SCOPES, } from "../../core/src/constants" import { tryReadText } from "../../core/src/fs" import { @@ -87,8 +90,8 @@ class ModelManager implements ModelService { const res = await fetch(`${conn.base}/api/pull`, { method: "POST", headers: { - "user-agent": TOOL_ID, - "content-type": "application/json", + "User-Agent": TOOL_ID, + "Content-Type": "application/json", }, body: JSON.stringify({ name: model, stream: false }, null, 2), }) @@ -159,7 +162,8 @@ export class NodeHost implements RuntimeHost { } clientLanguageModel: LanguageModel - private _azureToken: AuthenticationToken + private _azureOpenAIToken: AuthenticationToken + private _azureServerlessToken: AuthenticationToken async getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions @@ -168,25 +172,45 @@ export class NodeHost implements RuntimeHost { await this.parseDefaults() const tok = await parseTokenFromEnv(process.env, modelId) if (!askToken && tok?.token) tok.token = "***" - if ( - askToken && - tok && - !tok.token && - tok.provider === MODEL_PROVIDER_AZURE // MODEL_PROVIDER_AZURE_SERVERLESS does not support Entra yet - ) { - if (isAzureTokenExpired(this._azureToken)) { - logVerbose( - `fetching azure token (${this._azureToken?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(this._azureToken.expiresOnTimestamp).toLocaleString()}` : "not available"})` - ) - this._azureToken = await createAzureToken(signal) + if (askToken && tok && !tok.token) { + if ( + tok.provider === MODEL_PROVIDER_AZURE || + (tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS && + /\.openai\.azure\.com/i.test(tok.base)) + ) { + if (isAzureTokenExpired(this._azureOpenAIToken)) { + logVerbose( + `fetching Azure OpenAI token ${this._azureOpenAIToken?.expiresOnTimestamp >= Date.now() ? `(expired ${new Date(this._azureOpenAIToken.expiresOnTimestamp).toLocaleString()})` : ""}` + ) + this._azureOpenAIToken = await createAzureToken( + AZURE_OPENAI_TOKEN_SCOPES, + signal + ) + } + if (!this._azureOpenAIToken) + throw new Error("Azure OpenAI token not available") + tok.token = "Bearer " + this._azureOpenAIToken.token + } else if (tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS) { + if (isAzureTokenExpired(this._azureServerlessToken)) { + logVerbose( + `fetching Azure AI Infererence token ${this._azureServerlessToken?.expiresOnTimestamp >= Date.now() ? `(expired ${new Date(this._azureServerlessToken.expiresOnTimestamp).toLocaleString()})` : ""}` + ) + this._azureServerlessToken = await createAzureToken( + AZURE_AI_INFERENCE_TOKEN_SCOPES, + signal + ) + } + if (!this._azureServerlessToken) + throw new Error("Azure AI Inference token not available") + tok.token = "Bearer " + this._azureServerlessToken.token } - if (!this._azureToken) throw new Error("Azure token not available") - tok.token = "Bearer " + this._azureToken.token } if (!tok) { const { provider } = parseModelIdentifier(modelId) if (provider === MODEL_PROVIDER_AZURE) - throw new Error("Azure end point not configured") + throw new Error("Azure OpenAI end point not configured") + else if (provider === MODEL_PROVIDER_AZURE_SERVERLESS) + throw new Error("Azure AI Inference end point not configured") } if (!tok && this.clientLanguageModel) { return { diff --git a/packages/core/src/aici.ts b/packages/core/src/aici.ts index 6eacf95e30..1d78e58778 100644 --- a/packages/core/src/aici.ts +++ b/packages/core/src/aici.ts @@ -273,8 +273,8 @@ const AICIChatCompletion: ChatCompletionHandler = async ( const r = await fetchRetry(url, { headers: { "api-key": connection.token, - "user-agent": TOOL_ID, - "content-type": "application/json", + "User-Agent": TOOL_ID, + "Content-Type": "application/json", ...(headers || {}), }, body, @@ -426,8 +426,8 @@ async function listModels(cfg: LanguageModelConfiguration) { method: "GET", headers: { "api-key": token, - "user-agent": TOOL_ID, - accept: "application/json", + "User-Agent": TOOL_ID, + Accept: "application/json", }, }) if (res.status !== 200) return [] diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index cdc05f9fbe..9086ef5ed0 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -1,15 +1,7 @@ import { ANTHROPIC_API_BASE, + AZURE_AI_INFERENCE_VERSION, AZURE_OPENAI_API_VERSION, - DEFAULT_TEMPERATURE, - DOCS_CONFIGURATION_AICI_URL, - DOCS_CONFIGURATION_AZURE_OPENAI_URL, - DOCS_CONFIGURATION_GITHUB_URL, - DOCS_CONFIGURATION_LITELLM_URL, - DOCS_CONFIGURATION_LLAMAFILE_URL, - DOCS_CONFIGURATION_LOCALAI_URL, - DOCS_CONFIGURATION_OLLAMA_URL, - DOCS_CONFIGURATION_OPENAI_URL, DOT_ENV_FILENAME, GITHUB_MODELS_BASE, LITELLM_API_BASE, @@ -106,9 +98,6 @@ export async function parseTokenFromEnv( token, source: "env: OPENAI_API_...", version, - curlHeaders: { - Authorization: `Bearer $OPENAI_API_KEY`, - }, } } } @@ -129,9 +118,6 @@ export async function parseTokenFromEnv( type, token, source: `env: ${tokenVar}`, - curlHeaders: { - Authorization: `Bearer $${tokenVar}`, - }, } } @@ -175,30 +161,27 @@ export async function parseTokenFromEnv( ? "env: AZURE_OPENAI_API_..." : "env: AZURE_OPENAI_API_... + Entra ID", version, - curlHeaders: tokenVar - ? { - "api-key": `$${tokenVar}`, - } - : undefined, } } if (provider === MODEL_PROVIDER_AZURE_SERVERLESS) { + // https://github.com/Azure/azure-sdk-for-js/tree/@azure-rest/ai-inference_1.0.0-beta.2/sdk/ai/ai-inference-rest const tokenVar = "AZURE_INFERENCE_CREDENTIAL" const token = env[tokenVar]?.trim() - const base = trimTrailingSlash(env.AZURE_INFERENCE_ENDPOINT) + let base = trimTrailingSlash(env.AZURE_INFERENCE_ENDPOINT) if (!token && !base) return undefined if (token === PLACEHOLDER_API_KEY) throw new Error("AZURE_INFERENCE_CREDENTIAL not configured") if (!base) throw new Error("AZURE_INFERENCE_ENDPOINT missing") if (base === PLACEHOLDER_API_BASE) throw new Error("AZURE_INFERENCE_ENDPOINT not configured") + base = cleanAzureBase(base) if (!URL.canParse(base)) throw new Error("AZURE_INFERENCE_ENDPOINT must be a valid URL") const version = env.AZURE_INFERENCE_API_VERSION - if (version && version !== AZURE_OPENAI_API_VERSION) + if (version && version !== AZURE_AI_INFERENCE_VERSION) throw new Error( - `AZURE_INFERENCE_ENDPOINT must be '${AZURE_OPENAI_API_VERSION}'` + `AZURE_INFERENCE_API_VERSION must be '${AZURE_AI_INFERENCE_VERSION}'` ) return { provider, @@ -206,25 +189,22 @@ export async function parseTokenFromEnv( base, token, type: "azure_serverless", - source: "env: AZURE_INFERENCE_...", + source: token + ? "env: AZURE_INFERENCE_..." + : "env: AZURE_INFERENCE_... + Entra ID", version, - curlHeaders: tokenVar - ? { - "api-key": `$${tokenVar}`, - } - : undefined, } } if (provider === MODEL_PROVIDER_ANTHROPIC) { - const token = env.ANTHROPIC_API_KEY?.trim() + const modelKey = "ANTHROPIC_API_KEY" + const token = env[modelKey]?.trim() if (token === undefined || token === PLACEHOLDER_API_KEY) throw new Error("ANTHROPIC_API_KEY not configured") const base = trimTrailingSlash(env.ANTHROPIC_API_BASE) || ANTHROPIC_API_BASE const version = env.ANTHROPIC_API_VERSION || undefined const source = "env: ANTHROPIC_API_..." - const modelKey = "ANTHROPIC_API_KEY" return { provider, @@ -263,11 +243,6 @@ export async function parseTokenFromEnv( type, version, source, - curlHeaders: token - ? { - Authorization: `Bearer $${modelKey}`, - } - : undefined, } } } @@ -317,7 +292,7 @@ export async function parseTokenFromEnv( return undefined function cleanAzureBase(b: string) { - if (!b) return b + if (!b || !/\.openai\.azure\.com/i.test(b)) return b b = trimTrailingSlash(b.replace(/\/openai\/deployments.*$/, "")) + `/openai/deployments` diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index e8334c9acf..07f98d00bf 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -11,9 +11,12 @@ export const MAX_TOOL_CALLS = 10000 export const AZURE_OPENAI_API_VERSION = "2024-06-01" export const AZURE_OPENAI_TOKEN_SCOPES = Object.freeze([ "https://cognitiveservices.azure.com/.default", - "offline_access", ]) -export const AZURE_OPENAI_TOKEN_EXPIRATION = 59 * 60_000 // 59 minutes +export const AZURE_AI_INFERENCE_VERSION = "2024-08-01-preview" +export const AZURE_AI_INFERENCE_TOKEN_SCOPES = Object.freeze([ + "https://ml.azure.com/.default", +]) +export const AZURE_TOKEN_EXPIRATION = 59 * 60_000 // 59 minutes export const TOOL_ID = "genaiscript" export const GENAISCRIPT_FOLDER = "." + TOOL_ID @@ -52,14 +55,16 @@ export const LARGE_MODEL_ID = "large" export const DEFAULT_MODEL = "openai:gpt-4o" export const DEFAULT_MODEL_CANDIDATES = [ "azure:gpt-4o", + "azure-serverless:gpt-4o", DEFAULT_MODEL, "github:gpt-4o", - "client:gpt-4", "anthropic:claude-2", + "client:gpt-4", ] export const DEFAULT_SMALL_MODEL = "openai:gpt-4o-mini" export const DEFAULT_SMALL_MODEL_CANDIDATES = [ "azure:gpt-4o-mini", + "azure-serverless:gpt-4o-mini", DEFAULT_SMALL_MODEL, "github:gpt-4o-mini", "client:gpt-4-mini", diff --git a/packages/core/src/fetch.ts b/packages/core/src/fetch.ts index 72609c0001..2706ad1475 100644 --- a/packages/core/src/fetch.ts +++ b/packages/core/src/fetch.ts @@ -14,7 +14,7 @@ import { readText } from "./fs" /** * Creates a fetch function with retry logic. - * + * * This function wraps the `crossFetch` with retry capabilities based * on provided options. It allows configuring the number of retries, * delay between retries, and specific HTTP status codes to retry on. @@ -161,7 +161,6 @@ export function traceFetchPost( : "***") // Mask other authorization headers ) const cmd = `curl ${url} \\ --H "Content-Type: application/json" \\ ${Object.entries(headers) .map(([k, v]) => `-H "${k}: ${v}"`) .join("\\\n")} \\ diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index c694cc0d6f..c03773680f 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -1,4 +1,3 @@ -import { Embeddings } from "openai/resources/embeddings.mjs" import { CancellationToken } from "./cancellation" import { LanguageModel } from "./chat" import { Progress } from "./progress" @@ -33,9 +32,8 @@ export interface LanguageModelConfiguration { model: string base: string token: string - curlHeaders?: Record - type?: OpenAIAPIType source?: string + type?: OpenAIAPIType aici?: boolean version?: string } diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index a50d74b9b0..38c7dc59f6 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -1,6 +1,7 @@ import { logVerbose, normalizeInt, trimTrailingSlash } from "./util" import { LanguageModelConfiguration, host } from "./host" import { + AZURE_AI_INFERENCE_VERSION, AZURE_OPENAI_API_VERSION, MODEL_PROVIDER_OPENAI, TOOL_ID, @@ -33,7 +34,7 @@ export function getConfigHeaders(cfg: LanguageModelConfiguration) { } const res: Record = { // openai - authorization: /^Bearer /.test(cfg.token) + Authorization: /^Bearer /.test(cfg.token) ? token : token && (type === "openai" || @@ -46,7 +47,7 @@ export function getConfigHeaders(cfg: LanguageModelConfiguration) { token && !/^Bearer /.test(token) && type === "azure" ? token : undefined, - "user-agent": TOOL_ID, + "User-Agent": TOOL_ID, } for (const [k, v] of Object.entries(res)) if (v === undefined) delete res[k] return res @@ -69,7 +70,7 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( cancellationToken, inner, } = options - const { headers, ...rest } = requestOptions || {} + const { headers = {}, ...rest } = requestOptions || {} const { token, source, ...cfgNoToken } = cfg const { model } = parseModelIdentifier(req.model) const encoder = await resolveTokenEncoder(model) @@ -138,12 +139,22 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( model.replace(/\./g, "") + `/chat/completions?api-version=${AZURE_OPENAI_API_VERSION}` } else if (cfg.type === "azure_serverless") { - url = - trimTrailingSlash(cfg.base).replace( - /^https?:\/\/(?[^\.]+)\.(?[^\.]+)\.models\.ai\.azure\.com/i, - (m, deployment, region) => - `https://${r2.model}.${region}.models.ai.azure.com` - ) + `/chat/completions` + if (/\.models\.ai\.azure\.com/i.test(cfg.base)) + url = + trimTrailingSlash(cfg.base).replace( + /^https?:\/\/(?[^\.]+)\.(?[^\.]+)\.models\.ai\.azure\.com/i, + (m, deployment, region) => + `https://${r2.model}.${region}.models.ai.azure.com` + ) + + `/chat/completions?api-version=${AZURE_AI_INFERENCE_VERSION}` + else if (/\.openai\.azure\.com/i.test(cfg.base)) + url = + trimTrailingSlash(cfg.base) + + "/" + + model.replace(/\./g, "") + + `/chat/completions?api-version=${AZURE_AI_INFERENCE_VERSION}` + // https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-api?view=azureml-api-2&tabs=javascript#extensibility + ;(headers as any)["extra-parameters"] = "drop" delete r2.model } else throw new Error(`api type ${cfg.type} not supported`) @@ -159,16 +170,17 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( }) trace.dispatchChange() - traceFetchPost(trace, url, cfg.curlHeaders, postReq) + const fetchHeaders: HeadersInit = { + ...getConfigHeaders(cfg), + "Content-Type": "application/json", + ...(headers || {}), + } + traceFetchPost(trace, url, fetchHeaders as any, postReq) const body = JSON.stringify(postReq) let r: Response try { r = await fetchRetry(url, { - headers: { - ...getConfigHeaders(cfg), - "content-type": "application/json", - ...(headers || {}), - }, + headers: fetchHeaders, body, method: "POST", signal: toSignal(cancellationToken),