From 94b58249a37a21b1c08deaa2d1edfdecbb6deb18 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Fri, 5 Apr 2024 10:58:36 -0700 Subject: [PATCH] Enable per-workspace provider/model combination (#1042) * Enable per-workspace provider/model combination * cleanup * remove resetWorkspaceChatModels and wipeWorkspaceModelPreference to prevent workspace from resetting model * add space --------- Co-authored-by: shatfield4 --- .../LLMSelection/AnthropicAiOptions/index.jsx | 76 +++--- .../LLMSelection/GeminiLLMOptions/index.jsx | 40 +-- .../LLMSelection/GroqAiOptions/index.jsx | 40 +-- .../LLMSelection/LMStudioOptions/index.jsx | 38 +-- .../LLMSelection/LocalAiOptions/index.jsx | 46 ++-- .../LLMSelection/MistralOptions/index.jsx | 4 +- .../LLMSelection/OllamaLLMOptions/index.jsx | 38 +-- .../LLMSelection/OpenAiOptions/index.jsx | 4 +- .../LLMSelection/OpenRouterOptions/index.jsx | 6 +- .../LLMSelection/PerplexityOptions/index.jsx | 4 +- .../LLMSelection/TogetherAiOptions/index.jsx | 6 +- frontend/src/hooks/useGetProvidersModels.js | 2 +- .../GeneralSettings/LLMPreference/index.jsx | 241 ++++++++++-------- .../ChatSettings/ChatModelSelection/index.jsx | 17 +- .../WorkspaceLLMItem/index.jsx | 151 +++++++++++ .../WorkspaceLLMSelection/index.jsx | 159 ++++++++++++ .../WorkspaceSettings/ChatSettings/index.jsx | 72 +++--- server/endpoints/workspaces.js | 4 +- server/models/systemSettings.js | 196 +++++++------- server/models/workspace.js | 91 ++++--- .../20240405015034_init/migration.sql | 2 + server/prisma/schema.prisma | 1 + server/utils/chats/embed.js | 4 +- server/utils/chats/index.js | 5 +- server/utils/chats/stream.js | 5 +- server/utils/helpers/index.js | 35 +-- server/utils/helpers/updateENV.js | 10 - 27 files changed, 836 insertions(+), 461 deletions(-) create mode 100644 frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx create mode 100644 frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx create mode 100644 server/prisma/migrations/20240405015034_init/migration.sql diff --git a/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx index e8c288d6..9fe283ff 100644 --- a/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx @@ -1,26 +1,6 @@ -import { Info } from "@phosphor-icons/react"; -import paths from "@/utils/paths"; - -export default function AnthropicAiOptions({ settings, showAlert = false }) { +export default function AnthropicAiOptions({ settings }) { return (
- {showAlert && ( -
-
- -

- Anthropic as your LLM requires you to set an embedding service to - use. -

-
- - Manage embedding → - -
- )}
-
- - -
+ {!settings?.credentialsOnly && ( +
+ + +
+ )}
); diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx index 3b53ccc1..a46e5132 100644 --- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx @@ -18,25 +18,27 @@ export default function GeminiLLMOptions({ settings }) { /> -
- - -
+ {!settings?.credentialsOnly && ( +
+ + +
+ )} ); diff --git a/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx b/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx index cc6fbbcc..c85f0f1e 100644 --- a/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx @@ -17,25 +17,27 @@ export default function GroqAiOptions({ settings }) { /> -
- - -
+ {!settings?.credentialsOnly && ( +
+ + +
+ )} ); } diff --git a/frontend/src/components/LLMSelection/LMStudioOptions/index.jsx b/frontend/src/components/LLMSelection/LMStudioOptions/index.jsx index 200c77a6..c94a99d7 100644 --- a/frontend/src/components/LLMSelection/LMStudioOptions/index.jsx +++ b/frontend/src/components/LLMSelection/LMStudioOptions/index.jsx @@ -46,23 +46,27 @@ export default function LMStudioOptions({ settings, showAlert = false }) { onBlur={() => setBasePath(basePathValue)} /> - -
- - e.target.blur()} - defaultValue={settings?.LMStudioTokenLimit} - required={true} - autoComplete="off" - /> -
+ {!settings?.credentialsOnly && ( + <> + +
+ + e.target.blur()} + defaultValue={settings?.LMStudioTokenLimit} + required={true} + autoComplete="off" + /> +
+ + )} ); diff --git a/frontend/src/components/LLMSelection/LocalAiOptions/index.jsx b/frontend/src/components/LLMSelection/LocalAiOptions/index.jsx index 91e38670..36b2f258 100644 --- a/frontend/src/components/LLMSelection/LocalAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/LocalAiOptions/index.jsx @@ -46,27 +46,31 @@ export default function LocalAiOptions({ settings, showAlert = false }) { onBlur={() => setBasePath(basePathValue)} /> - -
- - e.target.blur()} - defaultValue={settings?.LocalAiTokenLimit} - required={true} - autoComplete="off" - /> -
+ {!settings?.credentialsOnly && ( + <> + +
+ + e.target.blur()} + defaultValue={settings?.LocalAiTokenLimit} + required={true} + autoComplete="off" + /> +
+ + )}
diff --git a/frontend/src/components/LLMSelection/MistralOptions/index.jsx b/frontend/src/components/LLMSelection/MistralOptions/index.jsx index a143436e..4daadcff 100644 --- a/frontend/src/components/LLMSelection/MistralOptions/index.jsx +++ b/frontend/src/components/LLMSelection/MistralOptions/index.jsx @@ -24,7 +24,9 @@ export default function MistralOptions({ settings }) { onBlur={() => setMistralKey(inputValue)} />
- + {!settings?.credentialsOnly && ( + + )}
); } diff --git a/frontend/src/components/LLMSelection/OllamaLLMOptions/index.jsx b/frontend/src/components/LLMSelection/OllamaLLMOptions/index.jsx index ddfd7a81..b08f2944 100644 --- a/frontend/src/components/LLMSelection/OllamaLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/OllamaLLMOptions/index.jsx @@ -27,23 +27,27 @@ export default function OllamaLLMOptions({ settings }) { onBlur={() => setBasePath(basePathValue)} /> - -
- - e.target.blur()} - defaultValue={settings?.OllamaLLMTokenLimit} - required={true} - autoComplete="off" - /> -
+ {!settings?.credentialsOnly && ( + <> + +
+ + e.target.blur()} + defaultValue={settings?.OllamaLLMTokenLimit} + required={true} + autoComplete="off" + /> +
+ + )} ); diff --git a/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx index 1e349309..c5ec337d 100644 --- a/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx @@ -24,7 +24,9 @@ export default function OpenAiOptions({ settings }) { onBlur={() => setOpenAIKey(inputValue)} /> - + {!settings?.credentialsOnly && ( + + )} ); } diff --git a/frontend/src/components/LLMSelection/OpenRouterOptions/index.jsx b/frontend/src/components/LLMSelection/OpenRouterOptions/index.jsx index ff2a1d8f..94ae320a 100644 --- a/frontend/src/components/LLMSelection/OpenRouterOptions/index.jsx +++ b/frontend/src/components/LLMSelection/OpenRouterOptions/index.jsx @@ -19,7 +19,9 @@ export default function OpenRouterOptions({ settings }) { spellCheck={false} /> - + {!settings?.credentialsOnly && ( + + )} ); } @@ -84,7 +86,7 @@ function OpenRouterModelSelection({ settings }) { diff --git a/frontend/src/components/LLMSelection/PerplexityOptions/index.jsx b/frontend/src/components/LLMSelection/PerplexityOptions/index.jsx index 6c452249..9b53cd19 100644 --- a/frontend/src/components/LLMSelection/PerplexityOptions/index.jsx +++ b/frontend/src/components/LLMSelection/PerplexityOptions/index.jsx @@ -19,7 +19,9 @@ export default function PerplexityOptions({ settings }) { spellCheck={false} /> - + {!settings?.credentialsOnly && ( + + )} ); } diff --git a/frontend/src/components/LLMSelection/TogetherAiOptions/index.jsx b/frontend/src/components/LLMSelection/TogetherAiOptions/index.jsx index 2c816339..a0eefc83 100644 --- a/frontend/src/components/LLMSelection/TogetherAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/TogetherAiOptions/index.jsx @@ -19,7 +19,9 @@ export default function TogetherAiOptions({ settings }) { spellCheck={false} /> - + {!settings?.credentialsOnly && ( + + )} ); } @@ -84,7 +86,7 @@ function TogetherAiModelSelection({ settings }) { diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js index f578c929..95df82a3 100644 --- a/frontend/src/hooks/useGetProvidersModels.js +++ b/frontend/src/hooks/useGetProvidersModels.js @@ -2,7 +2,7 @@ import System from "@/models/system"; import { useEffect, useState } from "react"; // Providers which cannot use this feature for workspace<>model selection -export const DISABLED_PROVIDERS = ["azure", "lmstudio"]; +export const DISABLED_PROVIDERS = ["azure", "lmstudio", "native"]; const PROVIDER_DEFAULT_MODELS = { openai: [ "gpt-3.5-turbo", diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx index b9525c92..ccc6508b 100644 --- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx @@ -36,6 +36,130 @@ import GroqAiOptions from "@/components/LLMSelection/GroqAiOptions"; import LLMItem from "@/components/LLMSelection/LLMItem"; import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react"; +export const AVAILABLE_LLM_PROVIDERS = [ + { + name: "OpenAI", + value: "openai", + logo: OpenAiLogo, + options: (settings) => , + description: "The standard option for most non-commercial use.", + requiredConfig: ["OpenAiKey"], + }, + { + name: "Azure OpenAI", + value: "azure", + logo: AzureOpenAiLogo, + options: (settings) => , + description: "The enterprise option of OpenAI hosted on Azure services.", + requiredConfig: ["AzureOpenAiEndpoint"], + }, + { + name: "Anthropic", + value: "anthropic", + logo: AnthropicLogo, + options: (settings) => , + description: "A friendly AI Assistant hosted by Anthropic.", + requiredConfig: ["AnthropicApiKey"], + }, + { + name: "Gemini", + value: "gemini", + logo: GeminiLogo, + options: (settings) => , + description: "Google's largest and most capable AI model", + requiredConfig: ["GeminiLLMApiKey"], + }, + { + name: "HuggingFace", + value: "huggingface", + logo: HuggingFaceLogo, + options: (settings) => , + description: + "Access 150,000+ open-source LLMs and the world's AI community", + requiredConfig: [ + "HuggingFaceLLMEndpoint", + "HuggingFaceLLMAccessToken", + "HuggingFaceLLMTokenLimit", + ], + }, + { + name: "Ollama", + value: "ollama", + logo: OllamaLogo, + options: (settings) => , + description: "Run LLMs locally on your own machine.", + requiredConfig: ["OllamaLLMBasePath"], + }, + { + name: "LM Studio", + value: "lmstudio", + logo: LMStudioLogo, + options: (settings) => , + description: + "Discover, download, and run thousands of cutting edge LLMs in a few clicks.", + requiredConfig: ["LMStudioBasePath"], + }, + { + name: "Local AI", + value: "localai", + logo: LocalAiLogo, + options: (settings) => , + description: "Run LLMs locally on your own machine.", + requiredConfig: ["LocalAiApiKey", "LocalAiBasePath", "LocalAiTokenLimit"], + }, + { + name: "Together AI", + value: "togetherai", + logo: TogetherAILogo, + options: (settings) => , + description: "Run open source models from Together AI.", + requiredConfig: ["TogetherAiApiKey"], + }, + { + name: "Mistral", + value: "mistral", + logo: MistralLogo, + options: (settings) => , + description: "Run open source models from Mistral AI.", + requiredConfig: ["MistralApiKey"], + }, + { + name: "Perplexity AI", + value: "perplexity", + logo: PerplexityLogo, + options: (settings) => , + description: + "Run powerful and internet-connected models hosted by Perplexity AI.", + requiredConfig: ["PerplexityApiKey"], + }, + { + name: "OpenRouter", + value: "openrouter", + logo: OpenRouterLogo, + options: (settings) => , + description: "A unified interface for LLMs.", + requiredConfig: ["OpenRouterApiKey"], + }, + { + name: "Groq", + value: "groq", + logo: GroqLogo, + options: (settings) => , + description: + "The fastest LLM inferencing available for real-time AI applications.", + requiredConfig: ["GroqApiKey"], + }, + { + name: "Native", + value: "native", + logo: AnythingLLMIcon, + options: (settings) => , + description: + "Use a downloaded custom Llama model for chatting on this AnythingLLM instance.", + requiredConfig: [], + }, +]; + export default function GeneralLLMPreference() { const [saving, setSaving] = useState(false); const [hasChanges, setHasChanges] = useState(false); @@ -94,120 +218,15 @@ export default function GeneralLLMPreference() { }, []); useEffect(() => { - const filtered = LLMS.filter((llm) => + const filtered = AVAILABLE_LLM_PROVIDERS.filter((llm) => llm.name.toLowerCase().includes(searchQuery.toLowerCase()) ); setFilteredLLMs(filtered); }, [searchQuery, selectedLLM]); - const LLMS = [ - { - name: "OpenAI", - value: "openai", - logo: OpenAiLogo, - options: , - description: "The standard option for most non-commercial use.", - }, - { - name: "Azure OpenAI", - value: "azure", - logo: AzureOpenAiLogo, - options: , - description: "The enterprise option of OpenAI hosted on Azure services.", - }, - { - name: "Anthropic", - value: "anthropic", - logo: AnthropicLogo, - options: , - description: "A friendly AI Assistant hosted by Anthropic.", - }, - { - name: "Gemini", - value: "gemini", - logo: GeminiLogo, - options: , - description: "Google's largest and most capable AI model", - }, - { - name: "HuggingFace", - value: "huggingface", - logo: HuggingFaceLogo, - options: , - description: - "Access 150,000+ open-source LLMs and the world's AI community", - }, - { - name: "Ollama", - value: "ollama", - logo: OllamaLogo, - options: , - description: "Run LLMs locally on your own machine.", - }, - { - name: "LM Studio", - value: "lmstudio", - logo: LMStudioLogo, - options: , - description: - "Discover, download, and run thousands of cutting edge LLMs in a few clicks.", - }, - { - name: "Local AI", - value: "localai", - logo: LocalAiLogo, - options: , - description: "Run LLMs locally on your own machine.", - }, - { - name: "Together AI", - value: "togetherai", - logo: TogetherAILogo, - options: , - description: "Run open source models from Together AI.", - }, - { - name: "Mistral", - value: "mistral", - logo: MistralLogo, - options: , - description: "Run open source models from Mistral AI.", - }, - { - name: "Perplexity AI", - value: "perplexity", - logo: PerplexityLogo, - options: , - description: - "Run powerful and internet-connected models hosted by Perplexity AI.", - }, - { - name: "OpenRouter", - value: "openrouter", - logo: OpenRouterLogo, - options: , - description: "A unified interface for LLMs.", - }, - { - name: "Groq", - value: "groq", - logo: GroqLogo, - options: , - description: - "The fastest LLM inferencing available for real-time AI applications.", - }, - { - name: "Native", - value: "native", - logo: AnythingLLMIcon, - options: , - description: - "Use a downloaded custom Llama model for chatting on this AnythingLLM instance.", - }, - ]; - - const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM); - + const selectedLLMObject = AVAILABLE_LLM_PROVIDERS.find( + (llm) => llm.value === selectedLLM + ); return (
@@ -339,7 +358,9 @@ export default function GeneralLLMPreference() { className="mt-4 flex flex-col gap-y-1" > {selectedLLM && - LLMS.find((llm) => llm.value === selectedLLM)?.options} + AVAILABLE_LLM_PROVIDERS.find( + (llm) => llm.value === selectedLLM + )?.options?.(settings)}
diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/ChatModelSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/ChatModelSelection/index.jsx index 3ef7bb7a..9ed42429 100644 --- a/frontend/src/pages/WorkspaceSettings/ChatSettings/ChatModelSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/ChatModelSelection/index.jsx @@ -3,21 +3,20 @@ import useGetProviderModels, { } from "@/hooks/useGetProvidersModels"; export default function ChatModelSelection({ - settings, + provider, workspace, setHasChanges, }) { - const { defaultModels, customModels, loading } = useGetProviderModels( - settings?.LLMProvider - ); - if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null; + const { defaultModels, customModels, loading } = + useGetProviderModels(provider); + if (DISABLED_PROVIDERS.includes(provider)) return null; if (loading) { return (

The specific chat model that will be used for this workspace. If @@ -42,8 +41,7 @@ export default function ChatModelSelection({

The specific chat model that will be used for this workspace. If @@ -59,9 +57,6 @@ export default function ChatModelSelection({ }} className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5" > - {defaultModels.length > 0 && ( {defaultModels.map((model) => { diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx new file mode 100644 index 00000000..872d2a42 --- /dev/null +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx @@ -0,0 +1,151 @@ +// This component differs from the main LLMItem in that it shows if a provider is +// "ready for use" and if not - will then highjack the click handler to show a modal +// of the provider options that must be saved to continue. +import { createPortal } from "react-dom"; +import ModalWrapper from "@/components/ModalWrapper"; +import { useModal } from "@/hooks/useModal"; +import { X } from "@phosphor-icons/react"; +import System from "@/models/system"; +import showToast from "@/utils/toast"; + +export default function WorkspaceLLM({ + llm, + availableLLMs, + settings, + checked, + onClick, +}) { + const { isOpen, openModal, closeModal } = useModal(); + const { name, value, logo, description } = llm; + + function handleProviderSelection() { + // Determine if provider needs additional setup because its minimum required keys are + // not yet set in settings. + const requiresAdditionalSetup = (llm.requiredConfig || []).some( + (key) => !settings[key] + ); + if (requiresAdditionalSetup) { + openModal(); + return; + } + onClick(value); + } + + return ( + <> +

+ +
+ {`${name} +
+
{name}
+
{description}
+
+
+
+ + + ); +} + +function SetupProvider({ + availableLLMs, + isOpen, + provider, + closeModal, + postSubmit, +}) { + if (!isOpen) return null; + const LLMOption = availableLLMs.find((llm) => llm.value === provider); + if (!LLMOption) return null; + + async function handleUpdate(e) { + e.preventDefault(); + e.stopPropagation(); + const data = {}; + const form = new FormData(e.target); + for (var [key, value] of form.entries()) data[key] = value; + const { error } = await System.updateSystem(data); + if (error) { + showToast(`Failed to save ${LLMOption.name} settings: ${error}`, "error"); + return; + } + + closeModal(); + postSubmit(); + return false; + } + + // Cannot do nested forms, it will cause all sorts of issues, so we portal this out + // to the parent container form so we don't have nested forms. + return createPortal( + +
+
+
+

+ Setup {LLMOption.name} +

+ +
+ +
+
+

+ To use {LLMOption.name} as this workspace's LLM you need to set + it up first. +

+
{LLMOption.options({ credentialsOnly: true })}
+
+
+ + +
+
+
+
+
, + document.getElementById("workspace-chat-settings-container") + ); +} diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx new file mode 100644 index 00000000..07e35596 --- /dev/null +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx @@ -0,0 +1,159 @@ +import React, { useEffect, useRef, useState } from "react"; +import AnythingLLMIcon from "@/media/logo/anything-llm-icon.png"; +import WorkspaceLLMItem from "./WorkspaceLLMItem"; +import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference"; +import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react"; +import ChatModelSelection from "../ChatModelSelection"; + +const DISABLED_PROVIDERS = ["azure", "lmstudio", "native"]; +const LLM_DEFAULT = { + name: "System default", + value: "default", + logo: AnythingLLMIcon, + options: () => , + description: "Use the system LLM preference for this workspace.", + requiredConfig: [], +}; + +export default function WorkspaceLLMSelection({ + settings, + workspace, + setHasChanges, +}) { + const [filteredLLMs, setFilteredLLMs] = useState([]); + const [selectedLLM, setSelectedLLM] = useState( + workspace?.chatProvider ?? "default" + ); + const [searchQuery, setSearchQuery] = useState(""); + const [searchMenuOpen, setSearchMenuOpen] = useState(false); + const searchInputRef = useRef(null); + const LLMS = [LLM_DEFAULT, ...AVAILABLE_LLM_PROVIDERS].filter( + (llm) => !DISABLED_PROVIDERS.includes(llm.value) + ); + + function updateLLMChoice(selection) { + console.log({ selection }); + setSearchQuery(""); + setSelectedLLM(selection); + setSearchMenuOpen(false); + setHasChanges(true); + } + + function handleXButton() { + if (searchQuery.length > 0) { + setSearchQuery(""); + if (searchInputRef.current) searchInputRef.current.value = ""; + } else { + setSearchMenuOpen(!searchMenuOpen); + } + } + + useEffect(() => { + const filtered = LLMS.filter((llm) => + llm.name.toLowerCase().includes(searchQuery.toLowerCase()) + ); + setFilteredLLMs(filtered); + }, [LLMS, searchQuery, selectedLLM]); + + const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM); + return ( +
+
+ +

+ The specific LLM provider & model that will be used for this + workspace. By default, it uses the system LLM provider and settings. +

+
+ +
+ + {searchMenuOpen && ( +
setSearchMenuOpen(false)} + /> + )} + {searchMenuOpen ? ( +
+
+
+ + setSearchQuery(e.target.value)} + ref={searchInputRef} + onKeyDown={(e) => { + if (e.key === "Enter") e.preventDefault(); + }} + /> + +
+
+ {filteredLLMs.map((llm) => { + return ( + updateLLMChoice(llm.value)} + /> + ); + })} +
+
+
+ ) : ( + + )} +
+ {selectedLLM !== "default" && ( +
+ +
+ )} +
+ ); +} diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx index 3004b871..a6bab2c3 100644 --- a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx @@ -3,11 +3,11 @@ import Workspace from "@/models/workspace"; import showToast from "@/utils/toast"; import { castToType } from "@/utils/types"; import { useEffect, useRef, useState } from "react"; -import ChatModelSelection from "./ChatModelSelection"; import ChatHistorySettings from "./ChatHistorySettings"; import ChatPromptSettings from "./ChatPromptSettings"; import ChatTemperatureSettings from "./ChatTemperatureSettings"; import ChatModeSelection from "./ChatModeSelection"; +import WorkspaceLLMSelection from "./WorkspaceLLMSelection"; export default function ChatSettings({ workspace }) { const [settings, setSettings] = useState({}); @@ -44,35 +44,45 @@ export default function ChatSettings({ workspace }) { if (!workspace) return null; return ( -
- - - - - - {hasChanges && ( - - )} - +
+
+ + + + + + {hasChanges && ( + + )} + +
); } diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js index da9e2ad9..1c87dc36 100644 --- a/server/endpoints/workspaces.js +++ b/server/endpoints/workspaces.js @@ -508,7 +508,7 @@ function workspaceEndpoints(app) { if (fs.existsSync(oldPfpPath)) fs.unlinkSync(oldPfpPath); } - const { workspace, message } = await Workspace.update( + const { workspace, message } = await Workspace._update( workspaceRecord.id, { pfpFilename: uploadedFileName, @@ -547,7 +547,7 @@ function workspaceEndpoints(app) { if (fs.existsSync(oldPfpPath)) fs.unlinkSync(oldPfpPath); } - const { workspace, message } = await Workspace.update( + const { workspace, message } = await Workspace._update( workspaceRecord.id, { pfpFilename: null, diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index c4529ad9..080a01f0 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -57,103 +57,13 @@ const SystemSettings = { // VectorDB Provider Selection Settings & Configs // -------------------------------------------------------- VectorDB: vectorDB, - // Pinecone DB Keys - PineConeKey: !!process.env.PINECONE_API_KEY, - PineConeIndex: process.env.PINECONE_INDEX, - - // Chroma DB Keys - ChromaEndpoint: process.env.CHROMA_ENDPOINT, - ChromaApiHeader: process.env.CHROMA_API_HEADER, - ChromaApiKey: !!process.env.CHROMA_API_KEY, - - // Weaviate DB Keys - WeaviateEndpoint: process.env.WEAVIATE_ENDPOINT, - WeaviateApiKey: process.env.WEAVIATE_API_KEY, - - // QDrant DB Keys - QdrantEndpoint: process.env.QDRANT_ENDPOINT, - QdrantApiKey: process.env.QDRANT_API_KEY, - - // Milvus DB Keys - MilvusAddress: process.env.MILVUS_ADDRESS, - MilvusUsername: process.env.MILVUS_USERNAME, - MilvusPassword: !!process.env.MILVUS_PASSWORD, - - // Zilliz DB Keys - ZillizEndpoint: process.env.ZILLIZ_ENDPOINT, - ZillizApiToken: process.env.ZILLIZ_API_TOKEN, - - // AstraDB Keys - AstraDBApplicationToken: process?.env?.ASTRA_DB_APPLICATION_TOKEN, - AstraDBEndpoint: process?.env?.ASTRA_DB_ENDPOINT, + ...this.vectorDBPreferenceKeys(), // -------------------------------------------------------- // LLM Provider Selection Settings & Configs // -------------------------------------------------------- LLMProvider: llmProvider, - // OpenAI Keys - OpenAiKey: !!process.env.OPEN_AI_KEY, - OpenAiModelPref: process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo", - - // Azure + OpenAI Keys - AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT, - AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY, - AzureOpenAiModelPref: process.env.OPEN_MODEL_PREF, - AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF, - AzureOpenAiTokenLimit: process.env.AZURE_OPENAI_TOKEN_LIMIT || 4096, - - // Anthropic Keys - AnthropicApiKey: !!process.env.ANTHROPIC_API_KEY, - AnthropicModelPref: process.env.ANTHROPIC_MODEL_PREF || "claude-2", - - // Gemini Keys - GeminiLLMApiKey: !!process.env.GEMINI_API_KEY, - GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro", - - // LMStudio Keys - LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH, - LMStudioTokenLimit: process.env.LMSTUDIO_MODEL_TOKEN_LIMIT, - LMStudioModelPref: process.env.LMSTUDIO_MODEL_PREF, - - // LocalAI Keys - LocalAiApiKey: !!process.env.LOCAL_AI_API_KEY, - LocalAiBasePath: process.env.LOCAL_AI_BASE_PATH, - LocalAiModelPref: process.env.LOCAL_AI_MODEL_PREF, - LocalAiTokenLimit: process.env.LOCAL_AI_MODEL_TOKEN_LIMIT, - - // Ollama LLM Keys - OllamaLLMBasePath: process.env.OLLAMA_BASE_PATH, - OllamaLLMModelPref: process.env.OLLAMA_MODEL_PREF, - OllamaLLMTokenLimit: process.env.OLLAMA_MODEL_TOKEN_LIMIT, - - // TogetherAI Keys - TogetherAiApiKey: !!process.env.TOGETHER_AI_API_KEY, - TogetherAiModelPref: process.env.TOGETHER_AI_MODEL_PREF, - - // Perplexity AI Keys - PerplexityApiKey: !!process.env.PERPLEXITY_API_KEY, - PerplexityModelPref: process.env.PERPLEXITY_MODEL_PREF, - - // OpenRouter Keys - OpenRouterApiKey: !!process.env.OPENROUTER_API_KEY, - OpenRouterModelPref: process.env.OPENROUTER_MODEL_PREF, - - // Mistral AI (API) Keys - MistralApiKey: !!process.env.MISTRAL_API_KEY, - MistralModelPref: process.env.MISTRAL_MODEL_PREF, - - // Groq AI API Keys - GroqApiKey: !!process.env.GROQ_API_KEY, - GroqModelPref: process.env.GROQ_MODEL_PREF, - - // Native LLM Keys - NativeLLMModelPref: process.env.NATIVE_LLM_MODEL_PREF, - NativeLLMTokenLimit: process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT, - - // HuggingFace Dedicated Inference - HuggingFaceLLMEndpoint: process.env.HUGGING_FACE_LLM_ENDPOINT, - HuggingFaceLLMAccessToken: !!process.env.HUGGING_FACE_LLM_API_KEY, - HuggingFaceLLMTokenLimit: process.env.HUGGING_FACE_LLM_TOKEN_LIMIT, + ...this.llmPreferenceKeys(), // -------------------------------------------------------- // Whisper (Audio transcription) Selection Settings & Configs @@ -273,6 +183,108 @@ const SystemSettings = { return false; } }, + + vectorDBPreferenceKeys: function () { + return { + // Pinecone DB Keys + PineConeKey: !!process.env.PINECONE_API_KEY, + PineConeIndex: process.env.PINECONE_INDEX, + + // Chroma DB Keys + ChromaEndpoint: process.env.CHROMA_ENDPOINT, + ChromaApiHeader: process.env.CHROMA_API_HEADER, + ChromaApiKey: !!process.env.CHROMA_API_KEY, + + // Weaviate DB Keys + WeaviateEndpoint: process.env.WEAVIATE_ENDPOINT, + WeaviateApiKey: process.env.WEAVIATE_API_KEY, + + // QDrant DB Keys + QdrantEndpoint: process.env.QDRANT_ENDPOINT, + QdrantApiKey: process.env.QDRANT_API_KEY, + + // Milvus DB Keys + MilvusAddress: process.env.MILVUS_ADDRESS, + MilvusUsername: process.env.MILVUS_USERNAME, + MilvusPassword: !!process.env.MILVUS_PASSWORD, + + // Zilliz DB Keys + ZillizEndpoint: process.env.ZILLIZ_ENDPOINT, + ZillizApiToken: process.env.ZILLIZ_API_TOKEN, + + // AstraDB Keys + AstraDBApplicationToken: process?.env?.ASTRA_DB_APPLICATION_TOKEN, + AstraDBEndpoint: process?.env?.ASTRA_DB_ENDPOINT, + }; + }, + + llmPreferenceKeys: function () { + return { + // OpenAI Keys + OpenAiKey: !!process.env.OPEN_AI_KEY, + OpenAiModelPref: process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo", + + // Azure + OpenAI Keys + AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT, + AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY, + AzureOpenAiModelPref: process.env.OPEN_MODEL_PREF, + AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF, + AzureOpenAiTokenLimit: process.env.AZURE_OPENAI_TOKEN_LIMIT || 4096, + + // Anthropic Keys + AnthropicApiKey: !!process.env.ANTHROPIC_API_KEY, + AnthropicModelPref: process.env.ANTHROPIC_MODEL_PREF || "claude-2", + + // Gemini Keys + GeminiLLMApiKey: !!process.env.GEMINI_API_KEY, + GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro", + + // LMStudio Keys + LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH, + LMStudioTokenLimit: process.env.LMSTUDIO_MODEL_TOKEN_LIMIT, + LMStudioModelPref: process.env.LMSTUDIO_MODEL_PREF, + + // LocalAI Keys + LocalAiApiKey: !!process.env.LOCAL_AI_API_KEY, + LocalAiBasePath: process.env.LOCAL_AI_BASE_PATH, + LocalAiModelPref: process.env.LOCAL_AI_MODEL_PREF, + LocalAiTokenLimit: process.env.LOCAL_AI_MODEL_TOKEN_LIMIT, + + // Ollama LLM Keys + OllamaLLMBasePath: process.env.OLLAMA_BASE_PATH, + OllamaLLMModelPref: process.env.OLLAMA_MODEL_PREF, + OllamaLLMTokenLimit: process.env.OLLAMA_MODEL_TOKEN_LIMIT, + + // TogetherAI Keys + TogetherAiApiKey: !!process.env.TOGETHER_AI_API_KEY, + TogetherAiModelPref: process.env.TOGETHER_AI_MODEL_PREF, + + // Perplexity AI Keys + PerplexityApiKey: !!process.env.PERPLEXITY_API_KEY, + PerplexityModelPref: process.env.PERPLEXITY_MODEL_PREF, + + // OpenRouter Keys + OpenRouterApiKey: !!process.env.OPENROUTER_API_KEY, + OpenRouterModelPref: process.env.OPENROUTER_MODEL_PREF, + + // Mistral AI (API) Keys + MistralApiKey: !!process.env.MISTRAL_API_KEY, + MistralModelPref: process.env.MISTRAL_MODEL_PREF, + + // Groq AI API Keys + GroqApiKey: !!process.env.GROQ_API_KEY, + GroqModelPref: process.env.GROQ_MODEL_PREF, + + // Native LLM Keys + NativeLLMModelPref: process.env.NATIVE_LLM_MODEL_PREF, + NativeLLMTokenLimit: process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT, + + // HuggingFace Dedicated Inference + HuggingFaceLLMEndpoint: process.env.HUGGING_FACE_LLM_ENDPOINT, + HuggingFaceLLMAccessToken: !!process.env.HUGGING_FACE_LLM_API_KEY, + HuggingFaceLLMTokenLimit: process.env.HUGGING_FACE_LLM_TOKEN_LIMIT, + }; + }, }; module.exports.SystemSettings = SystemSettings; diff --git a/server/models/workspace.js b/server/models/workspace.js index f061ca20..b905c199 100644 --- a/server/models/workspace.js +++ b/server/models/workspace.js @@ -19,6 +19,7 @@ const Workspace = { "lastUpdatedAt", "openAiPrompt", "similarityThreshold", + "chatProvider", "chatModel", "topN", "chatMode", @@ -52,19 +53,42 @@ const Workspace = { } }, - update: async function (id = null, data = {}) { + update: async function (id = null, updates = {}) { if (!id) throw new Error("No workspace id provided for update"); - const validKeys = Object.keys(data).filter((key) => + const validFields = Object.keys(updates).filter((key) => this.writable.includes(key) ); - if (validKeys.length === 0) + + Object.entries(updates).forEach(([key]) => { + if (validFields.includes(key)) return; + delete updates[key]; + }); + + if (Object.keys(updates).length === 0) return { workspace: { id }, message: "No valid fields to update!" }; + // If the user unset the chatProvider we will need + // to then clear the chatModel as well to prevent confusion during + // LLM loading. + if (updates?.chatProvider === "default") { + updates.chatProvider = null; + updates.chatModel = null; + } + + return this._update(id, updates); + }, + + // Explicit update of settings + key validations. + // Only use this method when directly setting a key value + // that takes no user input for the keys being modified. + _update: async function (id = null, data = {}) { + if (!id) throw new Error("No workspace id provided for update"); + try { const workspace = await prisma.workspaces.update({ where: { id }, - data, // TODO: strict validation on writables here. + data, }); return { workspace, message: null }; } catch (error) { @@ -229,47 +253,40 @@ const Workspace = { } }, - resetWorkspaceChatModels: async () => { - try { - await prisma.workspaces.updateMany({ - data: { - chatModel: null, - }, - }); - return { success: true, error: null }; - } catch (error) { - console.error("Error resetting workspace chat models:", error.message); - return { success: false, error: error.message }; - } - }, - trackChange: async function (prevData, newData, user) { try { - const { Telemetry } = require("./telemetry"); - const { EventLogs } = require("./eventLogs"); - if ( - !newData?.openAiPrompt || - newData?.openAiPrompt === this.defaultPrompt || - newData?.openAiPrompt === prevData?.openAiPrompt - ) - return; - - await Telemetry.sendTelemetry("workspace_prompt_changed"); - await EventLogs.logEvent( - "workspace_prompt_changed", - { - workspaceName: prevData?.name, - prevSystemPrompt: prevData?.openAiPrompt || this.defaultPrompt, - newSystemPrompt: newData?.openAiPrompt, - }, - user?.id - ); + await this._trackWorkspacePromptChange(prevData, newData, user); return; } catch (error) { console.error("Error tracking workspace change:", error.message); return; } }, + + // We are only tracking this change to determine the need to a prompt library or + // prompt assistant feature. If this is something you would like to see - tell us on GitHub! + _trackWorkspacePromptChange: async function (prevData, newData, user) { + const { Telemetry } = require("./telemetry"); + const { EventLogs } = require("./eventLogs"); + if ( + !newData?.openAiPrompt || + newData?.openAiPrompt === this.defaultPrompt || + newData?.openAiPrompt === prevData?.openAiPrompt + ) + return; + + await Telemetry.sendTelemetry("workspace_prompt_changed"); + await EventLogs.logEvent( + "workspace_prompt_changed", + { + workspaceName: prevData?.name, + prevSystemPrompt: prevData?.openAiPrompt || this.defaultPrompt, + newSystemPrompt: newData?.openAiPrompt, + }, + user?.id + ); + return; + }, }; module.exports = { Workspace }; diff --git a/server/prisma/migrations/20240405015034_init/migration.sql b/server/prisma/migrations/20240405015034_init/migration.sql new file mode 100644 index 00000000..54a39d94 --- /dev/null +++ b/server/prisma/migrations/20240405015034_init/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "workspaces" ADD COLUMN "chatProvider" TEXT; diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index fbb5f61d..1e589b0f 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -98,6 +98,7 @@ model workspaces { lastUpdatedAt DateTime @default(now()) openAiPrompt String? similarityThreshold Float? @default(0.25) + chatProvider String? chatModel String? topN Int? @default(4) chatMode String? @default("chat") diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index f748a3a5..497b2c8e 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -28,7 +28,9 @@ async function streamChatWithForEmbed( embed.workspace.openAiTemp = parseFloat(temperatureOverride); const uuid = uuidv4(); - const LLMConnector = getLLMProvider(chatModel ?? embed.workspace?.chatModel); + const LLMConnector = getLLMProvider({ + model: chatModel ?? embed.workspace?.chatModel, + }); const VectorDb = getVectorDbClass(); const { safe, reasons = [] } = await LLMConnector.isSafe(message); if (!safe) { diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 10df9983..7e40b9a8 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -37,7 +37,10 @@ async function chatWithWorkspace( return await VALID_COMMANDS[command](workspace, message, uuid, user); } - const LLMConnector = getLLMProvider(workspace?.chatModel); + const LLMConnector = getLLMProvider({ + provider: workspace?.chatProvider, + model: workspace?.chatModel, + }); const VectorDb = getVectorDbClass(); const { safe, reasons = [] } = await LLMConnector.isSafe(message); if (!safe) { diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index f1a335bc..0ec969eb 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -35,7 +35,10 @@ async function streamChatWithWorkspace( return; } - const LLMConnector = getLLMProvider(workspace?.chatModel); + const LLMConnector = getLLMProvider({ + provider: workspace?.chatProvider, + model: workspace?.chatModel, + }); const VectorDb = getVectorDbClass(); const { safe, reasons = [] } = await LLMConnector.isSafe(message); if (!safe) { diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 78360972..a441bf82 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -30,52 +30,53 @@ function getVectorDbClass() { } } -function getLLMProvider(modelPreference = null) { - const vectorSelection = process.env.LLM_PROVIDER || "openai"; +function getLLMProvider({ provider = null, model = null } = {}) { + const LLMSelection = provider ?? process.env.LLM_PROVIDER ?? "openai"; const embedder = getEmbeddingEngineSelection(); - switch (vectorSelection) { + + switch (LLMSelection) { case "openai": const { OpenAiLLM } = require("../AiProviders/openAi"); - return new OpenAiLLM(embedder, modelPreference); + return new OpenAiLLM(embedder, model); case "azure": const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi"); - return new AzureOpenAiLLM(embedder, modelPreference); + return new AzureOpenAiLLM(embedder, model); case "anthropic": const { AnthropicLLM } = require("../AiProviders/anthropic"); - return new AnthropicLLM(embedder, modelPreference); + return new AnthropicLLM(embedder, model); case "gemini": const { GeminiLLM } = require("../AiProviders/gemini"); - return new GeminiLLM(embedder, modelPreference); + return new GeminiLLM(embedder, model); case "lmstudio": const { LMStudioLLM } = require("../AiProviders/lmStudio"); - return new LMStudioLLM(embedder, modelPreference); + return new LMStudioLLM(embedder, model); case "localai": const { LocalAiLLM } = require("../AiProviders/localAi"); - return new LocalAiLLM(embedder, modelPreference); + return new LocalAiLLM(embedder, model); case "ollama": const { OllamaAILLM } = require("../AiProviders/ollama"); - return new OllamaAILLM(embedder, modelPreference); + return new OllamaAILLM(embedder, model); case "togetherai": const { TogetherAiLLM } = require("../AiProviders/togetherAi"); - return new TogetherAiLLM(embedder, modelPreference); + return new TogetherAiLLM(embedder, model); case "perplexity": const { PerplexityLLM } = require("../AiProviders/perplexity"); - return new PerplexityLLM(embedder, modelPreference); + return new PerplexityLLM(embedder, model); case "openrouter": const { OpenRouterLLM } = require("../AiProviders/openRouter"); - return new OpenRouterLLM(embedder, modelPreference); + return new OpenRouterLLM(embedder, model); case "mistral": const { MistralLLM } = require("../AiProviders/mistral"); - return new MistralLLM(embedder, modelPreference); + return new MistralLLM(embedder, model); case "native": const { NativeLLM } = require("../AiProviders/native"); - return new NativeLLM(embedder, modelPreference); + return new NativeLLM(embedder, model); case "huggingface": const { HuggingFaceLLM } = require("../AiProviders/huggingface"); - return new HuggingFaceLLM(embedder, modelPreference); + return new HuggingFaceLLM(embedder, model); case "groq": const { GroqLLM } = require("../AiProviders/groq"); - return new GroqLLM(embedder, modelPreference); + return new GroqLLM(embedder, model); default: throw new Error("ENV: No LLM_PROVIDER value found in environment!"); } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 12c45af2..a026fe33 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -2,7 +2,6 @@ const KEY_MAPPING = { LLMProvider: { envKey: "LLM_PROVIDER", checks: [isNotEmpty, supportedLLM], - postUpdate: [wipeWorkspaceModelPreference], }, // OpenAI Settings OpenAiKey: { @@ -493,15 +492,6 @@ function validHuggingFaceEndpoint(input = "") { : null; } -// If the LLMProvider has changed we need to reset all workspace model preferences to -// null since the provider<>model name combination will be invalid for whatever the new -// provider is. -async function wipeWorkspaceModelPreference(key, prev, next) { - if (prev === next) return; - const { Workspace } = require("../../models/workspace"); - await Workspace.resetWorkspaceChatModels(); -} - // This will force update .env variables which for any which reason were not able to be parsed or // read from an ENV file as this seems to be a complicating step for many so allowing people to write // to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks