From a4431f2caad973ff53b7f616393dd1ed8ecea34c Mon Sep 17 00:00:00 2001 From: BrandonStudio <55647556+BrandonStudio@users.noreply.github.com> Date: Sun, 10 Nov 2024 01:14:46 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20disregard=20remoteModelCa?= =?UTF-8?q?rds=20when=20showModelFetcher=20is=20disabled=20(#4644)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/store/user/slices/modelList/action.ts | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/store/user/slices/modelList/action.ts b/src/store/user/slices/modelList/action.ts index eb505402bcae..b1b1cb3670ff 100644 --- a/src/store/user/slices/modelList/action.ts +++ b/src/store/user/slices/modelList/action.ts @@ -5,8 +5,8 @@ import type { StateCreator } from 'zustand/vanilla'; import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; import { ModelProvider } from '@/libs/agent-runtime'; import { UserStore } from '@/store/user'; -import { ChatModelCard } from '@/types/llm'; -import { +import type { ChatModelCard, ModelProviderCard } from '@/types/llm'; +import type { GlobalLLMProviderKey, UserKeyVaults, UserModelProviderConfig, @@ -79,20 +79,21 @@ export const createModelListSlice: StateCreator< * 3 - default model cards */ - // eslint-disable-next-line unicorn/consistent-function-scoping - const mergeModels = (provider: GlobalLLMProviderKey, defaultChatModels: ChatModelCard[]) => { + const mergeModels = (providerKey: GlobalLLMProviderKey, providerCard: ModelProviderCard) => { // if the chat model is config in the server side, use the server side model cards - const serverChatModels = modelProviderSelectors.serverProviderModelCards(provider)(get()); - const remoteChatModels = modelProviderSelectors.remoteProviderModelCards(provider)(get()); + const serverChatModels = modelProviderSelectors.serverProviderModelCards(providerKey)(get()); + const remoteChatModels = providerCard.modelList?.showModelFetcher + ? modelProviderSelectors.remoteProviderModelCards(providerKey)(get()) + : undefined; - return serverChatModels ?? remoteChatModels ?? defaultChatModels; + return serverChatModels ?? remoteChatModels ?? providerCard.chatModels; }; const defaultModelProviderList = produce(DEFAULT_MODEL_PROVIDER_LIST, (draft) => { - Object.values(ModelProvider).forEach((id) =>{ - const provider = draft.find((d) => d.id === id); - if (provider) provider.chatModels = mergeModels(id as any, provider.chatModels); - }) + Object.values(ModelProvider).forEach((id) => { + const provider = draft.find((d) => d.id === id); + if (provider) provider.chatModels = mergeModels(id as any, provider); + }); }); set({ defaultModelProviderList }, false, `refreshDefaultModelList - ${params?.trigger}`);