From df0237dda0b9f66a25ef80d1b78849eb68e6f0ea Mon Sep 17 00:00:00 2001 From: dosco <832235+dosco@users.noreply.github.com> Date: Sun, 3 Nov 2024 00:53:04 -0700 Subject: [PATCH] fix: gemini function calling --- src/ax/ai/google-gemini/api.ts | 54 +++++++++++++++++--------------- src/ax/ai/google-gemini/types.ts | 4 +-- src/examples/food-search.ts | 15 +++++---- 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/src/ax/ai/google-gemini/api.ts b/src/ax/ai/google-gemini/api.ts index 804f693..691a44a 100644 --- a/src/ax/ai/google-gemini/api.ts +++ b/src/ax/ai/google-gemini/api.ts @@ -208,38 +208,40 @@ export class AxAIGoogleGemini extends AxBaseAI< } case 'assistant': { - if ('content' in msg && typeof msg.content === 'string') { - const parts: Extract< - AxAIGoogleGeminiChatRequest['contents'][0], - { role: 'model' } - >['parts'] = [{ text: msg.content }]; + let parts: Extract< + AxAIGoogleGeminiChatRequest['contents'][0], + { role: 'model' } + >['parts'] = []; + + if (msg.functionCalls) { + parts = msg.functionCalls.map((f) => { + const args = + typeof f.function.params === 'string' + ? JSON.parse(f.function.params) + : f.function.params; + return { + functionCall: { + name: f.function.name, + args: args + } + }; + }); + + if (!parts) { + throw new Error('Function call is empty'); + } + return { role: 'model' as const, parts }; } - let parts: Extract< - AxAIGoogleGeminiChatRequest['contents'][0], - { role: 'model' } - >['parts'] = []; - - if ('functionCalls' in msg) { - parts = - msg.functionCalls?.map((f) => { - const args = - typeof f.function.params === 'string' - ? JSON.parse(f.function.params) - : f.function.params; - return { - functionCall: { - name: f.function.name, - args: args - } - }; - }) ?? []; + if (!msg.content) { + throw new Error('Assistant content is empty'); } + parts = [{ text: msg.content }]; return { role: 'model' as const, parts @@ -276,11 +278,11 @@ export class AxAIGoogleGemini extends AxBaseAI< let tools: AxAIGoogleGeminiChatRequest['tools'] | undefined = []; if (req.functions && req.functions.length > 0) { - tools.push({ functionDeclarations: req.functions }); + tools.push({ function_declarations: req.functions }); } if (this.options?.codeExecution) { - tools.push({ codeExecution: {} }); + tools.push({ code_execution: {} }); } if (tools.length === 0) { diff --git a/src/ax/ai/google-gemini/types.ts b/src/ax/ai/google-gemini/types.ts index dfb4bdd..9944974 100644 --- a/src/ax/ai/google-gemini/types.ts +++ b/src/ax/ai/google-gemini/types.ts @@ -79,8 +79,8 @@ export type AxAIGoogleGeminiToolFunctionDeclaration = { }; export type AxAIGoogleGeminiTool = { - functionDeclarations?: AxAIGoogleGeminiToolFunctionDeclaration[]; - codeExecution?: object; + function_declarations?: AxAIGoogleGeminiToolFunctionDeclaration[]; + code_execution?: object; }; export type AxAIGoogleGeminiToolConfig = { diff --git a/src/examples/food-search.ts b/src/examples/food-search.ts index 6ea6570..0d0c204 100644 --- a/src/examples/food-search.ts +++ b/src/examples/food-search.ts @@ -140,10 +140,14 @@ const functions: AxFunction[] = [ } ]; +// const ai = new AxAI({ +// name: 'openai', +// apiKey: process.env.OPENAI_APIKEY as string +// }); + const ai = new AxAI({ - name: 'openai', - apiKey: process.env.OPENAI_APIKEY as string - // config: { model: AxAIOpenAIModel.GPT4OMini } + name: 'google-gemini', + apiKey: process.env.GOOGLE_APIKEY as string }); // ai.setOptions({ debug: true }); @@ -158,11 +162,6 @@ const ai = new AxAI({ // apiKey: process.env.COHERE_APIKEY as string // }); -// const ai = new AxAI({ -// name: 'google-gemini', -// apiKey: process.env.GOOGLE_APIKEY as string -// }); - // const ai = new AxAI({ // name: 'anthropic', // apiKey: process.env.ANTHROPIC_APIKEY as string