From c61970ac63af2196230c5ae9bf78e6333a81da1d Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Mon, 12 Aug 2024 09:51:30 -0700 Subject: [PATCH] expire azure token as needed (#615) --- packages/cli/src/azuretoken.ts | 14 ++++++++++++-- packages/cli/src/nodehost.ts | 11 +++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/packages/cli/src/azuretoken.ts b/packages/cli/src/azuretoken.ts index 79d95473cb..d71b4cc22b 100644 --- a/packages/cli/src/azuretoken.ts +++ b/packages/cli/src/azuretoken.ts @@ -1,10 +1,20 @@ import { AZURE_OPENAI_TOKEN_SCOPES } from "../../core/src/constants" -export async function createAzureToken(signal: AbortSignal): Promise { +export interface AuthenticationToken { + token: string + expiresOnTimestamp: number +} + +export async function createAzureToken( + signal: AbortSignal +): Promise { const { DefaultAzureCredential } = await import("@azure/identity") const azureToken = await new DefaultAzureCredential().getToken( AZURE_OPENAI_TOKEN_SCOPES.slice(), { abortSignal: signal } ) - return azureToken.token + return { + token: azureToken.token, + expiresOnTimestamp: azureToken.expiresOnTimestamp, + } } diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 7fbdd8817a..563c1f7aac 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -41,7 +41,7 @@ import { import { AbortSignalOptions, TraceOptions } from "../../core/src/trace" import { logVerbose, unique } from "../../core/src/util" import { parseModelIdentifier } from "../../core/src/models" -import { createAzureToken } from "./azuretoken" +import { AuthenticationToken, createAzureToken } from "./azuretoken" import { LanguageModel } from "../../core/src/chat" import { errorMessage } from "../../core/src/error" @@ -144,7 +144,7 @@ export class NodeHost implements RuntimeHost { } clientLanguageModel: LanguageModel - private _azureToken: string + private _azureToken: AuthenticationToken async getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions @@ -158,10 +158,13 @@ export class NodeHost implements RuntimeHost { !tok.token && tok.provider === MODEL_PROVIDER_AZURE ) { - if (!this._azureToken) + if ( + !this._azureToken || + this._azureToken.expiresOnTimestamp >= Date.now() + ) this._azureToken = await createAzureToken(signal) if (!this._azureToken) throw new Error("Azure token not available") - tok.token = "Bearer " + this._azureToken + tok.token = "Bearer " + this._azureToken.token } if (!tok && this.clientLanguageModel) { return {