mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-10-04 01:40:12 +02:00
Per workspace model selection (#582)
* WIP model selection per workspace (migrations and openai saves properly * revert OpenAiOption * add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi * remove unneeded comments * update logic for when LLMProvider is reset, reset Ai provider files with master * remove frontend/api reset of workspace chat and move logic to updateENV add postUpdate callbacks to envs * set preferred model for chat on class instantiation * remove extra param * linting * remove unused var * refactor chat model selection on workspace * linting * add fallback for base path to localai models --------- Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
parent
bf503ee0e9
commit
90df37582b
@ -0,0 +1,120 @@
|
|||||||
|
import useGetProviderModels, {
|
||||||
|
DISABLED_PROVIDERS,
|
||||||
|
} from "./useGetProviderModels";
|
||||||
|
|
||||||
|
export default function ChatModelSelection({
|
||||||
|
settings,
|
||||||
|
workspace,
|
||||||
|
setHasChanges,
|
||||||
|
}) {
|
||||||
|
const { defaultModels, customModels, loading } = useGetProviderModels(
|
||||||
|
settings?.LLMProvider
|
||||||
|
);
|
||||||
|
if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null;
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div className="flex flex-col">
|
||||||
|
<label
|
||||||
|
htmlFor="name"
|
||||||
|
className="block text-sm font-medium text-white"
|
||||||
|
>
|
||||||
|
Chat model
|
||||||
|
</label>
|
||||||
|
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
|
||||||
|
The specific chat model that will be used for this workspace. If
|
||||||
|
empty, will use the system LLM preference.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<select
|
||||||
|
name="chatModel"
|
||||||
|
required={true}
|
||||||
|
disabled={true}
|
||||||
|
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
|
||||||
|
>
|
||||||
|
<option disabled={true} selected={true}>
|
||||||
|
-- waiting for models --
|
||||||
|
</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div className="flex flex-col">
|
||||||
|
<label htmlFor="name" className="block text-sm font-medium text-white">
|
||||||
|
Chat model{" "}
|
||||||
|
<span className="font-normal">({settings?.LLMProvider})</span>
|
||||||
|
</label>
|
||||||
|
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
|
||||||
|
The specific chat model that will be used for this workspace. If
|
||||||
|
empty, will use the system LLM preference.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<select
|
||||||
|
name="chatModel"
|
||||||
|
required={true}
|
||||||
|
onChange={() => {
|
||||||
|
setHasChanges(true);
|
||||||
|
}}
|
||||||
|
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
|
||||||
|
>
|
||||||
|
<option disabled={true} selected={workspace?.chatModel === null}>
|
||||||
|
System default
|
||||||
|
</option>
|
||||||
|
{defaultModels.length > 0 && (
|
||||||
|
<optgroup label="General models">
|
||||||
|
{defaultModels.map((model) => {
|
||||||
|
return (
|
||||||
|
<option
|
||||||
|
key={model}
|
||||||
|
value={model}
|
||||||
|
selected={workspace?.chatModel === model}
|
||||||
|
>
|
||||||
|
{model}
|
||||||
|
</option>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</optgroup>
|
||||||
|
)}
|
||||||
|
{Array.isArray(customModels) && customModels.length > 0 && (
|
||||||
|
<optgroup label="Custom models">
|
||||||
|
{customModels.map((model) => {
|
||||||
|
return (
|
||||||
|
<option
|
||||||
|
key={model.id}
|
||||||
|
value={model.id}
|
||||||
|
selected={workspace?.chatModel === model.id}
|
||||||
|
>
|
||||||
|
{model.id}
|
||||||
|
</option>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</optgroup>
|
||||||
|
)}
|
||||||
|
{/* For providers like TogetherAi where we partition model by creator entity. */}
|
||||||
|
{!Array.isArray(customModels) &&
|
||||||
|
Object.keys(customModels).length > 0 && (
|
||||||
|
<>
|
||||||
|
{Object.entries(customModels).map(([organization, models]) => (
|
||||||
|
<optgroup key={organization} label={organization}>
|
||||||
|
{models.map((model) => (
|
||||||
|
<option
|
||||||
|
key={model.id}
|
||||||
|
value={model.id}
|
||||||
|
selected={workspace?.chatModel === model.id}
|
||||||
|
>
|
||||||
|
{model.name}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</optgroup>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
import System from "@/models/system";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
|
// Providers which cannot use this feature for workspace<>model selection
|
||||||
|
export const DISABLED_PROVIDERS = ["azure", "lmstudio"];
|
||||||
|
const PROVIDER_DEFAULT_MODELS = {
|
||||||
|
openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"],
|
||||||
|
gemini: ["gemini-pro"],
|
||||||
|
anthropic: ["claude-2", "claude-instant-1"],
|
||||||
|
azure: [],
|
||||||
|
lmstudio: [],
|
||||||
|
localai: [],
|
||||||
|
ollama: [],
|
||||||
|
togetherai: [],
|
||||||
|
native: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
// For togetherAi, which has a large model list - we subgroup the options
|
||||||
|
// by their creator organization (eg: Meta, Mistral, etc)
|
||||||
|
// which makes selection easier to read.
|
||||||
|
function groupModels(models) {
|
||||||
|
return models.reduce((acc, model) => {
|
||||||
|
acc[model.organization] = acc[model.organization] || [];
|
||||||
|
acc[model.organization].push(model);
|
||||||
|
return acc;
|
||||||
|
}, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function useGetProviderModels(provider = null) {
|
||||||
|
const [defaultModels, setDefaultModels] = useState([]);
|
||||||
|
const [customModels, setCustomModels] = useState([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
async function fetchProviderModels() {
|
||||||
|
if (!provider) return;
|
||||||
|
const { models = [] } = await System.customModels(provider);
|
||||||
|
if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
|
||||||
|
setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
|
||||||
|
provider === "togetherai"
|
||||||
|
? setCustomModels(groupModels(models))
|
||||||
|
: setCustomModels(models);
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
fetchProviderModels();
|
||||||
|
}, [provider]);
|
||||||
|
|
||||||
|
return { defaultModels, customModels, loading };
|
||||||
|
}
|
@ -6,6 +6,7 @@ import System from "../../../../models/system";
|
|||||||
import PreLoader from "../../../Preloader";
|
import PreLoader from "../../../Preloader";
|
||||||
import { useParams } from "react-router-dom";
|
import { useParams } from "react-router-dom";
|
||||||
import showToast from "../../../../utils/toast";
|
import showToast from "../../../../utils/toast";
|
||||||
|
import ChatModelPreference from "./ChatModelPreference";
|
||||||
|
|
||||||
// Ensure that a type is correct before sending the body
|
// Ensure that a type is correct before sending the body
|
||||||
// to the backend.
|
// to the backend.
|
||||||
@ -26,7 +27,7 @@ function castToType(key, value) {
|
|||||||
return definitions[key].cast(value);
|
return definitions[key].cast(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function WorkspaceSettings({ active, workspace }) {
|
export default function WorkspaceSettings({ active, workspace, settings }) {
|
||||||
const { slug } = useParams();
|
const { slug } = useParams();
|
||||||
const formEl = useRef(null);
|
const formEl = useRef(null);
|
||||||
const [saving, setSaving] = useState(false);
|
const [saving, setSaving] = useState(false);
|
||||||
@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
|
|||||||
<div className="flex">
|
<div className="flex">
|
||||||
<div className="flex flex-col gap-y-4 w-1/2">
|
<div className="flex flex-col gap-y-4 w-1/2">
|
||||||
<div className="w-3/4 flex flex-col gap-y-4">
|
<div className="w-3/4 flex flex-col gap-y-4">
|
||||||
|
<ChatModelPreference
|
||||||
|
settings={settings}
|
||||||
|
workspace={workspace}
|
||||||
|
setHasChanges={setHasChanges}
|
||||||
|
/>
|
||||||
<div>
|
<div>
|
||||||
<div className="flex flex-col">
|
<div className="flex flex-col">
|
||||||
<label
|
<label
|
||||||
|
@ -117,6 +117,7 @@ const ManageWorkspace = ({ hideModal = noop, providedSlug = null }) => {
|
|||||||
<WorkspaceSettings
|
<WorkspaceSettings
|
||||||
active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
|
active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
|
||||||
workspace={workspace}
|
workspace={workspace}
|
||||||
|
settings={settings}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Suspense>
|
</Suspense>
|
||||||
|
@ -30,19 +30,17 @@ export default function GeneralLLMPreference() {
|
|||||||
const [hasChanges, setHasChanges] = useState(false);
|
const [hasChanges, setHasChanges] = useState(false);
|
||||||
const [settings, setSettings] = useState(null);
|
const [settings, setSettings] = useState(null);
|
||||||
const [loading, setLoading] = useState(true);
|
const [loading, setLoading] = useState(true);
|
||||||
|
|
||||||
const [searchQuery, setSearchQuery] = useState("");
|
const [searchQuery, setSearchQuery] = useState("");
|
||||||
const [filteredLLMs, setFilteredLLMs] = useState([]);
|
const [filteredLLMs, setFilteredLLMs] = useState([]);
|
||||||
const [selectedLLM, setSelectedLLM] = useState(null);
|
const [selectedLLM, setSelectedLLM] = useState(null);
|
||||||
|
|
||||||
const isHosted = window.location.hostname.includes("useanything.com");
|
const isHosted = window.location.hostname.includes("useanything.com");
|
||||||
|
|
||||||
const handleSubmit = async (e) => {
|
const handleSubmit = async (e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
const form = e.target;
|
const form = e.target;
|
||||||
const data = {};
|
const data = { LLMProvider: selectedLLM };
|
||||||
const formData = new FormData(form);
|
const formData = new FormData(form);
|
||||||
data.LLMProvider = selectedLLM;
|
|
||||||
for (var [key, value] of formData.entries()) data[key] = value;
|
for (var [key, value] of formData.entries()) data[key] = value;
|
||||||
const { error } = await System.updateSystem(data);
|
const { error } = await System.updateSystem(data);
|
||||||
setSaving(true);
|
setSaving(true);
|
||||||
|
@ -139,7 +139,7 @@ function apiSystemEndpoints(app) {
|
|||||||
*/
|
*/
|
||||||
try {
|
try {
|
||||||
const body = reqBody(request);
|
const body = reqBody(request);
|
||||||
const { newValues, error } = updateENV(body);
|
const { newValues, error } = await updateENV(body);
|
||||||
if (process.env.NODE_ENV === "production") await dumpENV();
|
if (process.env.NODE_ENV === "production") await dumpENV();
|
||||||
response.status(200).json({ newValues, error });
|
response.status(200).json({ newValues, error });
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
@ -290,7 +290,7 @@ function systemEndpoints(app) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const body = reqBody(request);
|
const body = reqBody(request);
|
||||||
const { newValues, error } = updateENV(body);
|
const { newValues, error } = await updateENV(body);
|
||||||
if (process.env.NODE_ENV === "production") await dumpENV();
|
if (process.env.NODE_ENV === "production") await dumpENV();
|
||||||
response.status(200).json({ newValues, error });
|
response.status(200).json({ newValues, error });
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@ -312,7 +312,7 @@ function systemEndpoints(app) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const { usePassword, newPassword } = reqBody(request);
|
const { usePassword, newPassword } = reqBody(request);
|
||||||
const { error } = updateENV(
|
const { error } = await updateENV(
|
||||||
{
|
{
|
||||||
AuthToken: usePassword ? newPassword : "",
|
AuthToken: usePassword ? newPassword : "",
|
||||||
JWTSecret: usePassword ? v4() : "",
|
JWTSecret: usePassword ? v4() : "",
|
||||||
@ -355,7 +355,7 @@ function systemEndpoints(app) {
|
|||||||
message_limit: 25,
|
message_limit: 25,
|
||||||
});
|
});
|
||||||
|
|
||||||
updateENV(
|
await updateENV(
|
||||||
{
|
{
|
||||||
AuthToken: "",
|
AuthToken: "",
|
||||||
JWTSecret: process.env.JWT_SECRET || v4(),
|
JWTSecret: process.env.JWT_SECRET || v4(),
|
||||||
|
@ -14,6 +14,7 @@ const Workspace = {
|
|||||||
"lastUpdatedAt",
|
"lastUpdatedAt",
|
||||||
"openAiPrompt",
|
"openAiPrompt",
|
||||||
"similarityThreshold",
|
"similarityThreshold",
|
||||||
|
"chatModel",
|
||||||
],
|
],
|
||||||
|
|
||||||
new: async function (name = null, creatorId = null) {
|
new: async function (name = null, creatorId = null) {
|
||||||
@ -191,6 +192,20 @@ const Workspace = {
|
|||||||
return { success: false, error: error.message };
|
return { success: false, error: error.message };
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
resetWorkspaceChatModels: async () => {
|
||||||
|
try {
|
||||||
|
await prisma.workspaces.updateMany({
|
||||||
|
data: {
|
||||||
|
chatModel: null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return { success: true, error: null };
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error resetting workspace chat models:", error.message);
|
||||||
|
return { success: false, error: error.message };
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = { Workspace };
|
module.exports = { Workspace };
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT;
|
@ -93,6 +93,7 @@ model workspaces {
|
|||||||
lastUpdatedAt DateTime @default(now())
|
lastUpdatedAt DateTime @default(now())
|
||||||
openAiPrompt String?
|
openAiPrompt String?
|
||||||
similarityThreshold Float? @default(0.25)
|
similarityThreshold Float? @default(0.25)
|
||||||
|
chatModel String?
|
||||||
workspace_users workspace_users[]
|
workspace_users workspace_users[]
|
||||||
documents workspace_documents[]
|
documents workspace_documents[]
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ const { v4 } = require("uuid");
|
|||||||
const { chatPrompt } = require("../../chats");
|
const { chatPrompt } = require("../../chats");
|
||||||
|
|
||||||
class AnthropicLLM {
|
class AnthropicLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
if (!process.env.ANTHROPIC_API_KEY)
|
if (!process.env.ANTHROPIC_API_KEY)
|
||||||
throw new Error("No Anthropic API key was set.");
|
throw new Error("No Anthropic API key was set.");
|
||||||
|
|
||||||
@ -12,7 +12,8 @@ class AnthropicLLM {
|
|||||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||||
});
|
});
|
||||||
this.anthropic = anthropic;
|
this.anthropic = anthropic;
|
||||||
this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2";
|
this.model =
|
||||||
|
modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2";
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
|
|||||||
const { chatPrompt } = require("../../chats");
|
const { chatPrompt } = require("../../chats");
|
||||||
|
|
||||||
class AzureOpenAiLLM {
|
class AzureOpenAiLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, _modelPreference = null) {
|
||||||
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
|
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
|
||||||
if (!process.env.AZURE_OPENAI_ENDPOINT)
|
if (!process.env.AZURE_OPENAI_ENDPOINT)
|
||||||
throw new Error("No Azure API endpoint was set.");
|
throw new Error("No Azure API endpoint was set.");
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
const { chatPrompt } = require("../../chats");
|
const { chatPrompt } = require("../../chats");
|
||||||
|
|
||||||
class GeminiLLM {
|
class GeminiLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
if (!process.env.GEMINI_API_KEY)
|
if (!process.env.GEMINI_API_KEY)
|
||||||
throw new Error("No Gemini API key was set.");
|
throw new Error("No Gemini API key was set.");
|
||||||
|
|
||||||
// Docs: https://ai.google.dev/tutorials/node_quickstart
|
// Docs: https://ai.google.dev/tutorials/node_quickstart
|
||||||
const { GoogleGenerativeAI } = require("@google/generative-ai");
|
const { GoogleGenerativeAI } = require("@google/generative-ai");
|
||||||
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
|
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
|
||||||
this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
|
this.model =
|
||||||
|
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
|
||||||
this.gemini = genAI.getGenerativeModel({ model: this.model });
|
this.gemini = genAI.getGenerativeModel({ model: this.model });
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats");
|
|||||||
|
|
||||||
// hybrid of openAi LLM chat completion for LMStudio
|
// hybrid of openAi LLM chat completion for LMStudio
|
||||||
class LMStudioLLM {
|
class LMStudioLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, _modelPreference = null) {
|
||||||
if (!process.env.LMSTUDIO_BASE_PATH)
|
if (!process.env.LMSTUDIO_BASE_PATH)
|
||||||
throw new Error("No LMStudio API Base Path was set.");
|
throw new Error("No LMStudio API Base Path was set.");
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ class LMStudioLLM {
|
|||||||
});
|
});
|
||||||
this.lmstudio = new OpenAIApi(config);
|
this.lmstudio = new OpenAIApi(config);
|
||||||
// When using LMStudios inference server - the model param is not required so
|
// When using LMStudios inference server - the model param is not required so
|
||||||
// we can stub it here.
|
// we can stub it here. LMStudio can only run one model at a time.
|
||||||
this.model = "model-placeholder";
|
this.model = "model-placeholder";
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
const { chatPrompt } = require("../../chats");
|
const { chatPrompt } = require("../../chats");
|
||||||
|
|
||||||
class LocalAiLLM {
|
class LocalAiLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
if (!process.env.LOCAL_AI_BASE_PATH)
|
if (!process.env.LOCAL_AI_BASE_PATH)
|
||||||
throw new Error("No LocalAI Base Path was set.");
|
throw new Error("No LocalAI Base Path was set.");
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ class LocalAiLLM {
|
|||||||
: {}),
|
: {}),
|
||||||
});
|
});
|
||||||
this.openai = new OpenAIApi(config);
|
this.openai = new OpenAIApi(config);
|
||||||
this.model = process.env.LOCAL_AI_MODEL_PREF;
|
this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF;
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -10,11 +10,11 @@ const ChatLlamaCpp = (...args) =>
|
|||||||
);
|
);
|
||||||
|
|
||||||
class NativeLLM {
|
class NativeLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
if (!process.env.NATIVE_LLM_MODEL_PREF)
|
if (!process.env.NATIVE_LLM_MODEL_PREF)
|
||||||
throw new Error("No local Llama model was set.");
|
throw new Error("No local Llama model was set.");
|
||||||
|
|
||||||
this.model = process.env.NATIVE_LLM_MODEL_PREF || null;
|
this.model = modelPreference || process.env.NATIVE_LLM_MODEL_PREF || null;
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -3,12 +3,12 @@ const { StringOutputParser } = require("langchain/schema/output_parser");
|
|||||||
|
|
||||||
// Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
|
// Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
|
||||||
class OllamaAILLM {
|
class OllamaAILLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
if (!process.env.OLLAMA_BASE_PATH)
|
if (!process.env.OLLAMA_BASE_PATH)
|
||||||
throw new Error("No Ollama Base Path was set.");
|
throw new Error("No Ollama Base Path was set.");
|
||||||
|
|
||||||
this.basePath = process.env.OLLAMA_BASE_PATH;
|
this.basePath = process.env.OLLAMA_BASE_PATH;
|
||||||
this.model = process.env.OLLAMA_MODEL_PREF;
|
this.model = modelPreference || process.env.OLLAMA_MODEL_PREF;
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -2,7 +2,7 @@ const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi");
|
|||||||
const { chatPrompt } = require("../../chats");
|
const { chatPrompt } = require("../../chats");
|
||||||
|
|
||||||
class OpenAiLLM {
|
class OpenAiLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
const { Configuration, OpenAIApi } = require("openai");
|
const { Configuration, OpenAIApi } = require("openai");
|
||||||
if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
|
if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ class OpenAiLLM {
|
|||||||
apiKey: process.env.OPEN_AI_KEY,
|
apiKey: process.env.OPEN_AI_KEY,
|
||||||
});
|
});
|
||||||
this.openai = new OpenAIApi(config);
|
this.openai = new OpenAIApi(config);
|
||||||
this.model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
|
this.model =
|
||||||
|
modelPreference || process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -6,7 +6,7 @@ function togetherAiModels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class TogetherAiLLM {
|
class TogetherAiLLM {
|
||||||
constructor(embedder = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
const { Configuration, OpenAIApi } = require("openai");
|
const { Configuration, OpenAIApi } = require("openai");
|
||||||
if (!process.env.TOGETHER_AI_API_KEY)
|
if (!process.env.TOGETHER_AI_API_KEY)
|
||||||
throw new Error("No TogetherAI API key was set.");
|
throw new Error("No TogetherAI API key was set.");
|
||||||
@ -16,7 +16,7 @@ class TogetherAiLLM {
|
|||||||
apiKey: process.env.TOGETHER_AI_API_KEY,
|
apiKey: process.env.TOGETHER_AI_API_KEY,
|
||||||
});
|
});
|
||||||
this.openai = new OpenAIApi(config);
|
this.openai = new OpenAIApi(config);
|
||||||
this.model = process.env.TOGETHER_AI_MODEL_PREF;
|
this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF;
|
||||||
this.limits = {
|
this.limits = {
|
||||||
history: this.promptWindowLimit() * 0.15,
|
history: this.promptWindowLimit() * 0.15,
|
||||||
system: this.promptWindowLimit() * 0.15,
|
system: this.promptWindowLimit() * 0.15,
|
||||||
|
@ -71,7 +71,7 @@ async function chatWithWorkspace(
|
|||||||
return await VALID_COMMANDS[command](workspace, message, uuid, user);
|
return await VALID_COMMANDS[command](workspace, message, uuid, user);
|
||||||
}
|
}
|
||||||
|
|
||||||
const LLMConnector = getLLMProvider();
|
const LLMConnector = getLLMProvider(workspace?.chatModel);
|
||||||
const VectorDb = getVectorDbClass();
|
const VectorDb = getVectorDbClass();
|
||||||
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
||||||
if (!safe) {
|
if (!safe) {
|
||||||
|
@ -30,7 +30,7 @@ async function streamChatWithWorkspace(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const LLMConnector = getLLMProvider();
|
const LLMConnector = getLLMProvider(workspace?.chatModel);
|
||||||
const VectorDb = getVectorDbClass();
|
const VectorDb = getVectorDbClass();
|
||||||
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
|
||||||
if (!safe) {
|
if (!safe) {
|
||||||
|
@ -17,7 +17,7 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
|
|||||||
case "localai":
|
case "localai":
|
||||||
return await localAIModels(basePath, apiKey);
|
return await localAIModels(basePath, apiKey);
|
||||||
case "ollama":
|
case "ollama":
|
||||||
return await ollamaAIModels(basePath, apiKey);
|
return await ollamaAIModels(basePath);
|
||||||
case "togetherai":
|
case "togetherai":
|
||||||
return await getTogetherAiModels();
|
return await getTogetherAiModels();
|
||||||
case "native-llm":
|
case "native-llm":
|
||||||
@ -53,7 +53,7 @@ async function openAiModels(apiKey = null) {
|
|||||||
async function localAIModels(basePath = null, apiKey = null) {
|
async function localAIModels(basePath = null, apiKey = null) {
|
||||||
const { Configuration, OpenAIApi } = require("openai");
|
const { Configuration, OpenAIApi } = require("openai");
|
||||||
const config = new Configuration({
|
const config = new Configuration({
|
||||||
basePath,
|
basePath: basePath || process.env.LOCAL_AI_BASE_PATH,
|
||||||
apiKey: apiKey || process.env.LOCAL_AI_API_KEY,
|
apiKey: apiKey || process.env.LOCAL_AI_API_KEY,
|
||||||
});
|
});
|
||||||
const openai = new OpenAIApi(config);
|
const openai = new OpenAIApi(config);
|
||||||
@ -70,13 +70,14 @@ async function localAIModels(basePath = null, apiKey = null) {
|
|||||||
return { models, error: null };
|
return { models, error: null };
|
||||||
}
|
}
|
||||||
|
|
||||||
async function ollamaAIModels(basePath = null, _apiKey = null) {
|
async function ollamaAIModels(basePath = null) {
|
||||||
let url;
|
let url;
|
||||||
try {
|
try {
|
||||||
new URL(basePath);
|
let urlPath = basePath ?? process.env.OLLAMA_BASE_PATH;
|
||||||
if (basePath.split("").slice(-1)?.[0] === "/")
|
new URL(urlPath);
|
||||||
|
if (urlPath.split("").slice(-1)?.[0] === "/")
|
||||||
throw new Error("BasePath Cannot end in /!");
|
throw new Error("BasePath Cannot end in /!");
|
||||||
url = basePath;
|
url = urlPath;
|
||||||
} catch {
|
} catch {
|
||||||
return { models: [], error: "Not a valid URL." };
|
return { models: [], error: "Not a valid URL." };
|
||||||
}
|
}
|
||||||
|
@ -24,37 +24,37 @@ function getVectorDbClass() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function getLLMProvider() {
|
function getLLMProvider(modelPreference = null) {
|
||||||
const vectorSelection = process.env.LLM_PROVIDER || "openai";
|
const vectorSelection = process.env.LLM_PROVIDER || "openai";
|
||||||
const embedder = getEmbeddingEngineSelection();
|
const embedder = getEmbeddingEngineSelection();
|
||||||
switch (vectorSelection) {
|
switch (vectorSelection) {
|
||||||
case "openai":
|
case "openai":
|
||||||
const { OpenAiLLM } = require("../AiProviders/openAi");
|
const { OpenAiLLM } = require("../AiProviders/openAi");
|
||||||
return new OpenAiLLM(embedder);
|
return new OpenAiLLM(embedder, modelPreference);
|
||||||
case "azure":
|
case "azure":
|
||||||
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
|
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
|
||||||
return new AzureOpenAiLLM(embedder);
|
return new AzureOpenAiLLM(embedder, modelPreference);
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
const { AnthropicLLM } = require("../AiProviders/anthropic");
|
const { AnthropicLLM } = require("../AiProviders/anthropic");
|
||||||
return new AnthropicLLM(embedder);
|
return new AnthropicLLM(embedder, modelPreference);
|
||||||
case "gemini":
|
case "gemini":
|
||||||
const { GeminiLLM } = require("../AiProviders/gemini");
|
const { GeminiLLM } = require("../AiProviders/gemini");
|
||||||
return new GeminiLLM(embedder);
|
return new GeminiLLM(embedder, modelPreference);
|
||||||
case "lmstudio":
|
case "lmstudio":
|
||||||
const { LMStudioLLM } = require("../AiProviders/lmStudio");
|
const { LMStudioLLM } = require("../AiProviders/lmStudio");
|
||||||
return new LMStudioLLM(embedder);
|
return new LMStudioLLM(embedder, modelPreference);
|
||||||
case "localai":
|
case "localai":
|
||||||
const { LocalAiLLM } = require("../AiProviders/localAi");
|
const { LocalAiLLM } = require("../AiProviders/localAi");
|
||||||
return new LocalAiLLM(embedder);
|
return new LocalAiLLM(embedder, modelPreference);
|
||||||
case "ollama":
|
case "ollama":
|
||||||
const { OllamaAILLM } = require("../AiProviders/ollama");
|
const { OllamaAILLM } = require("../AiProviders/ollama");
|
||||||
return new OllamaAILLM(embedder);
|
return new OllamaAILLM(embedder, modelPreference);
|
||||||
case "togetherai":
|
case "togetherai":
|
||||||
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
|
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
|
||||||
return new TogetherAiLLM(embedder);
|
return new TogetherAiLLM(embedder, modelPreference);
|
||||||
case "native":
|
case "native":
|
||||||
const { NativeLLM } = require("../AiProviders/native");
|
const { NativeLLM } = require("../AiProviders/native");
|
||||||
return new NativeLLM(embedder);
|
return new NativeLLM(embedder, modelPreference);
|
||||||
default:
|
default:
|
||||||
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
|
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ const KEY_MAPPING = {
|
|||||||
LLMProvider: {
|
LLMProvider: {
|
||||||
envKey: "LLM_PROVIDER",
|
envKey: "LLM_PROVIDER",
|
||||||
checks: [isNotEmpty, supportedLLM],
|
checks: [isNotEmpty, supportedLLM],
|
||||||
|
postUpdate: [wipeWorkspaceModelPreference],
|
||||||
},
|
},
|
||||||
// OpenAI Settings
|
// OpenAI Settings
|
||||||
OpenAiKey: {
|
OpenAiKey: {
|
||||||
@ -362,11 +363,20 @@ function validDockerizedUrl(input = "") {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the LLMProvider has changed we need to reset all workspace model preferences to
|
||||||
|
// null since the provider<>model name combination will be invalid for whatever the new
|
||||||
|
// provider is.
|
||||||
|
async function wipeWorkspaceModelPreference(key, prev, next) {
|
||||||
|
if (prev === next) return;
|
||||||
|
const { Workspace } = require("../../models/workspace");
|
||||||
|
await Workspace.resetWorkspaceChatModels();
|
||||||
|
}
|
||||||
|
|
||||||
// This will force update .env variables which for any which reason were not able to be parsed or
|
// This will force update .env variables which for any which reason were not able to be parsed or
|
||||||
// read from an ENV file as this seems to be a complicating step for many so allowing people to write
|
// read from an ENV file as this seems to be a complicating step for many so allowing people to write
|
||||||
// to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks
|
// to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks
|
||||||
// and is simply for debugging when the .env not found issue many come across.
|
// and is simply for debugging when the .env not found issue many come across.
|
||||||
function updateENV(newENVs = {}, force = false) {
|
async function updateENV(newENVs = {}, force = false) {
|
||||||
let error = "";
|
let error = "";
|
||||||
const validKeys = Object.keys(KEY_MAPPING);
|
const validKeys = Object.keys(KEY_MAPPING);
|
||||||
const ENV_KEYS = Object.keys(newENVs).filter(
|
const ENV_KEYS = Object.keys(newENVs).filter(
|
||||||
@ -374,21 +384,25 @@ function updateENV(newENVs = {}, force = false) {
|
|||||||
);
|
);
|
||||||
const newValues = {};
|
const newValues = {};
|
||||||
|
|
||||||
ENV_KEYS.forEach((key) => {
|
for (const key of ENV_KEYS) {
|
||||||
const { envKey, checks } = KEY_MAPPING[key];
|
const { envKey, checks, postUpdate = [] } = KEY_MAPPING[key];
|
||||||
const value = newENVs[key];
|
const prevValue = process.env[envKey];
|
||||||
|
const nextValue = newENVs[key];
|
||||||
const errors = checks
|
const errors = checks
|
||||||
.map((validityCheck) => validityCheck(value, force))
|
.map((validityCheck) => validityCheck(nextValue, force))
|
||||||
.filter((err) => typeof err === "string");
|
.filter((err) => typeof err === "string");
|
||||||
|
|
||||||
if (errors.length > 0) {
|
if (errors.length > 0) {
|
||||||
error += errors.join("\n");
|
error += errors.join("\n");
|
||||||
return;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
newValues[key] = value;
|
newValues[key] = nextValue;
|
||||||
process.env[envKey] = value;
|
process.env[envKey] = nextValue;
|
||||||
});
|
|
||||||
|
for (const postUpdateFunc of postUpdate)
|
||||||
|
await postUpdateFunc(key, prevValue, nextValue);
|
||||||
|
}
|
||||||
|
|
||||||
return { newValues, error: error?.length > 0 ? error : false };
|
return { newValues, error: error?.length > 0 ? error : false };
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user