diff --git a/src/types/aiApiClient/index.ts b/src/types/aiApiClient/index.ts index c4f0221..b847967 100644 --- a/src/types/aiApiClient/index.ts +++ b/src/types/aiApiClient/index.ts @@ -10,5 +10,7 @@ export * from './types/TelemetryData'; export * from './types/ApiClient'; export * from './enums/commandSource'; export * from './utils/Completion/Completion'; +export * from './utils/sendPromptToLLMUtils'; +export * from './utils/typeGuards'; export * from './Completions/AiCompletion'; export * from './Completions/InlineCompletions'; diff --git a/src/types/aiApiClient/utils/sendPromptToLLMUtils.ts b/src/types/aiApiClient/utils/sendPromptToLLMUtils.ts new file mode 100644 index 0000000..c210a57 --- /dev/null +++ b/src/types/aiApiClient/utils/sendPromptToLLMUtils.ts @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024, salesforce.com, inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ +import { typeGuards, ChatStream } from './typeGuards'; + +/** + * @description: parses available text from the generation + * and determines if the stream is finished. + */ +export function processGeneration(chunk: unknown): { + done: boolean; + text: string; +} { + const processedGeneration = { + done: false, + text: '' + }; + + try { + if (typeGuards.isChatStream(chunk)) { + processedGeneration.done = isTerminationEvent(chunk); + processedGeneration.text = getText(chunk); + return processedGeneration; + } else { + processedGeneration.done = true; + return processedGeneration; + } + } catch (error) { + processedGeneration.done = true; + return processedGeneration; + } +} + +/** + * Check if this is the last chunk we should process. + * Returns true if <|endofprompt|> is in the text and/or + * the finish_reason parameter is populated. + */ +function isTerminationEvent(chunk: ChatStream): boolean { + try { + const isTerminationTokenInResponse = + chunk.data.generations[0].text.includes('<|endofprompt|>'); + const doesEventContainFinishReason = + chunk.data.generations[0].parameters?.finish_reason; + + return isTerminationTokenInResponse || !!doesEventContainFinishReason; + } catch (error) { + console.log(error, 'Error determining isTerminationEvent'); + return true; + } +} + +/** + * Parse through the chunk to get the text of the generation and + * remove <|endofprompt|> token from rawMessage + */ +function getText(chunk: ChatStream): string { + let text = ''; + try { + const generationText = chunk.data.generations[0].text; + text += generationText; + text = text.replace('<|endofprompt|>', ''); + return text; + } catch (error) { + console.log(error, 'Error getting stream text'); + return text; + } +} diff --git a/src/types/aiApiClient/utils/typeGuards.ts b/src/types/aiApiClient/utils/typeGuards.ts new file mode 100644 index 0000000..7110d94 --- /dev/null +++ b/src/types/aiApiClient/utils/typeGuards.ts @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024, salesforce.com, inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +export const typeGuards = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + isChatStream(obj: any): obj is ChatStream { + return ( + obj && + obj.event === 'generation' && + typeof obj.data === 'object' && + 'generations' in obj.data && + Array.isArray(obj.data.generations) && + obj.data.generations.length > 0 && + typeof obj.data.generations[0] === 'object' && + 'text' in obj.data.generations[0] + ); + } +}; + +export interface ChatStream { + event: 'generation'; + data: { + id: string; + generations: { + id: string; + text: string; + parameters?: { + token_logprobs: number; + token_id: number; + finish_reason?: string; + }; + generation_safety_score: number; + generation_content_quality: unknown; + }[]; + prompt: string | null; + input_safety_score: number | null; + input_bias_score: number | null; + parameters: unknown; + }; +}