diff --git a/.vscode/settings.json b/.vscode/settings.json index 8405c5281..14efd3fae 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -53,6 +53,7 @@ "uuidv", "vectordbs", "Weaviate", + "XAILLM", "Zilliz" ], "eslint.experimental.useFlatConfig": true, diff --git a/docker/.env.example b/docker/.env.example index a6cabe655..7bb07ebef 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -109,6 +109,10 @@ GID='1000' # APIPIE_LLM_API_KEY='sk-123abc' # APIPIE_LLM_MODEL_PREF='openrouter/llama-3.1-8b-instruct' +# LLM_PROVIDER='xai' +# XAI_LLM_API_KEY='xai-your-api-key-here' +# XAI_LLM_MODEL_PREF='grok-beta' + ########################################### ######## Embedding API SElECTION ########## ########################################### diff --git a/frontend/src/components/LLMSelection/XAiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/XAiLLMOptions/index.jsx new file mode 100644 index 000000000..d760a8ba4 --- /dev/null +++ b/frontend/src/components/LLMSelection/XAiLLMOptions/index.jsx @@ -0,0 +1,114 @@ +import { useState, useEffect } from "react"; +import System from "@/models/system"; + +export default function XAILLMOptions({ settings }) { + const [inputValue, setInputValue] = useState(settings?.XAIApiKey); + const [apiKey, setApiKey] = useState(settings?.XAIApiKey); + + return ( +
+
+ + setInputValue(e.target.value)} + onBlur={() => setApiKey(inputValue)} + /> +
+ + {!settings?.credentialsOnly && ( + + )} +
+ ); +} + +function XAIModelSelection({ apiKey, settings }) { + const [customModels, setCustomModels] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + if (!apiKey) { + setCustomModels([]); + setLoading(true); + return; + } + + try { + setLoading(true); + const { models } = await System.customModels("xai", apiKey); + setCustomModels(models || []); + } catch (error) { + console.error("Failed to fetch custom models:", error); + setCustomModels([]); + } finally { + setLoading(false); + } + } + findCustomModels(); + }, [apiKey]); + + if (loading) { + return ( +
+ + +

+ Enter a valid API key to view all available models for your account. +

+
+ ); + } + + return ( +
+ + +

+ Select the xAI model you want to use for your conversations. +

+
+ ); +} diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js index ece31c2b5..a493438c7 100644 --- a/frontend/src/hooks/useGetProvidersModels.js +++ b/frontend/src/hooks/useGetProvidersModels.js @@ -49,6 +49,7 @@ const PROVIDER_DEFAULT_MODELS = { textgenwebui: [], "generic-openai": [], bedrock: [], + xai: ["grok-beta"], }; // For providers with large model lists (e.g. togetherAi) - we subgroup the options diff --git a/frontend/src/media/llmprovider/xai.png b/frontend/src/media/llmprovider/xai.png new file mode 100644 index 000000000..93106761e Binary files /dev/null and b/frontend/src/media/llmprovider/xai.png differ diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx index d471dc358..e7b06e172 100644 --- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx @@ -27,6 +27,7 @@ import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import AWSBedrockLogo from "@/media/llmprovider/bedrock.png"; import DeepSeekLogo from "@/media/llmprovider/deepseek.png"; import APIPieLogo from "@/media/llmprovider/apipie.png"; +import XAILogo from "@/media/llmprovider/xai.png"; import PreLoader from "@/components/Preloader"; import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions"; @@ -52,6 +53,7 @@ import LiteLLMOptions from "@/components/LLMSelection/LiteLLMOptions"; import AWSBedrockLLMOptions from "@/components/LLMSelection/AwsBedrockLLMOptions"; import DeepSeekOptions from "@/components/LLMSelection/DeepSeekOptions"; import ApiPieLLMOptions from "@/components/LLMSelection/ApiPieOptions"; +import XAILLMOptions from "@/components/LLMSelection/XAiLLMOptions"; import LLMItem from "@/components/LLMSelection/LLMItem"; import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react"; @@ -258,6 +260,15 @@ export const AVAILABLE_LLM_PROVIDERS = [ "GenericOpenAiKey", ], }, + { + name: "xAI", + value: "xai", + logo: XAILogo, + options: (settings) => , + description: "Run xAI's powerful LLMs like Grok-2 and more.", + requiredConfig: ["XAIApiKey", "XAIModelPref"], + }, + { name: "Native", value: "native", diff --git a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx index e3b4e2ee8..33750cba2 100644 --- a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx +++ b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx @@ -22,6 +22,7 @@ import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import AWSBedrockLogo from "@/media/llmprovider/bedrock.png"; import DeepSeekLogo from "@/media/llmprovider/deepseek.png"; import APIPieLogo from "@/media/llmprovider/apipie.png"; +import XAILogo from "@/media/llmprovider/xai.png"; import CohereLogo from "@/media/llmprovider/cohere.png"; import ZillizLogo from "@/media/vectordbs/zilliz.png"; @@ -210,6 +211,13 @@ export const LLM_SELECTION_PRIVACY = { ], logo: APIPieLogo, }, + xai: { + name: "xAI", + description: [ + "Your model and chat contents are visible to xAI in accordance with their terms of service.", + ], + logo: XAILogo, + }, }; export const VECTOR_DB_PRIVACY = { diff --git a/frontend/src/pages/OnboardingFlow/Steps/LLMPreference/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/LLMPreference/index.jsx index 1b69369f5..cc17acfd3 100644 --- a/frontend/src/pages/OnboardingFlow/Steps/LLMPreference/index.jsx +++ b/frontend/src/pages/OnboardingFlow/Steps/LLMPreference/index.jsx @@ -22,6 +22,7 @@ import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import AWSBedrockLogo from "@/media/llmprovider/bedrock.png"; import DeepSeekLogo from "@/media/llmprovider/deepseek.png"; import APIPieLogo from "@/media/llmprovider/apipie.png"; +import XAILogo from "@/media/llmprovider/xai.png"; import CohereLogo from "@/media/llmprovider/cohere.png"; import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions"; @@ -47,6 +48,7 @@ import LiteLLMOptions from "@/components/LLMSelection/LiteLLMOptions"; import AWSBedrockLLMOptions from "@/components/LLMSelection/AwsBedrockLLMOptions"; import DeepSeekOptions from "@/components/LLMSelection/DeepSeekOptions"; import ApiPieLLMOptions from "@/components/LLMSelection/ApiPieOptions"; +import XAILLMOptions from "@/components/LLMSelection/XAiLLMOptions"; import LLMItem from "@/components/LLMSelection/LLMItem"; import System from "@/models/system"; @@ -219,6 +221,13 @@ const LLMS = [ options: (settings) => , description: "Run powerful foundation models privately with AWS Bedrock.", }, + { + name: "xAI", + value: "xai", + logo: XAILogo, + options: (settings) => , + description: "Run xAI's powerful LLMs like Grok-2 and more.", + }, { name: "Native", value: "native", diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index d0b0b4893..c59a77e71 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -26,6 +26,7 @@ const ENABLED_PROVIDERS = [ "deepseek", "litellm", "apipie", + "xai", // TODO: More agent support. // "cohere", // Has tool calling and will need to build explicit support // "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested. diff --git a/server/.env.example b/server/.env.example index f2d16b310..9c513f62f 100644 --- a/server/.env.example +++ b/server/.env.example @@ -99,6 +99,10 @@ SIG_SALT='salt' # Please generate random string at least 32 chars long. # APIPIE_LLM_API_KEY='sk-123abc' # APIPIE_LLM_MODEL_PREF='openrouter/llama-3.1-8b-instruct' +# LLM_PROVIDER='xai' +# XAI_LLM_API_KEY='xai-your-api-key-here' +# XAI_LLM_MODEL_PREF='grok-beta' + ########################################### ######## Embedding API SElECTION ########## ########################################### diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index e5de59376..55569be07 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -525,6 +525,10 @@ const SystemSettings = { // APIPie LLM API Keys ApipieLLMApiKey: !!process.env.APIPIE_LLM_API_KEY, ApipieLLMModelPref: process.env.APIPIE_LLM_MODEL_PREF, + + // xAI LLM API Keys + XAIApiKey: !!process.env.XAI_LLM_API_KEY, + XAIModelPref: process.env.XAI_LLM_MODEL_PREF, }; }, diff --git a/server/utils/AiProviders/modelMap.js b/server/utils/AiProviders/modelMap.js index 84e480b31..390278f37 100644 --- a/server/utils/AiProviders/modelMap.js +++ b/server/utils/AiProviders/modelMap.js @@ -61,6 +61,9 @@ const MODEL_MAP = { "deepseek-chat": 128_000, "deepseek-coder": 128_000, }, + xai: { + "grok-beta": 131_072, + }, }; module.exports = { MODEL_MAP }; diff --git a/server/utils/AiProviders/xai/index.js b/server/utils/AiProviders/xai/index.js new file mode 100644 index 000000000..7a25760df --- /dev/null +++ b/server/utils/AiProviders/xai/index.js @@ -0,0 +1,168 @@ +const { NativeEmbedder } = require("../../EmbeddingEngines/native"); +const { + handleDefaultStreamResponseV2, +} = require("../../helpers/chat/responses"); +const { MODEL_MAP } = require("../modelMap"); + +class XAiLLM { + constructor(embedder = null, modelPreference = null) { + if (!process.env.XAI_LLM_API_KEY) + throw new Error("No xAI API key was set."); + const { OpenAI: OpenAIApi } = require("openai"); + + this.openai = new OpenAIApi({ + baseURL: "https://api.x.ai/v1", + apiKey: process.env.XAI_LLM_API_KEY, + }); + this.model = + modelPreference || process.env.XAI_LLM_MODEL_PREF || "grok-beta"; + this.limits = { + history: this.promptWindowLimit() * 0.15, + system: this.promptWindowLimit() * 0.15, + user: this.promptWindowLimit() * 0.7, + }; + + this.embedder = embedder ?? new NativeEmbedder(); + this.defaultTemp = 0.7; + } + + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + + streamingEnabled() { + return "streamGetChatCompletion" in this; + } + + static promptWindowLimit(modelName) { + return MODEL_MAP.xai[modelName] ?? 131_072; + } + + promptWindowLimit() { + return MODEL_MAP.xai[this.model] ?? 131_072; + } + + isValidChatCompletionModel(modelName = "") { + switch (modelName) { + case "grok-beta": + return true; + default: + return false; + } + } + + /** + * Generates appropriate content array for a message + attachments. + * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} + * @returns {string|object[]} + */ + #generateContent({ userPrompt, attachments = [] }) { + if (!attachments.length) { + return userPrompt; + } + + const content = [{ type: "text", text: userPrompt }]; + for (let attachment of attachments) { + content.push({ + type: "image_url", + image_url: { + url: attachment.contentString, + detail: "high", + }, + }); + } + return content.flat(); + } + + /** + * Construct the user prompt for this model. + * @param {{attachments: import("../../helpers").Attachment[]}} param0 + * @returns + */ + constructPrompt({ + systemPrompt = "", + contextTexts = [], + chatHistory = [], + userPrompt = "", + attachments = [], // This is the specific attachment for only this prompt + }) { + const prompt = { + role: "system", + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, + }; + return [ + prompt, + ...chatHistory, + { + role: "user", + content: this.#generateContent({ userPrompt, attachments }), + }, + ]; + } + + async getChatCompletion(messages = null, { temperature = 0.7 }) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `xAI chat: ${this.model} is not valid for chat completion!` + ); + + const result = await this.openai.chat.completions + .create({ + model: this.model, + messages, + temperature, + }) + .catch((e) => { + throw new Error(e.message); + }); + + if (!result.hasOwnProperty("choices") || result.choices.length === 0) + return null; + return result.choices[0].message.content; + } + + async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { + if (!this.isValidChatCompletionModel(this.model)) + throw new Error( + `xAI chat: ${this.model} is not valid for chat completion!` + ); + + const streamRequest = await this.openai.chat.completions.create({ + model: this.model, + stream: true, + messages, + temperature, + }); + return streamRequest; + } + + handleStream(response, stream, responseProps) { + return handleDefaultStreamResponseV2(response, stream, responseProps); + } + + // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations + async embedTextInput(textInput) { + return await this.embedder.embedTextInput(textInput); + } + async embedChunks(textChunks = []) { + return await this.embedder.embedChunks(textChunks); + } + + async compressMessages(promptArgs = {}, rawHistory = []) { + const { messageArrayCompressor } = require("../../helpers/chat"); + const messageArray = this.constructPrompt(promptArgs); + return await messageArrayCompressor(this, messageArray, rawHistory); + } +} + +module.exports = { + XAiLLM, +}; diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 51dc57553..24f027cff 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -789,6 +789,8 @@ ${this.getHistory({ to: route.to }) return new Providers.LiteLLMProvider({ model: config.model }); case "apipie": return new Providers.ApiPieProvider({ model: config.model }); + case "xai": + return new Providers.XAIProvider({ model: config.model }); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index afaefa1c9..c9925d1cd 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -146,6 +146,14 @@ class Provider { apiKey: process.env.DEEPSEEK_API_KEY ?? null, ...config, }); + case "xai": + return new ChatOpenAI({ + configuration: { + baseURL: "https://api.x.ai/v1", + }, + apiKey: process.env.XAI_LLM_API_KEY ?? null, + ...config, + }); // OSS Model Runners // case "anythingllm_ollama": diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index f5ae66420..47e2d8716 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -17,6 +17,7 @@ const FireworksAIProvider = require("./fireworksai.js"); const DeepSeekProvider = require("./deepseek.js"); const LiteLLMProvider = require("./litellm.js"); const ApiPieProvider = require("./apipie.js"); +const XAIProvider = require("./xai.js"); module.exports = { OpenAIProvider, @@ -38,4 +39,5 @@ module.exports = { FireworksAIProvider, LiteLLMProvider, ApiPieProvider, + XAIProvider, }; diff --git a/server/utils/agents/aibitat/providers/xai.js b/server/utils/agents/aibitat/providers/xai.js new file mode 100644 index 000000000..9461d865f --- /dev/null +++ b/server/utils/agents/aibitat/providers/xai.js @@ -0,0 +1,116 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The agent provider for the xAI provider. + */ +class XAIProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = "grok-beta" } = config; + super(); + const client = new OpenAI({ + baseURL: "https://api.x.ai/v1", + apiKey: process.env.XAI_LLM_API_KEY, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("xAI chat: No results!"); + if (result.choices.length === 0) + throw new Error("xAI chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent + // from calling the exact same function over and over in a loop within a single chat exchange + // _but_ we should enable it to call previously used tools in a new chat interaction. + this.deduplicator.reset("runs"); + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = XAIProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 98caea5cd..fd7d06e8b 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -169,6 +169,10 @@ class AgentHandler { if (!process.env.APIPIE_LLM_API_KEY) throw new Error("ApiPie API Key must be provided to use agents."); break; + case "xai": + if (!process.env.XAI_LLM_API_KEY) + throw new Error("xAI API Key must be provided to use agents."); + break; default: throw new Error( @@ -228,6 +232,8 @@ class AgentHandler { return process.env.LITE_LLM_MODEL_PREF ?? null; case "apipie": return process.env.APIPIE_LLM_MODEL_PREF ?? null; + case "xai": + return process.env.XAI_LLM_MODEL_PREF ?? "grok-beta"; default: return null; } diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index 086144bfe..7ccbf13c7 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -21,6 +21,7 @@ const SUPPORT_CUSTOM_MODELS = [ "groq", "deepseek", "apipie", + "xai", ]; async function getCustomModels(provider = "", apiKey = null, basePath = null) { @@ -60,6 +61,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { return await getDeepSeekModels(apiKey); case "apipie": return await getAPIPieModels(apiKey); + case "xai": + return await getXAIModels(apiKey); default: return { models: [], error: "Invalid provider for custom models" }; } @@ -466,6 +469,36 @@ async function getDeepSeekModels(apiKey = null) { return { models, error: null }; } +async function getXAIModels(_apiKey = null) { + const { OpenAI: OpenAIApi } = require("openai"); + const apiKey = + _apiKey === true + ? process.env.XAI_LLM_API_KEY + : _apiKey || process.env.XAI_LLM_API_KEY || null; + const openai = new OpenAIApi({ + baseURL: "https://api.x.ai/v1", + apiKey, + }); + const models = await openai.models + .list() + .then((results) => results.data) + .catch((e) => { + console.error(`XAI:listModels`, e.message); + return [ + { + created: 1725148800, + id: "grok-beta", + object: "model", + owned_by: "xai", + }, + ]; + }); + + // Api Key was successful so lets save it for future uses + if (models.length > 0 && !!apiKey) process.env.XAI_LLM_API_KEY = apiKey; + return { models, error: null }; +} + module.exports = { getCustomModels, }; diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index f3f19fb9d..84f971cc6 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -165,6 +165,9 @@ function getLLMProvider({ provider = null, model = null } = {}) { case "apipie": const { ApiPieLLM } = require("../AiProviders/apipie"); return new ApiPieLLM(embedder, model); + case "xai": + const { XAiLLM } = require("../AiProviders/xai"); + return new XAiLLM(embedder, model); default: throw new Error( `ENV: No valid LLM_PROVIDER value found in environment! Using ${process.env.LLM_PROVIDER}` @@ -294,6 +297,9 @@ function getLLMProviderClass({ provider = null } = {}) { case "apipie": const { ApiPieLLM } = require("../AiProviders/apipie"); return ApiPieLLM; + case "xai": + const { XAiLLM } = require("../AiProviders/xai"); + return XAiLLM; default: return null; } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 202ffcd99..d705fb730 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -539,6 +539,16 @@ const KEY_MAPPING = { envKey: "APIPIE_LLM_MODEL_PREF", checks: [isNotEmpty], }, + + // xAI Options + XAIApiKey: { + envKey: "XAI_LLM_API_KEY", + checks: [isNotEmpty], + }, + XAIModelPref: { + envKey: "XAI_LLM_MODEL_PREF", + checks: [isNotEmpty], + }, }; function isNotEmpty(input = "") { @@ -643,6 +653,7 @@ function supportedLLM(input = "") { "bedrock", "deepseek", "apipie", + "xai", ].includes(input); return validSelection ? null : `${input} is not a valid LLM provider.`; }