Per workspace model selection (#582)

* WIP model selection per workspace (migrations and openai saves properly

* revert OpenAiOption

* add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi

* remove unneeded comments

* update logic for when LLMProvider is reset, reset Ai provider files with master

* remove frontend/api reset of workspace chat and move logic to updateENV
add postUpdate callbacks to envs

* set preferred model for chat on class instantiation

* remove extra param

* linting

* remove unused var

* refactor chat model selection on workspace

* linting

* add fallback for base path to localai models

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-01-17 12:59:25 -08:00 committed by GitHub
parent bf503ee0e9
commit 90df37582b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 263 additions and 53 deletions

View File

@ -0,0 +1,120 @@
import useGetProviderModels, {
DISABLED_PROVIDERS,
} from "./useGetProviderModels";
export default function ChatModelSelection({
settings,
workspace,
setHasChanges,
}) {
const { defaultModels, customModels, loading } = useGetProviderModels(
settings?.LLMProvider
);
if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null;
if (loading) {
return (
<div>
<div className="flex flex-col">
<label
htmlFor="name"
className="block text-sm font-medium text-white"
>
Chat model
</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
The specific chat model that will be used for this workspace. If
empty, will use the system LLM preference.
</p>
</div>
<select
name="chatModel"
required={true}
disabled={true}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- waiting for models --
</option>
</select>
</div>
);
}
return (
<div>
<div className="flex flex-col">
<label htmlFor="name" className="block text-sm font-medium text-white">
Chat model{" "}
<span className="font-normal">({settings?.LLMProvider})</span>
</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
The specific chat model that will be used for this workspace. If
empty, will use the system LLM preference.
</p>
</div>
<select
name="chatModel"
required={true}
onChange={() => {
setHasChanges(true);
}}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
>
<option disabled={true} selected={workspace?.chatModel === null}>
System default
</option>
{defaultModels.length > 0 && (
<optgroup label="General models">
{defaultModels.map((model) => {
return (
<option
key={model}
value={model}
selected={workspace?.chatModel === model}
>
{model}
</option>
);
})}
</optgroup>
)}
{Array.isArray(customModels) && customModels.length > 0 && (
<optgroup label="Custom models">
{customModels.map((model) => {
return (
<option
key={model.id}
value={model.id}
selected={workspace?.chatModel === model.id}
>
{model.id}
</option>
);
})}
</optgroup>
)}
{/* For providers like TogetherAi where we partition model by creator entity. */}
{!Array.isArray(customModels) &&
Object.keys(customModels).length > 0 && (
<>
{Object.entries(customModels).map(([organization, models]) => (
<optgroup key={organization} label={organization}>
{models.map((model) => (
<option
key={model.id}
value={model.id}
selected={workspace?.chatModel === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</>
)}
</select>
</div>
);
}

View File

@ -0,0 +1,49 @@
import System from "@/models/system";
import { useEffect, useState } from "react";
// Providers which cannot use this feature for workspace<>model selection
export const DISABLED_PROVIDERS = ["azure", "lmstudio"];
const PROVIDER_DEFAULT_MODELS = {
openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"],
gemini: ["gemini-pro"],
anthropic: ["claude-2", "claude-instant-1"],
azure: [],
lmstudio: [],
localai: [],
ollama: [],
togetherai: [],
native: [],
};
// For togetherAi, which has a large model list - we subgroup the options
// by their creator organization (eg: Meta, Mistral, etc)
// which makes selection easier to read.
function groupModels(models) {
return models.reduce((acc, model) => {
acc[model.organization] = acc[model.organization] || [];
acc[model.organization].push(model);
return acc;
}, {});
}
export default function useGetProviderModels(provider = null) {
const [defaultModels, setDefaultModels] = useState([]);
const [customModels, setCustomModels] = useState([]);
const [loading, setLoading] = useState(true);
useEffect(() => {
async function fetchProviderModels() {
if (!provider) return;
const { models = [] } = await System.customModels(provider);
if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
provider === "togetherai"
? setCustomModels(groupModels(models))
: setCustomModels(models);
setLoading(false);
}
fetchProviderModels();
}, [provider]);
return { defaultModels, customModels, loading };
}

View File

@ -6,6 +6,7 @@ import System from "../../../../models/system";
import PreLoader from "../../../Preloader";
import { useParams } from "react-router-dom";
import showToast from "../../../../utils/toast";
import ChatModelPreference from "./ChatModelPreference";
// Ensure that a type is correct before sending the body
// to the backend.
@ -26,7 +27,7 @@ function castToType(key, value) {
return definitions[key].cast(value);
}
export default function WorkspaceSettings({ active, workspace }) {
export default function WorkspaceSettings({ active, workspace, settings }) {
const { slug } = useParams();
const formEl = useRef(null);
const [saving, setSaving] = useState(false);
@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
<div className="flex">
<div className="flex flex-col gap-y-4 w-1/2">
<div className="w-3/4 flex flex-col gap-y-4">
<ChatModelPreference
settings={settings}
workspace={workspace}
setHasChanges={setHasChanges}
/>
<div>
<div className="flex flex-col">
<label

View File

@ -117,6 +117,7 @@ const ManageWorkspace = ({ hideModal = noop, providedSlug = null }) => {
<WorkspaceSettings
active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
workspace={workspace}
settings={settings}
/>
</div>
</Suspense>

View File

@ -30,19 +30,17 @@ export default function GeneralLLMPreference() {
const [hasChanges, setHasChanges] = useState(false);
const [settings, setSettings] = useState(null);
const [loading, setLoading] = useState(true);
const [searchQuery, setSearchQuery] = useState("");
const [filteredLLMs, setFilteredLLMs] = useState([]);
const [selectedLLM, setSelectedLLM] = useState(null);
const isHosted = window.location.hostname.includes("useanything.com");
const handleSubmit = async (e) => {
e.preventDefault();
const form = e.target;
const data = {};
const data = { LLMProvider: selectedLLM };
const formData = new FormData(form);
data.LLMProvider = selectedLLM;
for (var [key, value] of formData.entries()) data[key] = value;
const { error } = await System.updateSystem(data);
setSaving(true);

View File

@ -139,7 +139,7 @@ function apiSystemEndpoints(app) {
*/
try {
const body = reqBody(request);
const { newValues, error } = updateENV(body);
const { newValues, error } = await updateENV(body);
if (process.env.NODE_ENV === "production") await dumpENV();
response.status(200).json({ newValues, error });
} catch (e) {

View File

@ -290,7 +290,7 @@ function systemEndpoints(app) {
}
const body = reqBody(request);
const { newValues, error } = updateENV(body);
const { newValues, error } = await updateENV(body);
if (process.env.NODE_ENV === "production") await dumpENV();
response.status(200).json({ newValues, error });
} catch (e) {
@ -312,7 +312,7 @@ function systemEndpoints(app) {
}
const { usePassword, newPassword } = reqBody(request);
const { error } = updateENV(
const { error } = await updateENV(
{
AuthToken: usePassword ? newPassword : "",
JWTSecret: usePassword ? v4() : "",
@ -355,7 +355,7 @@ function systemEndpoints(app) {
message_limit: 25,
});
updateENV(
await updateENV(
{
AuthToken: "",
JWTSecret: process.env.JWT_SECRET || v4(),

View File

@ -14,6 +14,7 @@ const Workspace = {
"lastUpdatedAt",
"openAiPrompt",
"similarityThreshold",
"chatModel",
],
new: async function (name = null, creatorId = null) {
@ -191,6 +192,20 @@ const Workspace = {
return { success: false, error: error.message };
}
},
resetWorkspaceChatModels: async () => {
try {
await prisma.workspaces.updateMany({
data: {
chatModel: null,
},
});
return { success: true, error: null };
} catch (error) {
console.error("Error resetting workspace chat models:", error.message);
return { success: false, error: error.message };
}
},
};
module.exports = { Workspace };

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT;

View File

@ -93,6 +93,7 @@ model workspaces {
lastUpdatedAt DateTime @default(now())
openAiPrompt String?
similarityThreshold Float? @default(0.25)
chatModel String?
workspace_users workspace_users[]
documents workspace_documents[]
}

View File

@ -2,7 +2,7 @@ const { v4 } = require("uuid");
const { chatPrompt } = require("../../chats");
class AnthropicLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.ANTHROPIC_API_KEY)
throw new Error("No Anthropic API key was set.");
@ -12,7 +12,8 @@ class AnthropicLLM {
apiKey: process.env.ANTHROPIC_API_KEY,
});
this.anthropic = anthropic;
this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2";
this.model =
modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2";
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
const { chatPrompt } = require("../../chats");
class AzureOpenAiLLM {
constructor(embedder = null) {
constructor(embedder = null, _modelPreference = null) {
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
if (!process.env.AZURE_OPENAI_ENDPOINT)
throw new Error("No Azure API endpoint was set.");

View File

@ -1,14 +1,15 @@
const { chatPrompt } = require("../../chats");
class GeminiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.GEMINI_API_KEY)
throw new Error("No Gemini API key was set.");
// Docs: https://ai.google.dev/tutorials/node_quickstart
const { GoogleGenerativeAI } = require("@google/generative-ai");
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
this.model =
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
this.gemini = genAI.getGenerativeModel({ model: this.model });
this.limits = {
history: this.promptWindowLimit() * 0.15,

View File

@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats");
// hybrid of openAi LLM chat completion for LMStudio
class LMStudioLLM {
constructor(embedder = null) {
constructor(embedder = null, _modelPreference = null) {
if (!process.env.LMSTUDIO_BASE_PATH)
throw new Error("No LMStudio API Base Path was set.");
@ -12,7 +12,7 @@ class LMStudioLLM {
});
this.lmstudio = new OpenAIApi(config);
// When using LMStudios inference server - the model param is not required so
// we can stub it here.
// we can stub it here. LMStudio can only run one model at a time.
this.model = "model-placeholder";
this.limits = {
history: this.promptWindowLimit() * 0.15,

View File

@ -1,7 +1,7 @@
const { chatPrompt } = require("../../chats");
class LocalAiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.LOCAL_AI_BASE_PATH)
throw new Error("No LocalAI Base Path was set.");
@ -15,7 +15,7 @@ class LocalAiLLM {
: {}),
});
this.openai = new OpenAIApi(config);
this.model = process.env.LOCAL_AI_MODEL_PREF;
this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -10,11 +10,11 @@ const ChatLlamaCpp = (...args) =>
);
class NativeLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.NATIVE_LLM_MODEL_PREF)
throw new Error("No local Llama model was set.");
this.model = process.env.NATIVE_LLM_MODEL_PREF || null;
this.model = modelPreference || process.env.NATIVE_LLM_MODEL_PREF || null;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -3,12 +3,12 @@ const { StringOutputParser } = require("langchain/schema/output_parser");
// Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
class OllamaAILLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.OLLAMA_BASE_PATH)
throw new Error("No Ollama Base Path was set.");
this.basePath = process.env.OLLAMA_BASE_PATH;
this.model = process.env.OLLAMA_MODEL_PREF;
this.model = modelPreference || process.env.OLLAMA_MODEL_PREF;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -2,7 +2,7 @@ const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi");
const { chatPrompt } = require("../../chats");
class OpenAiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
const { Configuration, OpenAIApi } = require("openai");
if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
@ -10,7 +10,8 @@ class OpenAiLLM {
apiKey: process.env.OPEN_AI_KEY,
});
this.openai = new OpenAIApi(config);
this.model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
this.model =
modelPreference || process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -6,7 +6,7 @@ function togetherAiModels() {
}
class TogetherAiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
const { Configuration, OpenAIApi } = require("openai");
if (!process.env.TOGETHER_AI_API_KEY)
throw new Error("No TogetherAI API key was set.");
@ -16,7 +16,7 @@ class TogetherAiLLM {
apiKey: process.env.TOGETHER_AI_API_KEY,
});
this.openai = new OpenAIApi(config);
this.model = process.env.TOGETHER_AI_MODEL_PREF;
this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,

View File

@ -71,7 +71,7 @@ async function chatWithWorkspace(
return await VALID_COMMANDS[command](workspace, message, uuid, user);
}
const LLMConnector = getLLMProvider();
const LLMConnector = getLLMProvider(workspace?.chatModel);
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {

View File

@ -30,7 +30,7 @@ async function streamChatWithWorkspace(
return;
}
const LLMConnector = getLLMProvider();
const LLMConnector = getLLMProvider(workspace?.chatModel);
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {

View File

@ -17,7 +17,7 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
case "localai":
return await localAIModels(basePath, apiKey);
case "ollama":
return await ollamaAIModels(basePath, apiKey);
return await ollamaAIModels(basePath);
case "togetherai":
return await getTogetherAiModels();
case "native-llm":
@ -53,7 +53,7 @@ async function openAiModels(apiKey = null) {
async function localAIModels(basePath = null, apiKey = null) {
const { Configuration, OpenAIApi } = require("openai");
const config = new Configuration({
basePath,
basePath: basePath || process.env.LOCAL_AI_BASE_PATH,
apiKey: apiKey || process.env.LOCAL_AI_API_KEY,
});
const openai = new OpenAIApi(config);
@ -70,13 +70,14 @@ async function localAIModels(basePath = null, apiKey = null) {
return { models, error: null };
}
async function ollamaAIModels(basePath = null, _apiKey = null) {
async function ollamaAIModels(basePath = null) {
let url;
try {
new URL(basePath);
if (basePath.split("").slice(-1)?.[0] === "/")
let urlPath = basePath ?? process.env.OLLAMA_BASE_PATH;
new URL(urlPath);
if (urlPath.split("").slice(-1)?.[0] === "/")
throw new Error("BasePath Cannot end in /!");
url = basePath;
url = urlPath;
} catch {
return { models: [], error: "Not a valid URL." };
}

View File

@ -24,37 +24,37 @@ function getVectorDbClass() {
}
}
function getLLMProvider() {
function getLLMProvider(modelPreference = null) {
const vectorSelection = process.env.LLM_PROVIDER || "openai";
const embedder = getEmbeddingEngineSelection();
switch (vectorSelection) {
case "openai":
const { OpenAiLLM } = require("../AiProviders/openAi");
return new OpenAiLLM(embedder);
return new OpenAiLLM(embedder, modelPreference);
case "azure":
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
return new AzureOpenAiLLM(embedder);
return new AzureOpenAiLLM(embedder, modelPreference);
case "anthropic":
const { AnthropicLLM } = require("../AiProviders/anthropic");
return new AnthropicLLM(embedder);
return new AnthropicLLM(embedder, modelPreference);
case "gemini":
const { GeminiLLM } = require("../AiProviders/gemini");
return new GeminiLLM(embedder);
return new GeminiLLM(embedder, modelPreference);
case "lmstudio":
const { LMStudioLLM } = require("../AiProviders/lmStudio");
return new LMStudioLLM(embedder);
return new LMStudioLLM(embedder, modelPreference);
case "localai":
const { LocalAiLLM } = require("../AiProviders/localAi");
return new LocalAiLLM(embedder);
return new LocalAiLLM(embedder, modelPreference);
case "ollama":
const { OllamaAILLM } = require("../AiProviders/ollama");
return new OllamaAILLM(embedder);
return new OllamaAILLM(embedder, modelPreference);
case "togetherai":
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
return new TogetherAiLLM(embedder);
return new TogetherAiLLM(embedder, modelPreference);
case "native":
const { NativeLLM } = require("../AiProviders/native");
return new NativeLLM(embedder);
return new NativeLLM(embedder, modelPreference);
default:
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
}

View File

@ -2,6 +2,7 @@ const KEY_MAPPING = {
LLMProvider: {
envKey: "LLM_PROVIDER",
checks: [isNotEmpty, supportedLLM],
postUpdate: [wipeWorkspaceModelPreference],
},
// OpenAI Settings
OpenAiKey: {
@ -362,11 +363,20 @@ function validDockerizedUrl(input = "") {
return null;
}
// If the LLMProvider has changed we need to reset all workspace model preferences to
// null since the provider<>model name combination will be invalid for whatever the new
// provider is.
async function wipeWorkspaceModelPreference(key, prev, next) {
if (prev === next) return;
const { Workspace } = require("../../models/workspace");
await Workspace.resetWorkspaceChatModels();
}
// This will force update .env variables which for any which reason were not able to be parsed or
// read from an ENV file as this seems to be a complicating step for many so allowing people to write
// to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks
// and is simply for debugging when the .env not found issue many come across.
function updateENV(newENVs = {}, force = false) {
async function updateENV(newENVs = {}, force = false) {
let error = "";
const validKeys = Object.keys(KEY_MAPPING);
const ENV_KEYS = Object.keys(newENVs).filter(
@ -374,21 +384,25 @@ function updateENV(newENVs = {}, force = false) {
);
const newValues = {};
ENV_KEYS.forEach((key) => {
const { envKey, checks } = KEY_MAPPING[key];
const value = newENVs[key];
for (const key of ENV_KEYS) {
const { envKey, checks, postUpdate = [] } = KEY_MAPPING[key];
const prevValue = process.env[envKey];
const nextValue = newENVs[key];
const errors = checks
.map((validityCheck) => validityCheck(value, force))
.map((validityCheck) => validityCheck(nextValue, force))
.filter((err) => typeof err === "string");
if (errors.length > 0) {
error += errors.join("\n");
return;
break;
}
newValues[key] = value;
process.env[envKey] = value;
});
newValues[key] = nextValue;
process.env[envKey] = nextValue;
for (const postUpdateFunc of postUpdate)
await postUpdateFunc(key, prevValue, nextValue);
}
return { newValues, error: error?.length > 0 ? error : false };
}