From e72fa8b370212394f9d52d7dc629b568de6538c9 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Fri, 21 Jun 2024 16:27:02 -0700 Subject: [PATCH] [FEAT] Generic OpenAI embedding provider (#1664) * implement generic openai embedding provider * linting * comment & description update for generic openai embedding provider * fix privacy for generic --------- Co-authored-by: timothycarambat --- docker/.env.example | 6 ++ .../GenericOpenAiOptions/index.jsx | 74 +++++++++++++++ .../EmbeddingPreference/index.jsx | 11 +++ .../Steps/DataHandling/index.jsx | 7 ++ server/.env.example | 6 ++ server/models/systemSettings.js | 2 + .../EmbeddingEngines/genericOpenAi/index.js | 95 +++++++++++++++++++ server/utils/helpers/index.js | 5 + server/utils/helpers/updateENV.js | 7 ++ 9 files changed, 213 insertions(+) create mode 100644 frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx create mode 100644 server/utils/EmbeddingEngines/genericOpenAi/index.js diff --git a/docker/.env.example b/docker/.env.example index f682f8bf..38b98088 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -136,6 +136,12 @@ GID='1000' # LITE_LLM_BASE_PATH='http://127.0.0.1:4000' # LITE_LLM_API_KEY='sk-123abc' +# EMBEDDING_ENGINE='generic-openai' +# EMBEDDING_MODEL_PREF='text-embedding-ada-002' +# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192 +# EMBEDDING_BASE_PATH='http://127.0.0.1:4000' +# GENERIC_OPEN_AI_EMBEDDING_API_KEY='sk-123abc' + ########################################### ######## Vector Database Selection ######## ########################################### diff --git a/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx b/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx new file mode 100644 index 00000000..8d4870f0 --- /dev/null +++ b/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx @@ -0,0 +1,74 @@ +export default function GenericOpenAiEmbeddingOptions({ settings }) { + return ( +
+
+
+ + +
+
+ + +
+
+ + e.target.blur()} + defaultValue={settings?.EmbeddingModelMaxChunkLength} + required={false} + autoComplete="off" + /> +
+
+
+
+
+ +
+ +
+
+
+ ); +} diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx index 2563aaad..ec8c2b8b 100644 --- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx @@ -12,6 +12,7 @@ import LMStudioLogo from "@/media/llmprovider/lmstudio.png"; import CohereLogo from "@/media/llmprovider/cohere.png"; import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png"; import LiteLLMLogo from "@/media/llmprovider/litellm.png"; +import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png"; import PreLoader from "@/components/Preloader"; import ChangeWarningModal from "@/components/ChangeWarning"; @@ -24,6 +25,7 @@ import LMStudioEmbeddingOptions from "@/components/EmbeddingSelection/LMStudioOp import CohereEmbeddingOptions from "@/components/EmbeddingSelection/CohereOptions"; import VoyageAiOptions from "@/components/EmbeddingSelection/VoyageAiOptions"; import LiteLLMOptions from "@/components/EmbeddingSelection/LiteLLMOptions"; +import GenericOpenAiEmbeddingOptions from "@/components/EmbeddingSelection/GenericOpenAiOptions"; import EmbedderItem from "@/components/EmbeddingSelection/EmbedderItem"; import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react"; @@ -98,6 +100,15 @@ const EMBEDDERS = [ options: (settings) => , description: "Run powerful embedding models from LiteLLM.", }, + { + name: "Generic OpenAI", + value: "generic-openai", + logo: GenericOpenAiLogo, + options: (settings) => ( + + ), + description: "Run embedding models from any OpenAI compatible API service.", + }, ]; export default function GeneralEmbeddingPreference() { diff --git a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx index b4fa666f..1b3bf360 100644 --- a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx +++ b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx @@ -308,6 +308,13 @@ export const EMBEDDING_ENGINE_PRIVACY = { ], logo: LiteLLMLogo, }, + "generic-openai": { + name: "Generic OpenAI compatible service", + description: [ + "Data is shared according to the terms of service applicable with your generic endpoint provider.", + ], + logo: GenericOpenAiLogo, + }, }; export default function DataHandling({ setHeader, setForwardBtn, setBackBtn }) { diff --git a/server/.env.example b/server/.env.example index 145e00da..22bd557e 100644 --- a/server/.env.example +++ b/server/.env.example @@ -133,6 +133,12 @@ SIG_SALT='salt' # Please generate random string at least 32 chars long. # LITE_LLM_BASE_PATH='http://127.0.0.1:4000' # LITE_LLM_API_KEY='sk-123abc' +# EMBEDDING_ENGINE='generic-openai' +# EMBEDDING_MODEL_PREF='text-embedding-ada-002' +# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192 +# EMBEDDING_BASE_PATH='http://127.0.0.1:4000' +# GENERIC_OPEN_AI_EMBEDDING_API_KEY='sk-123abc' + ########################################### ######## Vector Database Selection ######## ########################################### diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index eae75d9c..3f44f722 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -149,6 +149,8 @@ const SystemSettings = { EmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF, EmbeddingModelMaxChunkLength: process.env.EMBEDDING_MODEL_MAX_CHUNK_LENGTH, + GenericOpenAiEmbeddingApiKey: + !!process.env.GENERIC_OPEN_AI_EMBEDDING_API_KEY, // -------------------------------------------------------- // VectorDB Provider Selection Settings & Configs diff --git a/server/utils/EmbeddingEngines/genericOpenAi/index.js b/server/utils/EmbeddingEngines/genericOpenAi/index.js new file mode 100644 index 00000000..d3ec3072 --- /dev/null +++ b/server/utils/EmbeddingEngines/genericOpenAi/index.js @@ -0,0 +1,95 @@ +const { toChunks } = require("../../helpers"); + +class GenericOpenAiEmbedder { + constructor() { + if (!process.env.EMBEDDING_BASE_PATH) + throw new Error( + "GenericOpenAI must have a valid base path to use for the api." + ); + const { OpenAI: OpenAIApi } = require("openai"); + this.basePath = process.env.EMBEDDING_BASE_PATH; + this.openai = new OpenAIApi({ + baseURL: this.basePath, + apiKey: process.env.GENERIC_OPEN_AI_EMBEDDING_API_KEY ?? null, + }); + this.model = process.env.EMBEDDING_MODEL_PREF ?? null; + + // Limit of how many strings we can process in a single pass to stay with resource or network limits + this.maxConcurrentChunks = 500; + + // Refer to your specific model and provider you use this class with to determine a valid maxChunkLength + this.embeddingMaxChunkLength = 8_191; + } + + async embedTextInput(textInput) { + const result = await this.embedChunks( + Array.isArray(textInput) ? textInput : [textInput] + ); + return result?.[0] || []; + } + + async embedChunks(textChunks = []) { + // Because there is a hard POST limit on how many chunks can be sent at once to OpenAI (~8mb) + // we concurrently execute each max batch of text chunks possible. + // Refer to constructor maxConcurrentChunks for more info. + const embeddingRequests = []; + for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) { + embeddingRequests.push( + new Promise((resolve) => { + this.openai.embeddings + .create({ + model: this.model, + input: chunk, + }) + .then((result) => { + resolve({ data: result?.data, error: null }); + }) + .catch((e) => { + e.type = + e?.response?.data?.error?.code || + e?.response?.status || + "failed_to_embed"; + e.message = e?.response?.data?.error?.message || e.message; + resolve({ data: [], error: e }); + }); + }) + ); + } + + const { data = [], error = null } = await Promise.all( + embeddingRequests + ).then((results) => { + // If any errors were returned from OpenAI abort the entire sequence because the embeddings + // will be incomplete. + const errors = results + .filter((res) => !!res.error) + .map((res) => res.error) + .flat(); + if (errors.length > 0) { + let uniqueErrors = new Set(); + errors.map((error) => + uniqueErrors.add(`[${error.type}]: ${error.message}`) + ); + + return { + data: [], + error: Array.from(uniqueErrors).join(", "), + }; + } + return { + data: results.map((res) => res?.data || []).flat(), + error: null, + }; + }); + + if (!!error) throw new Error(`GenericOpenAI Failed to embed: ${error}`); + return data.length > 0 && + data.every((embd) => embd.hasOwnProperty("embedding")) + ? data.map((embd) => embd.embedding) + : null; + } +} + +module.exports = { + GenericOpenAiEmbedder, +}; diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 8f0df126..302ec958 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -131,6 +131,11 @@ function getEmbeddingEngineSelection() { case "litellm": const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM"); return new LiteLLMEmbedder(); + case "generic-openai": + const { + GenericOpenAiEmbedder, + } = require("../EmbeddingEngines/genericOpenAi"); + return new GenericOpenAiEmbedder(); default: return new NativeEmbedder(); } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 6abd6408..c2cfc1aa 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -221,6 +221,12 @@ const KEY_MAPPING = { checks: [nonZero], }, + // Generic OpenAI Embedding Settings + GenericOpenAiEmbeddingApiKey: { + envKey: "GENERIC_OPEN_AI_EMBEDDING_API_KEY", + checks: [], + }, + // Vector Database Selection Settings VectorDB: { envKey: "VECTOR_DB", @@ -587,6 +593,7 @@ function supportedEmbeddingModel(input = "") { "cohere", "voyageai", "litellm", + "generic-openai", ]; return supported.includes(input) ? null