From 90df37582bcccba53282420eb61c8038c4699609 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Wed, 17 Jan 2024 12:59:25 -0800 Subject: [PATCH] Per workspace model selection (#582) * WIP model selection per workspace (migrations and openai saves properly * revert OpenAiOption * add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi * remove unneeded comments * update logic for when LLMProvider is reset, reset Ai provider files with master * remove frontend/api reset of workspace chat and move logic to updateENV add postUpdate callbacks to envs * set preferred model for chat on class instantiation * remove extra param * linting * remove unused var * refactor chat model selection on workspace * linting * add fallback for base path to localai models --------- Co-authored-by: timothycarambat --- .../Settings/ChatModelPreference/index.jsx | 120 ++++++++++++++++++ .../useGetProviderModels.js | 49 +++++++ .../Modals/MangeWorkspace/Settings/index.jsx | 8 +- .../Modals/MangeWorkspace/index.jsx | 1 + .../GeneralSettings/LLMPreference/index.jsx | 6 +- server/endpoints/api/system/index.js | 2 +- server/endpoints/system.js | 6 +- server/models/workspace.js | 15 +++ .../20240113013409_init/migration.sql | 2 + server/prisma/schema.prisma | 1 + server/utils/AiProviders/anthropic/index.js | 5 +- server/utils/AiProviders/azureOpenAi/index.js | 2 +- server/utils/AiProviders/gemini/index.js | 5 +- server/utils/AiProviders/lmStudio/index.js | 4 +- server/utils/AiProviders/localAi/index.js | 4 +- server/utils/AiProviders/native/index.js | 4 +- server/utils/AiProviders/ollama/index.js | 4 +- server/utils/AiProviders/openAi/index.js | 5 +- server/utils/AiProviders/togetherAi/index.js | 4 +- server/utils/chats/index.js | 2 +- server/utils/chats/stream.js | 2 +- server/utils/helpers/customModels.js | 13 +- server/utils/helpers/index.js | 20 +-- server/utils/helpers/updateENV.js | 32 +++-- 24 files changed, 263 insertions(+), 53 deletions(-) create mode 100644 frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx create mode 100644 frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js create mode 100644 server/prisma/migrations/20240113013409_init/migration.sql diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx new file mode 100644 index 00000000..ea03c09a --- /dev/null +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx @@ -0,0 +1,120 @@ +import useGetProviderModels, { + DISABLED_PROVIDERS, +} from "./useGetProviderModels"; + +export default function ChatModelSelection({ + settings, + workspace, + setHasChanges, +}) { + const { defaultModels, customModels, loading } = useGetProviderModels( + settings?.LLMProvider + ); + if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null; + + if (loading) { + return ( +
+
+ +

+ The specific chat model that will be used for this workspace. If + empty, will use the system LLM preference. +

+
+ +
+ ); + } + + return ( +
+
+ +

+ The specific chat model that will be used for this workspace. If + empty, will use the system LLM preference. +

+
+ + +
+ ); +} diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js new file mode 100644 index 00000000..eae1b4ad --- /dev/null +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js @@ -0,0 +1,49 @@ +import System from "@/models/system"; +import { useEffect, useState } from "react"; + +// Providers which cannot use this feature for workspace<>model selection +export const DISABLED_PROVIDERS = ["azure", "lmstudio"]; +const PROVIDER_DEFAULT_MODELS = { + openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"], + gemini: ["gemini-pro"], + anthropic: ["claude-2", "claude-instant-1"], + azure: [], + lmstudio: [], + localai: [], + ollama: [], + togetherai: [], + native: [], +}; + +// For togetherAi, which has a large model list - we subgroup the options +// by their creator organization (eg: Meta, Mistral, etc) +// which makes selection easier to read. +function groupModels(models) { + return models.reduce((acc, model) => { + acc[model.organization] = acc[model.organization] || []; + acc[model.organization].push(model); + return acc; + }, {}); +} + +export default function useGetProviderModels(provider = null) { + const [defaultModels, setDefaultModels] = useState([]); + const [customModels, setCustomModels] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function fetchProviderModels() { + if (!provider) return; + const { models = [] } = await System.customModels(provider); + if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider)) + setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]); + provider === "togetherai" + ? setCustomModels(groupModels(models)) + : setCustomModels(models); + setLoading(false); + } + fetchProviderModels(); + }, [provider]); + + return { defaultModels, customModels, loading }; +} diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx index 2fce91e1..a3089d68 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx @@ -6,6 +6,7 @@ import System from "../../../../models/system"; import PreLoader from "../../../Preloader"; import { useParams } from "react-router-dom"; import showToast from "../../../../utils/toast"; +import ChatModelPreference from "./ChatModelPreference"; // Ensure that a type is correct before sending the body // to the backend. @@ -26,7 +27,7 @@ function castToType(key, value) { return definitions[key].cast(value); } -export default function WorkspaceSettings({ active, workspace }) { +export default function WorkspaceSettings({ active, workspace, settings }) { const { slug } = useParams(); const formEl = useRef(null); const [saving, setSaving] = useState(false); @@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
+
diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx index 28771622..bd6ae511 100644 --- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx @@ -30,19 +30,17 @@ export default function GeneralLLMPreference() { const [hasChanges, setHasChanges] = useState(false); const [settings, setSettings] = useState(null); const [loading, setLoading] = useState(true); - const [searchQuery, setSearchQuery] = useState(""); const [filteredLLMs, setFilteredLLMs] = useState([]); const [selectedLLM, setSelectedLLM] = useState(null); - const isHosted = window.location.hostname.includes("useanything.com"); const handleSubmit = async (e) => { e.preventDefault(); const form = e.target; - const data = {}; + const data = { LLMProvider: selectedLLM }; const formData = new FormData(form); - data.LLMProvider = selectedLLM; + for (var [key, value] of formData.entries()) data[key] = value; const { error } = await System.updateSystem(data); setSaving(true); diff --git a/server/endpoints/api/system/index.js b/server/endpoints/api/system/index.js index 3548c306..b18019b1 100644 --- a/server/endpoints/api/system/index.js +++ b/server/endpoints/api/system/index.js @@ -139,7 +139,7 @@ function apiSystemEndpoints(app) { */ try { const body = reqBody(request); - const { newValues, error } = updateENV(body); + const { newValues, error } = await updateENV(body); if (process.env.NODE_ENV === "production") await dumpENV(); response.status(200).json({ newValues, error }); } catch (e) { diff --git a/server/endpoints/system.js b/server/endpoints/system.js index 15db895a..e699cf84 100644 --- a/server/endpoints/system.js +++ b/server/endpoints/system.js @@ -290,7 +290,7 @@ function systemEndpoints(app) { } const body = reqBody(request); - const { newValues, error } = updateENV(body); + const { newValues, error } = await updateENV(body); if (process.env.NODE_ENV === "production") await dumpENV(); response.status(200).json({ newValues, error }); } catch (e) { @@ -312,7 +312,7 @@ function systemEndpoints(app) { } const { usePassword, newPassword } = reqBody(request); - const { error } = updateENV( + const { error } = await updateENV( { AuthToken: usePassword ? newPassword : "", JWTSecret: usePassword ? v4() : "", @@ -355,7 +355,7 @@ function systemEndpoints(app) { message_limit: 25, }); - updateENV( + await updateENV( { AuthToken: "", JWTSecret: process.env.JWT_SECRET || v4(), diff --git a/server/models/workspace.js b/server/models/workspace.js index 9139c25e..6de8053e 100644 --- a/server/models/workspace.js +++ b/server/models/workspace.js @@ -14,6 +14,7 @@ const Workspace = { "lastUpdatedAt", "openAiPrompt", "similarityThreshold", + "chatModel", ], new: async function (name = null, creatorId = null) { @@ -191,6 +192,20 @@ const Workspace = { return { success: false, error: error.message }; } }, + + resetWorkspaceChatModels: async () => { + try { + await prisma.workspaces.updateMany({ + data: { + chatModel: null, + }, + }); + return { success: true, error: null }; + } catch (error) { + console.error("Error resetting workspace chat models:", error.message); + return { success: false, error: error.message }; + } + }, }; module.exports = { Workspace }; diff --git a/server/prisma/migrations/20240113013409_init/migration.sql b/server/prisma/migrations/20240113013409_init/migration.sql new file mode 100644 index 00000000..09b9448e --- /dev/null +++ b/server/prisma/migrations/20240113013409_init/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT; diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index e9aa8a8a..2f632a46 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -93,6 +93,7 @@ model workspaces { lastUpdatedAt DateTime @default(now()) openAiPrompt String? similarityThreshold Float? @default(0.25) + chatModel String? workspace_users workspace_users[] documents workspace_documents[] } diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index 70933323..17f2abc4 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -2,7 +2,7 @@ const { v4 } = require("uuid"); const { chatPrompt } = require("../../chats"); class AnthropicLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { if (!process.env.ANTHROPIC_API_KEY) throw new Error("No Anthropic API key was set."); @@ -12,7 +12,8 @@ class AnthropicLLM { apiKey: process.env.ANTHROPIC_API_KEY, }); this.anthropic = anthropic; - this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2"; + this.model = + modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2"; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 185dac02..f59fc51f 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi"); const { chatPrompt } = require("../../chats"); class AzureOpenAiLLM { - constructor(embedder = null) { + constructor(embedder = null, _modelPreference = null) { const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); if (!process.env.AZURE_OPENAI_ENDPOINT) throw new Error("No Azure API endpoint was set."); diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index 03388e3e..348c8f5e 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -1,14 +1,15 @@ const { chatPrompt } = require("../../chats"); class GeminiLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { if (!process.env.GEMINI_API_KEY) throw new Error("No Gemini API key was set."); // Docs: https://ai.google.dev/tutorials/node_quickstart const { GoogleGenerativeAI } = require("@google/generative-ai"); const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY); - this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro"; + this.model = + modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro"; this.gemini = genAI.getGenerativeModel({ model: this.model }); this.limits = { history: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index 28c107df..61480803 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats"); // hybrid of openAi LLM chat completion for LMStudio class LMStudioLLM { - constructor(embedder = null) { + constructor(embedder = null, _modelPreference = null) { if (!process.env.LMSTUDIO_BASE_PATH) throw new Error("No LMStudio API Base Path was set."); @@ -12,7 +12,7 @@ class LMStudioLLM { }); this.lmstudio = new OpenAIApi(config); // When using LMStudios inference server - the model param is not required so - // we can stub it here. + // we can stub it here. LMStudio can only run one model at a time. this.model = "model-placeholder"; this.limits = { history: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js index 84954c99..6623ac88 100644 --- a/server/utils/AiProviders/localAi/index.js +++ b/server/utils/AiProviders/localAi/index.js @@ -1,7 +1,7 @@ const { chatPrompt } = require("../../chats"); class LocalAiLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { if (!process.env.LOCAL_AI_BASE_PATH) throw new Error("No LocalAI Base Path was set."); @@ -15,7 +15,7 @@ class LocalAiLLM { : {}), }); this.openai = new OpenAIApi(config); - this.model = process.env.LOCAL_AI_MODEL_PREF; + this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/native/index.js b/server/utils/AiProviders/native/index.js index faac4fa0..66cc84d0 100644 --- a/server/utils/AiProviders/native/index.js +++ b/server/utils/AiProviders/native/index.js @@ -10,11 +10,11 @@ const ChatLlamaCpp = (...args) => ); class NativeLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { if (!process.env.NATIVE_LLM_MODEL_PREF) throw new Error("No local Llama model was set."); - this.model = process.env.NATIVE_LLM_MODEL_PREF || null; + this.model = modelPreference || process.env.NATIVE_LLM_MODEL_PREF || null; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index 55205c23..fce96f36 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -3,12 +3,12 @@ const { StringOutputParser } = require("langchain/schema/output_parser"); // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md class OllamaAILLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { if (!process.env.OLLAMA_BASE_PATH) throw new Error("No Ollama Base Path was set."); this.basePath = process.env.OLLAMA_BASE_PATH; - this.model = process.env.OLLAMA_MODEL_PREF; + this.model = modelPreference || process.env.OLLAMA_MODEL_PREF; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index ccc7ba0e..038d201d 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -2,7 +2,7 @@ const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi"); const { chatPrompt } = require("../../chats"); class OpenAiLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { const { Configuration, OpenAIApi } = require("openai"); if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set."); @@ -10,7 +10,8 @@ class OpenAiLLM { apiKey: process.env.OPEN_AI_KEY, }); this.openai = new OpenAIApi(config); - this.model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo"; + this.model = + modelPreference || process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo"; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/AiProviders/togetherAi/index.js b/server/utils/AiProviders/togetherAi/index.js index df64c413..44061dd0 100644 --- a/server/utils/AiProviders/togetherAi/index.js +++ b/server/utils/AiProviders/togetherAi/index.js @@ -6,7 +6,7 @@ function togetherAiModels() { } class TogetherAiLLM { - constructor(embedder = null) { + constructor(embedder = null, modelPreference = null) { const { Configuration, OpenAIApi } = require("openai"); if (!process.env.TOGETHER_AI_API_KEY) throw new Error("No TogetherAI API key was set."); @@ -16,7 +16,7 @@ class TogetherAiLLM { apiKey: process.env.TOGETHER_AI_API_KEY, }); this.openai = new OpenAIApi(config); - this.model = process.env.TOGETHER_AI_MODEL_PREF; + this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 7fdb4734..d63de47d 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -71,7 +71,7 @@ async function chatWithWorkspace( return await VALID_COMMANDS[command](workspace, message, uuid, user); } - const LLMConnector = getLLMProvider(); + const LLMConnector = getLLMProvider(workspace?.chatModel); const VectorDb = getVectorDbClass(); const { safe, reasons = [] } = await LLMConnector.isSafe(message); if (!safe) { diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index 04bb72b9..ceea8d7d 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -30,7 +30,7 @@ async function streamChatWithWorkspace( return; } - const LLMConnector = getLLMProvider(); + const LLMConnector = getLLMProvider(workspace?.chatModel); const VectorDb = getVectorDbClass(); const { safe, reasons = [] } = await LLMConnector.isSafe(message); if (!safe) { diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index 54976895..87fe976e 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -17,7 +17,7 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { case "localai": return await localAIModels(basePath, apiKey); case "ollama": - return await ollamaAIModels(basePath, apiKey); + return await ollamaAIModels(basePath); case "togetherai": return await getTogetherAiModels(); case "native-llm": @@ -53,7 +53,7 @@ async function openAiModels(apiKey = null) { async function localAIModels(basePath = null, apiKey = null) { const { Configuration, OpenAIApi } = require("openai"); const config = new Configuration({ - basePath, + basePath: basePath || process.env.LOCAL_AI_BASE_PATH, apiKey: apiKey || process.env.LOCAL_AI_API_KEY, }); const openai = new OpenAIApi(config); @@ -70,13 +70,14 @@ async function localAIModels(basePath = null, apiKey = null) { return { models, error: null }; } -async function ollamaAIModels(basePath = null, _apiKey = null) { +async function ollamaAIModels(basePath = null) { let url; try { - new URL(basePath); - if (basePath.split("").slice(-1)?.[0] === "/") + let urlPath = basePath ?? process.env.OLLAMA_BASE_PATH; + new URL(urlPath); + if (urlPath.split("").slice(-1)?.[0] === "/") throw new Error("BasePath Cannot end in /!"); - url = basePath; + url = urlPath; } catch { return { models: [], error: "Not a valid URL." }; } diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 1685acc1..2b1f3dac 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -24,37 +24,37 @@ function getVectorDbClass() { } } -function getLLMProvider() { +function getLLMProvider(modelPreference = null) { const vectorSelection = process.env.LLM_PROVIDER || "openai"; const embedder = getEmbeddingEngineSelection(); switch (vectorSelection) { case "openai": const { OpenAiLLM } = require("../AiProviders/openAi"); - return new OpenAiLLM(embedder); + return new OpenAiLLM(embedder, modelPreference); case "azure": const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi"); - return new AzureOpenAiLLM(embedder); + return new AzureOpenAiLLM(embedder, modelPreference); case "anthropic": const { AnthropicLLM } = require("../AiProviders/anthropic"); - return new AnthropicLLM(embedder); + return new AnthropicLLM(embedder, modelPreference); case "gemini": const { GeminiLLM } = require("../AiProviders/gemini"); - return new GeminiLLM(embedder); + return new GeminiLLM(embedder, modelPreference); case "lmstudio": const { LMStudioLLM } = require("../AiProviders/lmStudio"); - return new LMStudioLLM(embedder); + return new LMStudioLLM(embedder, modelPreference); case "localai": const { LocalAiLLM } = require("../AiProviders/localAi"); - return new LocalAiLLM(embedder); + return new LocalAiLLM(embedder, modelPreference); case "ollama": const { OllamaAILLM } = require("../AiProviders/ollama"); - return new OllamaAILLM(embedder); + return new OllamaAILLM(embedder, modelPreference); case "togetherai": const { TogetherAiLLM } = require("../AiProviders/togetherAi"); - return new TogetherAiLLM(embedder); + return new TogetherAiLLM(embedder, modelPreference); case "native": const { NativeLLM } = require("../AiProviders/native"); - return new NativeLLM(embedder); + return new NativeLLM(embedder, modelPreference); default: throw new Error("ENV: No LLM_PROVIDER value found in environment!"); } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index c699cf2d..5c43da51 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -2,6 +2,7 @@ const KEY_MAPPING = { LLMProvider: { envKey: "LLM_PROVIDER", checks: [isNotEmpty, supportedLLM], + postUpdate: [wipeWorkspaceModelPreference], }, // OpenAI Settings OpenAiKey: { @@ -362,11 +363,20 @@ function validDockerizedUrl(input = "") { return null; } +// If the LLMProvider has changed we need to reset all workspace model preferences to +// null since the provider<>model name combination will be invalid for whatever the new +// provider is. +async function wipeWorkspaceModelPreference(key, prev, next) { + if (prev === next) return; + const { Workspace } = require("../../models/workspace"); + await Workspace.resetWorkspaceChatModels(); +} + // This will force update .env variables which for any which reason were not able to be parsed or // read from an ENV file as this seems to be a complicating step for many so allowing people to write // to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks // and is simply for debugging when the .env not found issue many come across. -function updateENV(newENVs = {}, force = false) { +async function updateENV(newENVs = {}, force = false) { let error = ""; const validKeys = Object.keys(KEY_MAPPING); const ENV_KEYS = Object.keys(newENVs).filter( @@ -374,21 +384,25 @@ function updateENV(newENVs = {}, force = false) { ); const newValues = {}; - ENV_KEYS.forEach((key) => { - const { envKey, checks } = KEY_MAPPING[key]; - const value = newENVs[key]; + for (const key of ENV_KEYS) { + const { envKey, checks, postUpdate = [] } = KEY_MAPPING[key]; + const prevValue = process.env[envKey]; + const nextValue = newENVs[key]; const errors = checks - .map((validityCheck) => validityCheck(value, force)) + .map((validityCheck) => validityCheck(nextValue, force)) .filter((err) => typeof err === "string"); if (errors.length > 0) { error += errors.join("\n"); - return; + break; } - newValues[key] = value; - process.env[envKey] = value; - }); + newValues[key] = nextValue; + process.env[envKey] = nextValue; + + for (const postUpdateFunc of postUpdate) + await postUpdateFunc(key, prevValue, nextValue); + } return { newValues, error: error?.length > 0 ? error : false }; }