diff --git a/apps/shinkai-desktop/src/pages/add-ai.tsx b/apps/shinkai-desktop/src/pages/add-ai.tsx index 403f36c16..6ed7bccd3 100644 --- a/apps/shinkai-desktop/src/pages/add-ai.tsx +++ b/apps/shinkai-desktop/src/pages/add-ai.tsx @@ -14,7 +14,6 @@ import { import { useAddLLMProvider } from '@shinkai_network/shinkai-node-state/v2/mutations/addLLMProvider/useAddLLMProvider'; import { Button, - ErrorMessage, Form, FormControl, FormField, @@ -30,6 +29,7 @@ import { TextField, } from '@shinkai_network/shinkai-ui'; import { cn } from '@shinkai_network/shinkai-ui/utils'; +import { Loader2 } from 'lucide-react'; import { useEffect, useState } from 'react'; import { useForm } from 'react-hook-form'; import { useNavigate } from 'react-router-dom'; @@ -114,12 +114,7 @@ const AddAIPage = () => { } }, [addAgentForm, preSelectedAiProvider]); - const { - mutateAsync: addLLMProvider, - isPending, - isError, - error, - } = useAddLLMProvider({ + const { mutateAsync: addLLMProvider, isPending } = useAddLLMProvider({ onSuccess: (_, variables) => { navigate('/inboxes', { state: { @@ -247,11 +242,37 @@ const AddAIPage = () => { toolkit_permissions: [], model, }, + enableTest: false, + }); + }; + const handleTestAndSave = async (data: AddAgentFormSchema) => { + if (!auth) return; + let model = getModelObject(data.model, data.modelType); + if (isCustomModelMode && data.modelCustom && data.modelTypeCustom) { + model = getModelObject(data.modelCustom, data.modelTypeCustom); + } else if (isCustomModelType && data.modelTypeCustom) { + model = getModelObject(data.model, data.modelTypeCustom); + } + await addLLMProvider({ + nodeAddress: auth?.node_address ?? '', + token: auth?.api_v2_key ?? '', + agent: { + allowed_message_senders: [], + api_key: data.apikey, + external_url: data.externalUrl, + full_identity_name: `${auth.shinkai_identity}/${auth.profile}/agent/${data.agentName}`, + id: data.agentName, + perform_locally: false, + storage_bucket_permissions: [], + toolkit_permissions: [], + model, + }, + enableTest: true, }); }; return ( - +
{ /> - {isError && } - - + {isPending ? ( +
+ +
+ ) : ( +
+ + +
+ )}
diff --git a/apps/shinkai-desktop/src/pages/ais.tsx b/apps/shinkai-desktop/src/pages/ais.tsx index be68dc8d8..e975e3886 100644 --- a/apps/shinkai-desktop/src/pages/ais.tsx +++ b/apps/shinkai-desktop/src/pages/ais.tsx @@ -68,10 +68,10 @@ const AIsPage = () => { const onAddAgentClick = () => { if (isLocalShinkaiNodeIsUse) { - navigate('/local-ais'); return; } - navigate('/add-ai'); + navigate('/local-ais'); + // navigate('/add-ai'); }; return ( diff --git a/libs/shinkai-message-ts/src/api/jobs/index.ts b/libs/shinkai-message-ts/src/api/jobs/index.ts index aad57ca8a..dbc39a112 100644 --- a/libs/shinkai-message-ts/src/api/jobs/index.ts +++ b/libs/shinkai-message-ts/src/api/jobs/index.ts @@ -241,6 +241,21 @@ export const addLLMProvider = async ( ); return response.data as AddLLMProviderResponse; }; +export const testLLMProvider = async ( + nodeAddress: string, + bearerToken: string, + payload: AddLLMProviderRequest, +) => { + const response = await httpClient.post( + urlJoin(nodeAddress, '/v2/test_llm_provider'), + { ...payload, model: getModelString(payload.model) }, + { + headers: { Authorization: `Bearer ${bearerToken}` }, + responseType: 'json', + }, + ); + return response.data as AddLLMProviderResponse; +}; export const updateLLMProvider = async ( nodeAddress: string, diff --git a/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/index.ts b/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/index.ts index 6eda8ebde..2d987c72f 100644 --- a/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/index.ts +++ b/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/index.ts @@ -1,4 +1,7 @@ -import { addLLMProvider as addLLMProviderAPI } from '@shinkai_network/shinkai-message-ts/api/jobs/index'; +import { + addLLMProvider as addLLMProviderAPI, + testLLMProvider, +} from '@shinkai_network/shinkai-message-ts/api/jobs/index'; import { AddLLMProviderInput } from './types'; @@ -6,7 +9,11 @@ export const addLLMProvider = async ({ nodeAddress, token, agent, + enableTest, }: AddLLMProviderInput) => { + if (!agent.model.Ollama && enableTest) { + await testLLMProvider(nodeAddress, token, agent); + } const data = await addLLMProviderAPI(nodeAddress, token, agent); return data; }; diff --git a/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/types.ts b/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/types.ts index 3f256f190..c687fb66e 100644 --- a/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/types.ts +++ b/libs/shinkai-node-state/src/v2/mutations/addLLMProvider/types.ts @@ -7,5 +7,6 @@ import { export type AddLLMProviderInput = Token & { nodeAddress: string; agent: SerializedLLMProvider; + enableTest?: boolean; }; export type AddLLMProviderOutput = AddLLMProviderResponse; diff --git a/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/index.ts b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/index.ts new file mode 100644 index 000000000..55825fb94 --- /dev/null +++ b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/index.ts @@ -0,0 +1,19 @@ +import { + addLLMProvider as addLLMProviderAPI, + testLLMProvider, +} from '@shinkai_network/shinkai-message-ts/api/jobs/index'; + +import { AddLLMProviderInput } from './types'; + +export const addLLMProvider = async ({ + nodeAddress, + token, + agent, +}: AddLLMProviderInput) => { + if (!agent.model.Ollama) { + await testLLMProvider(nodeAddress, token, agent); + } + + const data = await addLLMProviderAPI(nodeAddress, token, agent); + return data; +}; diff --git a/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/types.ts b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/types.ts new file mode 100644 index 000000000..3f256f190 --- /dev/null +++ b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/types.ts @@ -0,0 +1,11 @@ +import { Token } from '@shinkai_network/shinkai-message-ts/api/general/types'; +import { + AddLLMProviderResponse, + SerializedLLMProvider, +} from '@shinkai_network/shinkai-message-ts/api/jobs/types'; + +export type AddLLMProviderInput = Token & { + nodeAddress: string; + agent: SerializedLLMProvider; +}; +export type AddLLMProviderOutput = AddLLMProviderResponse; diff --git a/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/useAddLLMProvider.ts b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/useAddLLMProvider.ts new file mode 100644 index 000000000..749108de8 --- /dev/null +++ b/libs/shinkai-node-state/src/v2/mutations/testLLMProvider/useAddLLMProvider.ts @@ -0,0 +1,19 @@ +import type { UseMutationOptions } from '@tanstack/react-query'; +import { useMutation } from '@tanstack/react-query'; + +import { APIError } from '../../types'; +import { addLLMProvider } from '.'; +import { AddLLMProviderInput, AddLLMProviderOutput } from './types'; + +type Options = UseMutationOptions< + AddLLMProviderOutput, + APIError, + AddLLMProviderInput +>; + +export const useAddLLMProvider = (options?: Options) => { + return useMutation({ + mutationFn: addLLMProvider, + ...options, + }); +};