From 0f8b55f785070e71019ed0b3fd438929531a701e Mon Sep 17 00:00:00 2001 From: msaaddev Date: Wed, 23 Oct 2024 17:10:34 +0200 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=93=A6=20NEW:=20Groq=20tools=20suppor?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/baseai/src/data/models.ts | 42 +++++++++++++++---- packages/baseai/src/dev/llms/call-groq.ts | 2 + .../src/dev/providers/groq/chatComplete.ts | 12 ++++++ .../src/dev/utils/add-tools-to-params.ts | 19 ++++----- .../baseai/src/dev/utils/has-tool-support.ts | 18 +------- 5 files changed, 58 insertions(+), 35 deletions(-) diff --git a/packages/baseai/src/data/models.ts b/packages/baseai/src/data/models.ts index 41759877..dc5334a6 100644 --- a/packages/baseai/src/data/models.ts +++ b/packages/baseai/src/data/models.ts @@ -421,43 +421,71 @@ export const modelsByProvider: ModelsByProviderInclCosts = { id: 'llama-3.1-70b-versatile', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama-3.1-8b-instant', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama3-70b-8192', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama3-8b-8192', provider: GROQ, promptCost: 0.05, - completionCost: 0.1 + completionCost: 0.1, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'mixtral-8x7b-32768', provider: GROQ, promptCost: 0.27, - completionCost: 0.27 + completionCost: 0.27, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } }, { id: 'gemma2-9b-it', provider: GROQ, promptCost: 0.2, - completionCost: 0.2 + completionCost: 0.2, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } }, { id: 'gemma-7b-it', provider: GROQ, promptCost: 0.07, - completionCost: 0.07 + completionCost: 0.07, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } } ], [GOOGLE]: [ diff --git a/packages/baseai/src/dev/llms/call-groq.ts b/packages/baseai/src/dev/llms/call-groq.ts index 78089721..92baf4e1 100644 --- a/packages/baseai/src/dev/llms/call-groq.ts +++ b/packages/baseai/src/dev/llms/call-groq.ts @@ -5,6 +5,7 @@ import transformToProviderRequest from '../utils/provider-handlers/transfrom-to- import { applyJsonModeIfEnabled, handleLlmError } from './utils'; import type { ModelParams } from 'types/providers'; import type { Message } from 'types/pipe'; +import { addToolsToParams } from '../utils/add-tools-to-params'; export async function callGroq({ pipe, @@ -24,6 +25,7 @@ export async function callGroq({ baseURL: 'https://api.groq.com/openai/v1' }); applyJsonModeIfEnabled(modelParams, pipe); + addToolsToParams(modelParams, pipe); // Transform params according to provider's format const transformedRequestParams = transformToProviderRequest({ diff --git a/packages/baseai/src/dev/providers/groq/chatComplete.ts b/packages/baseai/src/dev/providers/groq/chatComplete.ts index f5d125d3..fa52f2ca 100644 --- a/packages/baseai/src/dev/providers/groq/chatComplete.ts +++ b/packages/baseai/src/dev/providers/groq/chatComplete.ts @@ -38,5 +38,17 @@ export const GroqChatCompleteConfig: ProviderConfig = { default: 1, max: 1, min: 1 + }, + parallel_tool_calls: { + param: 'parallel_tool_calls', + default: false + }, + tool_choice: { + param: 'tool_choice', + default: 'none' + }, + tools: { + param: 'tools', + default: [] } }; diff --git a/packages/baseai/src/dev/utils/add-tools-to-params.ts b/packages/baseai/src/dev/utils/add-tools-to-params.ts index 6cbef242..e4d144e8 100644 --- a/packages/baseai/src/dev/utils/add-tools-to-params.ts +++ b/packages/baseai/src/dev/utils/add-tools-to-params.ts @@ -1,22 +1,19 @@ -import { getSupportedToolSettings, hasToolSupport } from './has-tool-support'; +import { hasModelToolSupport } from './has-tool-support'; import type { ModelParams } from 'types/providers'; export function addToolsToParams(modelParams: ModelParams, pipe: any) { if (!pipe.functions.length) return; // Check if the model supports tool calls - const hasToolCallSupport = hasToolSupport({ - modelName: pipe.model.name, - provider: pipe.model.provider - }); + const { hasToolChoiceSupport, hasParallelToolCallSupport } = + hasModelToolSupport({ + modelName: pipe.model.name, + provider: pipe.model.provider + }); - if (hasToolCallSupport) { - const { hasParallelToolCallSupport, hasToolChoiceSupport } = - getSupportedToolSettings({ - modelName: pipe.model.name, - provider: pipe.model.provider - }); + const hasToolSupport = hasToolChoiceSupport || hasParallelToolCallSupport; + if (hasToolSupport) { if (hasParallelToolCallSupport) { modelParams.parallel_tool_calls = pipe.model.parallel_tool_calls; } diff --git a/packages/baseai/src/dev/utils/has-tool-support.ts b/packages/baseai/src/dev/utils/has-tool-support.ts index b7537a56..04c0f0b9 100644 --- a/packages/baseai/src/dev/utils/has-tool-support.ts +++ b/packages/baseai/src/dev/utils/has-tool-support.ts @@ -1,6 +1,6 @@ import { modelsByProvider } from '@/data/models'; -export function hasToolSupport({ +export function hasModelToolSupport({ provider, modelName }: { @@ -10,23 +10,7 @@ export function hasToolSupport({ const toolSupportedModels = modelsByProvider[provider].filter( model => model.toolSupport ); - const hasToolCallSupport = toolSupportedModels - .flatMap(model => model.id) - .includes(modelName); - return hasToolCallSupport; -} - -export function getSupportedToolSettings({ - provider, - modelName -}: { - modelName: string; - provider: string; -}) { - const toolSupportedModels = modelsByProvider[provider].filter( - model => model.toolSupport - ); const providerModel = toolSupportedModels.find( model => model.id === modelName ); From 42e7b1333aa73e9de267f1531a5b5f183054da09 Mon Sep 17 00:00:00 2001 From: msaaddev Date: Wed, 23 Oct 2024 20:19:58 +0200 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dev/utils/provider-handlers/provider-response-handler.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts b/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts index 0a2a7dc0..8f0edb63 100644 --- a/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts +++ b/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts @@ -1,10 +1,10 @@ -import type { ModelParams } from '@/types/providers'; import { handleNonStreamingMode, handleStreamingMode } from './response-handler-utils'; import Providers from '@/dev/providers'; import { dlog } from '../dlog'; +import type { ModelParams } from 'types/providers'; /** * Handles various types of responses based on the specified parameters