diff --git a/README.md b/README.md index 9ed7cc60..44e0557f 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ Some cool features of AnythingLLM - [OpenAI](https://openai.com) - [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) - [Anthropic ClaudeV2](https://www.anthropic.com/) +- [Google Gemini Pro](https://ai.google.dev/) - [LM Studio (all models)](https://lmstudio.ai) - [LocalAi (all models)](https://localai.io/) diff --git a/docker/.env.example b/docker/.env.example index 8bbdd1dd..cc9fa06f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -11,6 +11,10 @@ GID='1000' # OPEN_AI_KEY= # OPEN_MODEL_PREF='gpt-3.5-turbo' +# LLM_PROVIDER='gemini' +# GEMINI_API_KEY= +# GEMINI_LLM_MODEL_PREF='gemini-pro' + # LLM_PROVIDER='azure' # AZURE_OPENAI_ENDPOINT= # AZURE_OPENAI_KEY= diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx new file mode 100644 index 00000000..4d09e043 --- /dev/null +++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx @@ -0,0 +1,43 @@ +export default function GeminiLLMOptions({ settings }) { + return ( +
+
+
+ + +
+ +
+ + +
+
+
+ ); +} diff --git a/frontend/src/media/llmprovider/gemini.png b/frontend/src/media/llmprovider/gemini.png new file mode 100644 index 00000000..aa81cfd8 Binary files /dev/null and b/frontend/src/media/llmprovider/gemini.png differ diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx index 1abf3a4b..d6224906 100644 --- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx @@ -46,10 +46,10 @@ export default function GeneralEmbeddingPreference() { const { error } = await System.updateSystem(settingsData); if (error) { - showToast(`Failed to save LLM settings: ${error}`, "error"); + showToast(`Failed to save embedding settings: ${error}`, "error"); setHasChanges(true); } else { - showToast("LLM preferences saved successfully.", "success"); + showToast("Embedding preferences saved successfully.", "success"); setHasChanges(false); } setSaving(false); @@ -132,7 +132,7 @@ export default function GeneralEmbeddingPreference() {
Embedding Providers
-
+
-
+
{embeddingChoice === "native" && } {embeddingChoice === "openai" && ( diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx index 1c18d1ff..a0169fe1 100644 --- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx @@ -7,6 +7,7 @@ import AnythingLLMIcon from "@/media/logo/anything-llm-icon.png"; import OpenAiLogo from "@/media/llmprovider/openai.png"; import AzureOpenAiLogo from "@/media/llmprovider/azure.png"; import AnthropicLogo from "@/media/llmprovider/anthropic.png"; +import GeminiLogo from "@/media/llmprovider/gemini.png"; import LMStudioLogo from "@/media/llmprovider/lmstudio.png"; import LocalAiLogo from "@/media/llmprovider/localai.png"; import PreLoader from "@/components/Preloader"; @@ -17,6 +18,7 @@ import AnthropicAiOptions from "@/components/LLMSelection/AnthropicAiOptions"; import LMStudioOptions from "@/components/LLMSelection/LMStudioOptions"; import LocalAiOptions from "@/components/LLMSelection/LocalAiOptions"; import NativeLLMOptions from "@/components/LLMSelection/NativeLLMOptions"; +import GeminiLLMOptions from "@/components/LLMSelection/GeminiLLMOptions"; export default function GeneralLLMPreference() { const [saving, setSaving] = useState(false); @@ -105,13 +107,13 @@ export default function GeneralLLMPreference() {
LLM Providers
-
+
+ )} + {llmChoice === "gemini" && ( + + )} {llmChoice === "lmstudio" && ( )} diff --git a/frontend/src/pages/GeneralSettings/VectorDatabase/index.jsx b/frontend/src/pages/GeneralSettings/VectorDatabase/index.jsx index 1635fef8..2ddf1d5a 100644 --- a/frontend/src/pages/GeneralSettings/VectorDatabase/index.jsx +++ b/frontend/src/pages/GeneralSettings/VectorDatabase/index.jsx @@ -55,10 +55,10 @@ export default function GeneralVectorDatabase() { const { error } = await System.updateSystem(settingsData); if (error) { - showToast(`Failed to save LLM settings: ${error}`, "error"); + showToast(`Failed to save vector database settings: ${error}`, "error"); setHasChanges(true); } else { - showToast("LLM preferences saved successfully.", "success"); + showToast("Vector database preferences saved successfully.", "success"); setHasChanges(false); } setSaving(false); diff --git a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/DataHandling/index.jsx index 98a1671c..cd63d74d 100644 --- a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/DataHandling/index.jsx +++ b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/DataHandling/index.jsx @@ -4,6 +4,7 @@ import AnythingLLMIcon from "@/media/logo/anything-llm-icon.png"; import OpenAiLogo from "@/media/llmprovider/openai.png"; import AzureOpenAiLogo from "@/media/llmprovider/azure.png"; import AnthropicLogo from "@/media/llmprovider/anthropic.png"; +import GeminiLogo from "@/media/llmprovider/gemini.png"; import LMStudioLogo from "@/media/llmprovider/lmstudio.png"; import LocalAiLogo from "@/media/llmprovider/localai.png"; import ChromaLogo from "@/media/vectordbs/chroma.png"; @@ -38,6 +39,14 @@ const LLM_SELECTION_PRIVACY = { ], logo: AnthropicLogo, }, + gemini: { + name: "Google Gemini", + description: [ + "Your chats are de-identified and used in training", + "Your prompts and document text are visible in responses to Google", + ], + logo: GeminiLogo, + }, lmstudio: { name: "LMStudio", description: [ diff --git a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx index 1f44c463..98e1262a 100644 --- a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx +++ b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx @@ -76,7 +76,7 @@ function EmbeddingSelection({ nextStep, prevStep, currentStep }) { name="OpenAI" value="openai" link="openai.com" - description="The standard option for most non-commercial use. Provides both chat and embedding." + description="The standard option for most non-commercial use." checked={embeddingChoice === "openai"} image={OpenAiLogo} onClick={updateChoice} @@ -85,7 +85,7 @@ function EmbeddingSelection({ nextStep, prevStep, currentStep }) { name="Azure OpenAI" value="azure" link="azure.microsoft.com" - description="The enterprise option of OpenAI hosted on Azure services. Provides both chat and embedding." + description="The enterprise option of OpenAI hosted on Azure services." checked={embeddingChoice === "azure"} image={AzureOpenAiLogo} onClick={updateChoice} diff --git a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx index bb87486b..f877e31d 100644 --- a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx +++ b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx @@ -3,6 +3,7 @@ import AnythingLLMIcon from "@/media/logo/anything-llm-icon.png"; import OpenAiLogo from "@/media/llmprovider/openai.png"; import AzureOpenAiLogo from "@/media/llmprovider/azure.png"; import AnthropicLogo from "@/media/llmprovider/anthropic.png"; +import GeminiLogo from "@/media/llmprovider/gemini.png"; import LMStudioLogo from "@/media/llmprovider/lmstudio.png"; import LocalAiLogo from "@/media/llmprovider/localai.png"; import System from "@/models/system"; @@ -14,6 +15,7 @@ import AnthropicAiOptions from "@/components/LLMSelection/AnthropicAiOptions"; import LMStudioOptions from "@/components/LLMSelection/LMStudioOptions"; import LocalAiOptions from "@/components/LLMSelection/LocalAiOptions"; import NativeLLMOptions from "@/components/LLMSelection/NativeLLMOptions"; +import GeminiLLMOptions from "@/components/LLMSelection/GeminiLLMOptions"; function LLMSelection({ nextStep, prevStep, currentStep }) { const [llmChoice, setLLMChoice] = useState("openai"); @@ -71,7 +73,7 @@ function LLMSelection({ nextStep, prevStep, currentStep }) { name="OpenAI" value="openai" link="openai.com" - description="The standard option for most non-commercial use. Provides both chat and embedding." + description="The standard option for most non-commercial use." checked={llmChoice === "openai"} image={OpenAiLogo} onClick={updateLLMChoice} @@ -80,7 +82,7 @@ function LLMSelection({ nextStep, prevStep, currentStep }) { name="Azure OpenAI" value="azure" link="azure.microsoft.com" - description="The enterprise option of OpenAI hosted on Azure services. Provides both chat and embedding." + description="The enterprise option of OpenAI hosted on Azure services." checked={llmChoice === "azure"} image={AzureOpenAiLogo} onClick={updateLLMChoice} @@ -94,6 +96,15 @@ function LLMSelection({ nextStep, prevStep, currentStep }) { image={AnthropicLogo} onClick={updateLLMChoice} /> + )} + {llmChoice === "gemini" && } {llmChoice === "lmstudio" && ( )} diff --git a/server/.env.example b/server/.env.example index a4bc9fe5..f73e0e08 100644 --- a/server/.env.example +++ b/server/.env.example @@ -8,6 +8,10 @@ JWT_SECRET="my-random-string-for-seeding" # Please generate random string at lea # OPEN_AI_KEY= # OPEN_MODEL_PREF='gpt-3.5-turbo' +# LLM_PROVIDER='gemini' +# GEMINI_API_KEY= +# GEMINI_LLM_MODEL_PREF='gemini-pro' + # LLM_PROVIDER='azure' # AZURE_OPENAI_ENDPOINT= # AZURE_OPENAI_KEY= diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index 068359bb..b5dfeb70 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -87,6 +87,20 @@ const SystemSettings = { } : {}), + ...(llmProvider === "gemini" + ? { + GeminiLLMApiKey: !!process.env.GEMINI_API_KEY, + GeminiLLMModelPref: + process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro", + + // For embedding credentials when Gemini is selected. + OpenAiKey: !!process.env.OPEN_AI_KEY, + AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT, + AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY, + AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF, + } + : {}), + ...(llmProvider === "lmstudio" ? { LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH, diff --git a/server/package.json b/server/package.json index 1100adbc..4f84327a 100644 --- a/server/package.json +++ b/server/package.json @@ -22,6 +22,7 @@ "dependencies": { "@anthropic-ai/sdk": "^0.8.1", "@azure/openai": "^1.0.0-beta.3", + "@google/generative-ai": "^0.1.3", "@googleapis/youtube": "^9.0.0", "@pinecone-database/pinecone": "^0.1.6", "@prisma/client": "5.3.0", @@ -65,4 +66,4 @@ "nodemon": "^2.0.22", "prettier": "^2.4.1" } -} \ No newline at end of file +} diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js new file mode 100644 index 00000000..d0a76c55 --- /dev/null +++ b/server/utils/AiProviders/gemini/index.js @@ -0,0 +1,200 @@ +const { v4 } = require("uuid"); +const { chatPrompt } = require("../../chats"); + +class GeminiLLM { + constructor(embedder = null) { + if (!process.env.GEMINI_API_KEY) + throw new Error("No Gemini API key was set."); + + // Docs: https://ai.google.dev/tutorials/node_quickstart + const { GoogleGenerativeAI } = require("@google/generative-ai"); + const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY); + this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro"; + this.gemini = genAI.getGenerativeModel({ model: this.model }); + this.limits = { + history: this.promptWindowLimit() * 0.15, + system: this.promptWindowLimit() * 0.15, + user: this.promptWindowLimit() * 0.7, + }; + + if (!embedder) + throw new Error( + "INVALID GEMINI LLM SETUP. No embedding engine has been set. Go to instance settings and set up an embedding interface to use Gemini as your LLM." + ); + this.embedder = embedder; + this.answerKey = v4().split("-")[0]; + } + + streamingEnabled() { + return "streamChat" in this && "streamGetChatCompletion" in this; + } + + promptWindowLimit() { + switch (this.model) { + case "gemini-pro": + return 30_720; + default: + return 30_720; // assume a gemini-pro model + } + } + + isValidChatCompletionModel(modelName = "") { + const validModels = ["gemini-pro"]; + return validModels.includes(modelName); + } + + // Moderation cannot be done with Gemini. + // Not implemented so must be stubbed + async isSafe(_input = "") { + return { safe: true, reasons: [] }; + } + + constructPrompt({ + systemPrompt = "", + contextTexts = [], + chatHistory = [], + userPrompt = "", + }) { + const prompt = { + role: "system", + content: `${systemPrompt} +Context: + ${contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("")}`, + }; + return [ + prompt, + { role: "assistant", content: "Okay." }, + ...chatHistory, + { role: "USER_PROMPT", content: userPrompt }, + ]; + } + + // This will take an OpenAi format message array and only pluck valid roles from it. + formatMessages(messages = []) { + // Gemini roles are either user || model. + // and all "content" is relabeled to "parts" + return messages + .map((message) => { + if (message.role === "system") + return { role: "user", parts: message.content }; + if (message.role === "user") + return { role: "user", parts: message.content }; + if (message.role === "assistant") + return { role: "model", parts: message.content }; + return null; + }) + .filter((msg) => !!msg); + } + + async sendChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `Gemini chat: ${this.model} is not valid for chat completion!` + ); + + const compressedHistory = await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + chatHistory, + }, + rawHistory + ); + + const chatThread = this.gemini.startChat({ + history: this.formatMessages(compressedHistory), + }); + const result = await chatThread.sendMessage(prompt); + const response = result.response; + const responseText = response.text(); + + if (!responseText) throw new Error("Gemini: No response could be parsed."); + + return responseText; + } + + async getChatCompletion(messages = [], _opts = {}) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `Gemini chat: ${this.model} is not valid for chat completion!` + ); + + const prompt = messages.find( + (chat) => chat.role === "USER_PROMPT" + )?.content; + const chatThread = this.gemini.startChat({ + history: this.formatMessages(messages), + }); + const result = await chatThread.sendMessage(prompt); + const response = result.response; + const responseText = response.text(); + + if (!responseText) throw new Error("Gemini: No response could be parsed."); + + return responseText; + } + + async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `Gemini chat: ${this.model} is not valid for chat completion!` + ); + + const compressedHistory = await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + chatHistory, + }, + rawHistory + ); + + const chatThread = this.gemini.startChat({ + history: this.formatMessages(compressedHistory), + }); + const responseStream = await chatThread.sendMessageStream(prompt); + if (!responseStream.stream) + throw new Error("Could not stream response stream from Gemini."); + + return { type: "geminiStream", ...responseStream }; + } + + async streamGetChatCompletion(messages = [], _opts = {}) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `Gemini chat: ${this.model} is not valid for chat completion!` + ); + + const prompt = messages.find( + (chat) => chat.role === "USER_PROMPT" + )?.content; + const chatThread = this.gemini.startChat({ + history: this.formatMessages(messages), + }); + const responseStream = await chatThread.sendMessageStream(prompt); + if (!responseStream.stream) + throw new Error("Could not stream response stream from Gemini."); + + return { type: "geminiStream", ...responseStream }; + } + + async compressMessages(promptArgs = {}, rawHistory = []) { + const { messageArrayCompressor } = require("../../helpers/chat"); + const messageArray = this.constructPrompt(promptArgs); + return await messageArrayCompressor(this, messageArray, rawHistory); + } + + // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations + async embedTextInput(textInput) { + return await this.embedder.embedTextInput(textInput); + } + async embedChunks(textChunks = []) { + return await this.embedder.embedChunks(textChunks); + } +} + +module.exports = { + GeminiLLM, +}; diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index 4eb9cf02..5bdb7a1f 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -202,6 +202,35 @@ async function streamEmptyEmbeddingChat({ function handleStreamResponses(response, stream, responseProps) { const { uuid = uuidv4(), sources = [] } = responseProps; + // Gemini likes to return a stream asyncIterator which will + // be a totally different object than other models. + if (stream?.type === "geminiStream") { + return new Promise(async (resolve) => { + let fullText = ""; + for await (const chunk of stream.stream) { + fullText += chunk.text(); + writeResponseChunk(response, { + uuid, + sources: [], + type: "textResponseChunk", + textResponse: chunk.text(), + close: false, + error: false, + }); + } + + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: "", + close: true, + error: false, + }); + resolve(fullText); + }); + } + // If stream is not a regular OpenAI Stream (like if using native model) // we can just iterate the stream content instead. if (!stream.hasOwnProperty("data")) { diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 3b7f4ccc..115df400 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -34,6 +34,9 @@ function getLLMProvider() { case "anthropic": const { AnthropicLLM } = require("../AiProviders/anthropic"); return new AnthropicLLM(embedder); + case "gemini": + const { GeminiLLM } = require("../AiProviders/gemini"); + return new GeminiLLM(embedder); case "lmstudio": const { LMStudioLLM } = require("../AiProviders/lmStudio"); return new LMStudioLLM(embedder); diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 3a8ea55d..fe4f4f5c 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -44,6 +44,15 @@ const KEY_MAPPING = { checks: [isNotEmpty, validAnthropicModel], }, + GeminiLLMApiKey: { + envKey: "GEMINI_API_KEY", + checks: [isNotEmpty], + }, + GeminiLLMModelPref: { + envKey: "GEMINI_LLM_MODEL_PREF", + checks: [isNotEmpty, validGeminiModel], + }, + // LMStudio Settings LMStudioBasePath: { envKey: "LMSTUDIO_BASE_PATH", @@ -204,12 +213,20 @@ function supportedLLM(input = "") { "openai", "azure", "anthropic", + "gemini", "lmstudio", "localai", "native", ].includes(input); } +function validGeminiModel(input = "") { + const validModels = ["gemini-pro"]; + return validModels.includes(input) + ? null + : `Invalid Model type. Must be one of ${validModels.join(", ")}.`; +} + function validAnthropicModel(input = "") { const validModels = ["claude-2", "claude-instant-1"]; return validModels.includes(input) diff --git a/server/yarn.lock b/server/yarn.lock index caffe137..f9a621f6 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -140,6 +140,11 @@ resolved "https://registry.yarnpkg.com/@gar/promisify/-/promisify-1.1.3.tgz#555193ab2e3bb3b6adc3d551c9c030d9e860daf6" integrity sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw== +"@google/generative-ai@^0.1.3": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@google/generative-ai/-/generative-ai-0.1.3.tgz#8e529d4d86c85b64d297b4abf1a653d613a09a9f" + integrity sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ== + "@googleapis/youtube@^9.0.0": version "9.0.0" resolved "https://registry.yarnpkg.com/@googleapis/youtube/-/youtube-9.0.0.tgz#e45f6f5f7eac198c6391782b94b3ca54bacf0b63"