diff --git a/src/renderer/src/hooks/useProvider.ts b/src/renderer/src/hooks/useProvider.ts index 960d2929d0..082a252002 100644 --- a/src/renderer/src/hooks/useProvider.ts +++ b/src/renderer/src/hooks/useProvider.ts @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit' import { CHERRYAI_PROVIDER } from '@renderer/config/providers' import { getDefaultProvider } from '@renderer/services/AssistantService' -import { useAppDispatch, useAppSelector } from '@renderer/store' +import { type RootState, useAppDispatch, useAppSelector } from '@renderer/store' import { addModel, addProvider, @@ -14,6 +14,7 @@ import { import type { Assistant, Model, Provider } from '@renderer/types' import { isSystemProvider } from '@renderer/types' import { withoutTrailingSlash } from '@renderer/utils/api' +import { useMemo } from 'react' import { useDefaultModel } from './useAssistant' @@ -28,13 +29,27 @@ function normalizeProvider(provider: T): T { } } -const selectEnabledProviders = createSelector( - (state) => state.llm.providers, - (providers) => - providers - .map(normalizeProvider) - .filter((p) => p.enabled) - .concat(CHERRYAI_PROVIDER) +const selectProviders = (state: RootState) => state.llm.providers + +const selectEnabledProviders = createSelector(selectProviders, (providers) => + providers + .map(normalizeProvider) + .filter((p) => p.enabled) + .concat(CHERRYAI_PROVIDER) +) + +const selectSystemProviders = createSelector(selectProviders, (providers) => + providers.filter((p) => isSystemProvider(p)).map(normalizeProvider) +) + +const selectUserProviders = createSelector(selectProviders, (providers) => + providers.filter((p) => !isSystemProvider(p)).map(normalizeProvider) +) + +const selectAllProviders = createSelector(selectProviders, (providers) => providers.map(normalizeProvider)) + +const selectAllProvidersWithCherryAI = createSelector(selectProviders, (providers) => + [...providers, CHERRYAI_PROVIDER].map(normalizeProvider) ) export function useProviders() { @@ -51,25 +66,20 @@ export function useProviders() { } export function useSystemProviders() { - return useAppSelector((state) => state.llm.providers.filter((p) => isSystemProvider(p)).map(normalizeProvider)) + return useAppSelector(selectSystemProviders) } export function useUserProviders() { - return useAppSelector((state) => state.llm.providers.filter((p) => !isSystemProvider(p)).map(normalizeProvider)) + return useAppSelector(selectUserProviders) } export function useAllProviders() { - return useAppSelector((state) => state.llm.providers.map(normalizeProvider)) + return useAppSelector(selectAllProviders) } export function useProvider(id: string) { - const provider = - useAppSelector((state) => - state.llm.providers - .concat([CHERRYAI_PROVIDER]) - .map(normalizeProvider) - .find((p) => p.id === id) - ) || getDefaultProvider() + const allProviders = useAppSelector(selectAllProvidersWithCherryAI) + const provider = useMemo(() => allProviders.find((p) => p.id === id) || getDefaultProvider(), [allProviders, id]) const dispatch = useAppDispatch() return {