diff --git a/docker/.env.example b/docker/.env.example index 6368a190..174a9d69 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -128,6 +128,12 @@ GID='1000' # VOYAGEAI_API_KEY= # EMBEDDING_MODEL_PREF='voyage-large-2-instruct' +# EMBEDDING_ENGINE='litellm' +# EMBEDDING_MODEL_PREF='text-embedding-ada-002' +# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192 +# LITE_LLM_BASE_PATH='http://127.0.0.1:4000' +# LITE_LLM_API_KEY='sk-123abc' + ########################################### ######## Vector Database Selection ######## ########################################### diff --git a/frontend/src/components/EmbeddingSelection/LiteLLMOptions/index.jsx b/frontend/src/components/EmbeddingSelection/LiteLLMOptions/index.jsx new file mode 100644 index 00000000..d5586c88 --- /dev/null +++ b/frontend/src/components/EmbeddingSelection/LiteLLMOptions/index.jsx @@ -0,0 +1,186 @@ +import { useEffect, useState } from "react"; +import System from "@/models/system"; +import { Warning } from "@phosphor-icons/react"; +import { Tooltip } from "react-tooltip"; + +export default function LiteLLMOptions({ settings }) { + const [basePathValue, setBasePathValue] = useState(settings?.LiteLLMBasePath); + const [basePath, setBasePath] = useState(settings?.LiteLLMBasePath); + const [apiKeyValue, setApiKeyValue] = useState(settings?.LiteLLMAPIKey); + const [apiKey, setApiKey] = useState(settings?.LiteLLMAPIKey); + + return ( +
+
+
+ + setBasePathValue(e.target.value)} + onBlur={() => setBasePath(basePathValue)} + /> +
+ +
+ + e.target.blur()} + defaultValue={settings?.EmbeddingModelMaxChunkLength} + required={false} + autoComplete="off" + /> +
+
+
+
+
+ +
+ setApiKeyValue(e.target.value)} + onBlur={() => setApiKey(apiKeyValue)} + /> +
+
+
+ ); +} + +function LiteLLMModelSelection({ settings, basePath = null, apiKey = null }) { + const [customModels, setCustomModels] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + if (!basePath) { + setCustomModels([]); + setLoading(false); + return; + } + setLoading(true); + const { models } = await System.customModels( + "litellm", + typeof apiKey === "boolean" ? null : apiKey, + basePath + ); + setCustomModels(models || []); + setLoading(false); + } + findCustomModels(); + }, [basePath, apiKey]); + + if (loading || customModels.length == 0) { + return ( +
+ + +
+ ); + } + + return ( +
+
+ + +
+ +
+ ); +} + +function EmbeddingModelTooltip() { + return ( +
+ + +

+ Be sure to select a valid embedding model. Chat models are not + embedding models. See{" "} + + this page + {" "} + for more information. +

+
+
+ ); +} diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx index 5a0f51c1..4d032dc0 100644 --- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx @@ -11,6 +11,7 @@ import OllamaLogo from "@/media/llmprovider/ollama.png"; import LMStudioLogo from "@/media/llmprovider/lmstudio.png"; import CohereLogo from "@/media/llmprovider/cohere.png"; import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png"; +import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import PreLoader from "@/components/Preloader"; import ChangeWarningModal from "@/components/ChangeWarning"; @@ -22,6 +23,7 @@ import OllamaEmbeddingOptions from "@/components/EmbeddingSelection/OllamaOption import LMStudioEmbeddingOptions from "@/components/EmbeddingSelection/LMStudioOptions"; import CohereEmbeddingOptions from "@/components/EmbeddingSelection/CohereOptions"; import VoyageAiOptions from "@/components/EmbeddingSelection/VoyageAiOptions"; +import LiteLLMOptions from "@/components/EmbeddingSelection/LiteLLMOptions"; import EmbedderItem from "@/components/EmbeddingSelection/EmbedderItem"; import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react"; @@ -88,6 +90,13 @@ const EMBEDDERS = [ options: (settings) => , description: "Run powerful embedding models from Voyage AI.", }, + { + name: "LiteLLM", + value: "litellm", + logo: LiteLLMLogo, + options: (settings) => , + description: "Run powerful embedding models from LiteLLM.", + }, ]; export default function GeneralEmbeddingPreference() { diff --git a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx index 35358636..b4fa666f 100644 --- a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx +++ b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx @@ -301,6 +301,13 @@ export const EMBEDDING_ENGINE_PRIVACY = { ], logo: VoyageAiLogo, }, + litellm: { + name: "LiteLLM", + description: [ + "Your document text is only accessible on the server running LiteLLM and to the providers you configured in LiteLLM.", + ], + logo: LiteLLMLogo, + }, }; export default function DataHandling({ setHeader, setForwardBtn, setBackBtn }) { diff --git a/server/.env.example b/server/.env.example index f51d6177..6148d594 100644 --- a/server/.env.example +++ b/server/.env.example @@ -125,6 +125,12 @@ JWT_SECRET="my-random-string-for-seeding" # Please generate random string at lea # VOYAGEAI_API_KEY= # EMBEDDING_MODEL_PREF='voyage-large-2-instruct' +# EMBEDDING_ENGINE='litellm' +# EMBEDDING_MODEL_PREF='text-embedding-ada-002' +# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192 +# LITE_LLM_BASE_PATH='http://127.0.0.1:4000' +# LITE_LLM_API_KEY='sk-123abc' + ########################################### ######## Vector Database Selection ######## ########################################### diff --git a/server/utils/EmbeddingEngines/liteLLM/index.js b/server/utils/EmbeddingEngines/liteLLM/index.js new file mode 100644 index 00000000..cd22480b --- /dev/null +++ b/server/utils/EmbeddingEngines/liteLLM/index.js @@ -0,0 +1,93 @@ +const { toChunks, maximumChunkLength } = require("../../helpers"); + +class LiteLLMEmbedder { + constructor() { + const { OpenAI: OpenAIApi } = require("openai"); + if (!process.env.LITE_LLM_BASE_PATH) + throw new Error( + "LiteLLM must have a valid base path to use for the api." + ); + this.basePath = process.env.LITE_LLM_BASE_PATH; + this.openai = new OpenAIApi({ + baseURL: this.basePath, + apiKey: process.env.LITE_LLM_API_KEY ?? null, + }); + this.model = process.env.EMBEDDING_MODEL_PREF || "text-embedding-ada-002"; + + // Limit of how many strings we can process in a single pass to stay with resource or network limits + this.maxConcurrentChunks = 500; + this.embeddingMaxChunkLength = maximumChunkLength(); + } + + async embedTextInput(textInput) { + const result = await this.embedChunks( + Array.isArray(textInput) ? textInput : [textInput] + ); + return result?.[0] || []; + } + + async embedChunks(textChunks = []) { + // Because there is a hard POST limit on how many chunks can be sent at once to LiteLLM (~8mb) + // we concurrently execute each max batch of text chunks possible. + // Refer to constructor maxConcurrentChunks for more info. + const embeddingRequests = []; + for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) { + embeddingRequests.push( + new Promise((resolve) => { + this.openai.embeddings + .create({ + model: this.model, + input: chunk, + }) + .then((result) => { + resolve({ data: result?.data, error: null }); + }) + .catch((e) => { + e.type = + e?.response?.data?.error?.code || + e?.response?.status || + "failed_to_embed"; + e.message = e?.response?.data?.error?.message || e.message; + resolve({ data: [], error: e }); + }); + }) + ); + } + + const { data = [], error = null } = await Promise.all( + embeddingRequests + ).then((results) => { + // If any errors were returned from LiteLLM abort the entire sequence because the embeddings + // will be incomplete. + const errors = results + .filter((res) => !!res.error) + .map((res) => res.error) + .flat(); + if (errors.length > 0) { + let uniqueErrors = new Set(); + errors.map((error) => + uniqueErrors.add(`[${error.type}]: ${error.message}`) + ); + + return { + data: [], + error: Array.from(uniqueErrors).join(", "), + }; + } + return { + data: results.map((res) => res?.data || []).flat(), + error: null, + }; + }); + + if (!!error) throw new Error(`LiteLLM Failed to embed: ${error}`); + return data.length > 0 && + data.every((embd) => embd.hasOwnProperty("embedding")) + ? data.map((embd) => embd.embedding) + : null; + } +} + +module.exports = { + LiteLLMEmbedder, +}; diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index e60202a6..8f0df126 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -128,6 +128,9 @@ function getEmbeddingEngineSelection() { case "voyageai": const { VoyageAiEmbedder } = require("../EmbeddingEngines/voyageAi"); return new VoyageAiEmbedder(); + case "litellm": + const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM"); + return new LiteLLMEmbedder(); default: return new NativeEmbedder(); } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index d5cdc68f..1a0e710a 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -577,6 +577,7 @@ function supportedEmbeddingModel(input = "") { "lmstudio", "cohere", "voyageai", + "litellm", ]; return supported.includes(input) ? null