mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-04 22:10:12 +01:00
Per workspace model selection (#582)
* WIP model selection per workspace (migrations and openai saves properly * revert OpenAiOption * add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi * remove unneeded comments * update logic for when LLMProvider is reset, reset Ai provider files with master * remove frontend/api reset of workspace chat and move logic to updateENV add postUpdate callbacks to envs * set preferred model for chat on class instantiation * remove extra param * linting * remove unused var * refactor chat model selection on workspace * linting * add fallback for base path to localai models --------- Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
parent
bf503ee0e9
commit
90df37582b
@ -0,0 +1,120 @@
|
||||
import useGetProviderModels, {
|
||||
DISABLED_PROVIDERS,
|
||||
} from "./useGetProviderModels";
|
||||
|
||||
export default function ChatModelSelection({
|
||||
settings,
|
||||
workspace,
|
||||
setHasChanges,
|
||||
}) {
|
||||
const { defaultModels, customModels, loading } = useGetProviderModels(
|
||||
settings?.LLMProvider
|
||||
);
|
||||
if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null;
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="block text-sm font-medium text-white"
|
||||
>
|
||||
Chat model
|
||||
</label>
|
||||
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
|
||||
The specific chat model that will be used for this workspace. If
|
||||
empty, will use the system LLM preference.
|
||||
</p>
|
||||
</div>
|
||||
<select
|
||||
name="chatModel"
|
||||
required={true}
|
||||
disabled={true}
|
||||
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
|
||||
>
|
||||
<option disabled={true} selected={true}>
|
||||
-- waiting for models --
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col">
|
||||
<label htmlFor="name" className="block text-sm font-medium text-white">
|
||||
Chat model{" "}
|
||||
<span className="font-normal">({settings?.LLMProvider})</span>
|
||||
</label>
|
||||
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
|
||||
The specific chat model that will be used for this workspace. If
|
||||
empty, will use the system LLM preference.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<select
|
||||
name="chatModel"
|
||||
required={true}
|
||||
onChange={() => {
|
||||
setHasChanges(true);
|
||||
}}
|
||||
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
|
||||
>
|
||||
<option disabled={true} selected={workspace?.chatModel === null}>
|
||||
System default
|
||||
</option>
|
||||
{defaultModels.length > 0 && (
|
||||
<optgroup label="General models">
|
||||
{defaultModels.map((model) => {
|
||||
return (
|
||||
<option
|
||||
key={model}
|
||||
value={model}
|
||||
selected={workspace?.chatModel === model}
|
||||
>
|
||||
{model}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</optgroup>
|
||||
)}
|
||||
{Array.isArray(customModels) && customModels.length > 0 && (
|
||||
<optgroup label="Custom models">
|
||||
{customModels.map((model) => {
|
||||
return (
|
||||
<option
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
selected={workspace?.chatModel === model.id}
|
||||
>
|
||||
{model.id}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</optgroup>
|
||||
)}
|
||||
{/* For providers like TogetherAi where we partition model by creator entity. */}
|
||||
{!Array.isArray(customModels) &&
|
||||
Object.keys(customModels).length > 0 && (
|
||||
<>
|
||||
{Object.entries(customModels).map(([organization, models]) => (
|
||||
<optgroup key={organization} label={organization}>
|
||||
{models.map((model) => (
|
||||
<option
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
selected={workspace?.chatModel === model.id}
|
||||
>
|
||||
{model.name}
|
||||
</option>
|
||||
))}
|
||||
</optgroup>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</select>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
import System from "@/models/system";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
// Providers which cannot use this feature for workspace<>model selection
|
||||
export const DISABLED_PROVIDERS = ["azure", "lmstudio"];
|
||||
const PROVIDER_DEFAULT_MODELS = {
|
||||
openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"],
|
||||
gemini: ["gemini-pro"],
|
||||
anthropic: ["claude-2", "claude-instant-1"],
|
||||
azure: [],
|
||||
lmstudio: [],
|
||||
localai: [],
|
||||
ollama: [],
|
||||
togetherai: [],
|
||||
native: [],
|
||||
};
|
||||
|
||||
// For togetherAi, which has a large model list - we subgroup the options
|
||||
// by their creator organization (eg: Meta, Mistral, etc)
|
||||
// which makes selection easier to read.
|
||||
function groupModels(models) {
|
||||
return models.reduce((acc, model) => {
|
||||
acc[model.organization] = acc[model.organization] || [];
|
||||
acc[model.organization].push(model);
|
||||
return acc;
|
||||
}, {});
|
||||
}
|
||||
|
||||
export default function useGetProviderModels(provider = null) {
|
||||
const [defaultModels, setDefaultModels] = useState([]);
|
||||
const [customModels, setCustomModels] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchProviderModels() {
|
||||
if (!provider) return;
|
||||
const { models = [] } = await System.customModels(provider);
|
||||
if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
|
||||
setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
|
||||
provider === "togetherai"
|
||||
? setCustomModels(groupModels(models))
|
||||
: setCustomModels(models);
|
||||
setLoading(false);
|
||||
}
|
||||
fetchProviderModels();
|
||||
}, [provider]);
|
||||
|
||||
return { defaultModels, customModels, loading };
|
||||
}
|
@ -6,6 +6,7 @@ import System from "../../../../models/system";
|
||||
import PreLoader from "../../../Preloader";
|
||||
import { useParams } from "react-router-dom";
|
||||
import showToast from "../../../../utils/toast";
|
||||
import ChatModelPreference from "./ChatModelPreference";
|
||||
|
||||
// Ensure that a type is correct before sending the body
|
||||
// to the backend.
|
||||
@ -26,7 +27,7 @@ function castToType(key, value) {
|
||||
return definitions[key].cast(value);
|
||||
}
|
||||
|
||||
export default function WorkspaceSettings({ active, workspace }) {
|
||||
export default function WorkspaceSettings({ active, workspace, settings }) {
|
||||
const { slug } = useParams();
|
||||
const formEl = useRef(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
|
||||
<div className="flex">
|
||||
<div className="flex flex-col gap-y-4 w-1/2">
|
||||
<div className="w-3/4 flex flex-col gap-y-4">
|
||||
<ChatModelPreference
|
||||
settings={settings}
|
||||
workspace={workspace}
|
||||
setHasChanges={setHasChanges}
|
||||
/>
|
||||
<div>
|
||||
<div className="flex flex-col">
|
||||
<label
|
||||
|
@ -117,6 +117,7 @@ const ManageWorkspace = ({ hideModal = noop, providedSlug = null }) => {
|
||||
<WorkspaceSettings
|
||||
active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
|
||||
workspace={workspace}
|
||||
settings={settings}
|
||||
/>
|
||||
</div>
|
||||
</Suspense>
|
||||
|
@ -30,19 +30,17 @@ export default function GeneralLLMPreference() {
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
const [settings, setSettings] = useState(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [filteredLLMs, setFilteredLLMs] = useState([]);
|
||||
const [selectedLLM, setSelectedLLM] = useState(null);
|
||||
|
||||
const isHosted = window.location.hostname.includes("useanything.com");
|
||||
|
||||
const handleSubmit = async (e) => {
|
||||
e.preventDefault();
|
||||
const form = e.target;
|
||||
const data = {};
|
||||
const data = { LLMProvider: selectedLLM };
|
||||
const formData = new FormData(form);
|
||||
data.LLMProvider = selectedLLM;
|
||||
|
||||
for (var [key, value] of formData.entries()) data[key] = value;
|
||||
const { error } = await System.updateSystem(data);
|
||||
setSaving(true);
|
||||
|
@ -139,7 +139,7 @@ function apiSystemEndpoints(app) {
|
||||
*/
|
||||
try {
|
||||
const body = reqBody(request);
|
||||
const { newValues, error } = updateENV(body);
|
||||
const { newValues, error } = await updateENV(body);
|
||||
if (process.env.NODE_ENV === "production") await dumpENV();
|
||||
response.status(200).json({ newValues, error });
|
||||
} catch (e) {
|
||||
|
@ -290,7 +290,7 @@ function systemEndpoints(app) {
|
||||
}
|
||||
|
||||
const body = reqBody(request);
|
||||
const { newValues, error } = updateENV(body);
|
||||
const { newValues, error } = await updateENV(body);
|
||||
if (process.env.NODE_ENV === "production") await dumpENV();
|
||||
response.status(200).json({ newValues, error });
|
||||
} catch (e) {
|
||||
@ -312,7 +312,7 @@ function systemEndpoints(app) {
|
||||
}
|
||||
|
||||
const { usePassword, newPassword } = reqBody(request);
|
||||
const { error } = updateENV(
|
||||
const { error } = await updateENV(
|
||||
{
|
||||
AuthToken: usePassword ? newPassword : "",
|
||||
JWTSecret: usePassword ? v4() : "",
|
||||
@ -355,7 +355,7 @@ function systemEndpoints(app) {
|
||||
message_limit: 25,
|
||||
});
|
||||
|
||||
updateENV(
|
||||
await updateENV(
|
||||
{
|
||||
AuthToken: "",
|
||||
JWTSecret: process.env.JWT_SECRET || v4(),
|
||||
|
@ -14,6 +14,7 @@ const Workspace = {
|
||||
"lastUpdatedAt",
|
||||
"openAiPrompt",
|
||||
"similarityThreshold",
|
||||
"chatModel",
|
||||
],
|
||||
|
||||
new: async function (name = null, creatorId = null) {
|
||||
@ -191,6 +192,20 @@ const Workspace = {
|
||||
return { success: false, error: error.message };
|
||||
}
|
||||
},
|
||||
|
||||
resetWorkspaceChatModels: async () => {
|
||||
try {
|
||||
await prisma.workspaces.updateMany({
|
||||
data: {
|
||||
chatModel: null,
|
||||
},
|
||||
});
|
||||
return { success: true, error: null };
|
||||
} catch (error) {
|
||||
console.error("Error resetting workspace chat models:", error.message);
|
||||
return { success: false, error: error.message };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = { Workspace };
|
||||
|
@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT;
|
@ -93,6 +93,7 @@ model workspaces {
|
||||
lastUpdatedAt DateTime @default(now())
|
||||
openAiPrompt String?
|
||||
similarityThreshold Float? @default(0.25)
|
||||
chatModel String?
|
||||
workspace_users workspace_users[]
|
||||
documents workspace_documents[]
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ const { v4 } = require("uuid");
|
||||
const { chatPrompt } = require("../../chats");
|
||||
|
||||
class AnthropicLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
if (!process.env.ANTHROPIC_API_KEY)
|
||||
throw new Error("No Anthropic API key was set.");
|
||||
|
||||
@ -12,7 +12,8 @@ class AnthropicLLM {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
});
|
||||
this.anthropic = anthropic;
|
||||
this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2";
|
||||
this.model =
|
||||
modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2";
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
|
||||
const { chatPrompt } = require("../../chats");
|
||||
|
||||
class AzureOpenAiLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, _modelPreference = null) {
|
||||
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
|
||||
if (!process.env.AZURE_OPENAI_ENDPOINT)
|
||||
throw new Error("No Azure API endpoint was set.");
|
||||
|
@ -1,14 +1,15 @@
|
||||
const { chatPrompt } = require("../../chats");
|
||||
|
||||
class GeminiLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
if (!process.env.GEMINI_API_KEY)
|
||||
throw new Error("No Gemini API key was set.");
|
||||
|
||||
// Docs: https://ai.google.dev/tutorials/node_quickstart
|
||||
const { GoogleGenerativeAI } = require("@google/generative-ai");
|
||||
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
|
||||
this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
|
||||
this.model =
|
||||
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
|
||||
this.gemini = genAI.getGenerativeModel({ model: this.model });
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
|
@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats");
|
||||
|
||||
// hybrid of openAi LLM chat completion for LMStudio
|
||||
class LMStudioLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, _modelPreference = null) {
|
||||
if (!process.env.LMSTUDIO_BASE_PATH)
|
||||
throw new Error("No LMStudio API Base Path was set.");
|
||||
|
||||
@ -12,7 +12,7 @@ class LMStudioLLM {
|
||||
});
|
||||
this.lmstudio = new OpenAIApi(config);
|
||||
// When using LMStudios inference server - the model param is not required so
|
||||
// we can stub it here.
|
||||
// we can stub it here. LMStudio can only run one model at a time.
|
||||
this.model = "model-placeholder";
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
|
@ -1,7 +1,7 @@
|
||||
const { chatPrompt } = require("../../chats");
|
||||
|
||||
class LocalAiLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
if (!process.env.LOCAL_AI_BASE_PATH)
|
||||
throw new Error("No LocalAI Base Path was set.");
|
||||
|
||||
@ -15,7 +15,7 @@ class LocalAiLLM {
|
||||
: {}),
|
||||
});
|
||||
this.openai = new OpenAIApi(config);
|
||||
this.model = process.env.LOCAL_AI_MODEL_PREF;
|
||||
this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF;
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -10,11 +10,11 @@ const ChatLlamaCpp = (...args) =>
|
||||
);
|
||||
|
||||
class NativeLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
if (!process.env.NATIVE_LLM_MODEL_PREF)
|
||||
throw new Error("No local Llama model was set.");
|
||||
|
||||
this.model = process.env.NATIVE_LLM_MODEL_PREF || null;
|
||||
this.model = modelPreference || process.env.NATIVE_LLM_MODEL_PREF || null;
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -3,12 +3,12 @@ const { StringOutputParser } = require("langchain/schema/output_parser");
|
||||
|
||||
// Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
|
||||
class OllamaAILLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
if (!process.env.OLLAMA_BASE_PATH)
|
||||
throw new Error("No Ollama Base Path was set.");
|
||||
|
||||
this.basePath = process.env.OLLAMA_BASE_PATH;
|
||||
this.model = process.env.OLLAMA_MODEL_PREF;
|
||||
this.model = modelPreference || process.env.OLLAMA_MODEL_PREF;
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -2,7 +2,7 @@ const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi");
|
||||
const { chatPrompt } = require("../../chats");
|
||||
|
||||
class OpenAiLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
const { Configuration, OpenAIApi } = require("openai");
|
||||
if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
|
||||
|
||||
@ -10,7 +10,8 @@ class OpenAiLLM {
|
||||
apiKey: process.env.OPEN_AI_KEY,
|
||||
});
|
||||
this.openai = new OpenAIApi(config);
|
||||
this.model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
|
||||
this.model =
|
||||
modelPreference || process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -6,7 +6,7 @@ function togetherAiModels() {
|
||||
}
|
||||
|
||||
class TogetherAiLLM {
|
||||
constructor(embedder = null) {
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
const { Configuration, OpenAIApi } = require("openai");
|
||||
if (!process.env.TOGETHER_AI_API_KEY)
|
||||
throw new Error("No TogetherAI API key was set.");
|
||||
@ -16,7 +16,7 @@ class TogetherAiLLM {
|
||||
apiKey: process.env.TOGETHER_AI_API_KEY,
|
||||
});
|
||||
this.openai = new OpenAIApi(config);
|
||||
this.model = process.env.TOGETHER_AI_MODEL_PREF;
|
||||
this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF;
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
|
@ -71,7 +71,7 @@ async function chatWithWorkspace(
|
||||
return await VALID_COMMANDS[command](workspace, message, uuid, user);
|
||||
}
|
||||
|
||||
const LLMConnector = getLLMProvider();
|
||||
const LLMConnector = getLLMProvider(workspace?.chatModel);
|
||||
const VectorDb = getVectorDbClass();
|
||||
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
||||
if (!safe) {
|
||||
|
@ -30,7 +30,7 @@ async function streamChatWithWorkspace(
|
||||
return;
|
||||
}
|
||||
|
||||
const LLMConnector = getLLMProvider();
|
||||
const LLMConnector = getLLMProvider(workspace?.chatModel);
|
||||
const VectorDb = getVectorDbClass();
|
||||
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
||||
if (!safe) {
|
||||
|
@ -17,7 +17,7 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
|
||||
case "localai":
|
||||
return await localAIModels(basePath, apiKey);
|
||||
case "ollama":
|
||||
return await ollamaAIModels(basePath, apiKey);
|
||||
return await ollamaAIModels(basePath);
|
||||
case "togetherai":
|
||||
return await getTogetherAiModels();
|
||||
case "native-llm":
|
||||
@ -53,7 +53,7 @@ async function openAiModels(apiKey = null) {
|
||||
async function localAIModels(basePath = null, apiKey = null) {
|
||||
const { Configuration, OpenAIApi } = require("openai");
|
||||
const config = new Configuration({
|
||||
basePath,
|
||||
basePath: basePath || process.env.LOCAL_AI_BASE_PATH,
|
||||
apiKey: apiKey || process.env.LOCAL_AI_API_KEY,
|
||||
});
|
||||
const openai = new OpenAIApi(config);
|
||||
@ -70,13 +70,14 @@ async function localAIModels(basePath = null, apiKey = null) {
|
||||
return { models, error: null };
|
||||
}
|
||||
|
||||
async function ollamaAIModels(basePath = null, _apiKey = null) {
|
||||
async function ollamaAIModels(basePath = null) {
|
||||
let url;
|
||||
try {
|
||||
new URL(basePath);
|
||||
if (basePath.split("").slice(-1)?.[0] === "/")
|
||||
let urlPath = basePath ?? process.env.OLLAMA_BASE_PATH;
|
||||
new URL(urlPath);
|
||||
if (urlPath.split("").slice(-1)?.[0] === "/")
|
||||
throw new Error("BasePath Cannot end in /!");
|
||||
url = basePath;
|
||||
url = urlPath;
|
||||
} catch {
|
||||
return { models: [], error: "Not a valid URL." };
|
||||
}
|
||||
|
@ -24,37 +24,37 @@ function getVectorDbClass() {
|
||||
}
|
||||
}
|
||||
|
||||
function getLLMProvider() {
|
||||
function getLLMProvider(modelPreference = null) {
|
||||
const vectorSelection = process.env.LLM_PROVIDER || "openai";
|
||||
const embedder = getEmbeddingEngineSelection();
|
||||
switch (vectorSelection) {
|
||||
case "openai":
|
||||
const { OpenAiLLM } = require("../AiProviders/openAi");
|
||||
return new OpenAiLLM(embedder);
|
||||
return new OpenAiLLM(embedder, modelPreference);
|
||||
case "azure":
|
||||
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
|
||||
return new AzureOpenAiLLM(embedder);
|
||||
return new AzureOpenAiLLM(embedder, modelPreference);
|
||||
case "anthropic":
|
||||
const { AnthropicLLM } = require("../AiProviders/anthropic");
|
||||
return new AnthropicLLM(embedder);
|
||||
return new AnthropicLLM(embedder, modelPreference);
|
||||
case "gemini":
|
||||
const { GeminiLLM } = require("../AiProviders/gemini");
|
||||
return new GeminiLLM(embedder);
|
||||
return new GeminiLLM(embedder, modelPreference);
|
||||
case "lmstudio":
|
||||
const { LMStudioLLM } = require("../AiProviders/lmStudio");
|
||||
return new LMStudioLLM(embedder);
|
||||
return new LMStudioLLM(embedder, modelPreference);
|
||||
case "localai":
|
||||
const { LocalAiLLM } = require("../AiProviders/localAi");
|
||||
return new LocalAiLLM(embedder);
|
||||
return new LocalAiLLM(embedder, modelPreference);
|
||||
case "ollama":
|
||||
const { OllamaAILLM } = require("../AiProviders/ollama");
|
||||
return new OllamaAILLM(embedder);
|
||||
return new OllamaAILLM(embedder, modelPreference);
|
||||
case "togetherai":
|
||||
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
|
||||
return new TogetherAiLLM(embedder);
|
||||
return new TogetherAiLLM(embedder, modelPreference);
|
||||
case "native":
|
||||
const { NativeLLM } = require("../AiProviders/native");
|
||||
return new NativeLLM(embedder);
|
||||
return new NativeLLM(embedder, modelPreference);
|
||||
default:
|
||||
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ const KEY_MAPPING = {
|
||||
LLMProvider: {
|
||||
envKey: "LLM_PROVIDER",
|
||||
checks: [isNotEmpty, supportedLLM],
|
||||
postUpdate: [wipeWorkspaceModelPreference],
|
||||
},
|
||||
// OpenAI Settings
|
||||
OpenAiKey: {
|
||||
@ -362,11 +363,20 @@ function validDockerizedUrl(input = "") {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If the LLMProvider has changed we need to reset all workspace model preferences to
|
||||
// null since the provider<>model name combination will be invalid for whatever the new
|
||||
// provider is.
|
||||
async function wipeWorkspaceModelPreference(key, prev, next) {
|
||||
if (prev === next) return;
|
||||
const { Workspace } = require("../../models/workspace");
|
||||
await Workspace.resetWorkspaceChatModels();
|
||||
}
|
||||
|
||||
// This will force update .env variables which for any which reason were not able to be parsed or
|
||||
// read from an ENV file as this seems to be a complicating step for many so allowing people to write
|
||||
// to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks
|
||||
// and is simply for debugging when the .env not found issue many come across.
|
||||
function updateENV(newENVs = {}, force = false) {
|
||||
async function updateENV(newENVs = {}, force = false) {
|
||||
let error = "";
|
||||
const validKeys = Object.keys(KEY_MAPPING);
|
||||
const ENV_KEYS = Object.keys(newENVs).filter(
|
||||
@ -374,21 +384,25 @@ function updateENV(newENVs = {}, force = false) {
|
||||
);
|
||||
const newValues = {};
|
||||
|
||||
ENV_KEYS.forEach((key) => {
|
||||
const { envKey, checks } = KEY_MAPPING[key];
|
||||
const value = newENVs[key];
|
||||
for (const key of ENV_KEYS) {
|
||||
const { envKey, checks, postUpdate = [] } = KEY_MAPPING[key];
|
||||
const prevValue = process.env[envKey];
|
||||
const nextValue = newENVs[key];
|
||||
const errors = checks
|
||||
.map((validityCheck) => validityCheck(value, force))
|
||||
.map((validityCheck) => validityCheck(nextValue, force))
|
||||
.filter((err) => typeof err === "string");
|
||||
|
||||
if (errors.length > 0) {
|
||||
error += errors.join("\n");
|
||||
return;
|
||||
break;
|
||||
}
|
||||
|
||||
newValues[key] = value;
|
||||
process.env[envKey] = value;
|
||||
});
|
||||
newValues[key] = nextValue;
|
||||
process.env[envKey] = nextValue;
|
||||
|
||||
for (const postUpdateFunc of postUpdate)
|
||||
await postUpdateFunc(key, prevValue, nextValue);
|
||||
}
|
||||
|
||||
return { newValues, error: error?.length > 0 ? error : false };
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user