From 8422f9254278f6430a350b808b877d42e6061ec0 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Wed, 8 May 2024 15:17:54 -0700 Subject: [PATCH 01/11] Agent support for LLMs with no function calling (#1295) * add LMStudio agent support (generic) support "work" with non-tool callable LLMs, highly dependent on system specs * add comments * enable few-shot prompting per function for OSS models * Add Agent support for Ollama models * azure, groq, koboldcpp agent support complete + WIP togetherai * WIP gemini agent support * WIP gemini blocked and will not fix for now * azure fix * merge fix * add localai agent support * azure untooled agent support * merge fix * refactor implementation of several agent provideers * update bad merge comment --------- Co-authored-by: timothycarambat --- .vscode/settings.json | 1 + .../AgentConfig/AgentLLMSelection/index.jsx | 21 +++- server/utils/agents/aibitat/index.js | 14 ++- .../agents/aibitat/providers/ai-provider.js | 3 + .../utils/agents/aibitat/providers/azure.js | 105 ++++++++++++++++ server/utils/agents/aibitat/providers/groq.js | 110 +++++++++++++++++ .../aibitat/providers/helpers/untooled.js | 3 +- .../utils/agents/aibitat/providers/index.js | 10 ++ .../agents/aibitat/providers/koboldcpp.js | 113 +++++++++++++++++ .../agents/aibitat/providers/lmstudio.js | 2 +- .../utils/agents/aibitat/providers/localai.js | 114 ++++++++++++++++++ .../agents/aibitat/providers/togetherai.js | 113 +++++++++++++++++ server/utils/agents/index.js | 42 +++++++ server/utils/helpers/customModels.js | 2 +- 14 files changed, 645 insertions(+), 8 deletions(-) create mode 100644 server/utils/agents/aibitat/providers/azure.js create mode 100644 server/utils/agents/aibitat/providers/groq.js create mode 100644 server/utils/agents/aibitat/providers/koboldcpp.js create mode 100644 server/utils/agents/aibitat/providers/localai.js create mode 100644 server/utils/agents/aibitat/providers/togetherai.js diff --git a/.vscode/settings.json b/.vscode/settings.json index f850bbb00..eecaa83fd 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -28,6 +28,7 @@ "openrouter", "Qdrant", "Serper", + "togetherai", "vectordbs", "Weaviate", "Zilliz" diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index fcb12d94d..400eef02d 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import AgentModelSelection from "../AgentModelSelection"; -const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; -const WARN_PERFORMANCE = ["lmstudio", "ollama"]; +const ENABLED_PROVIDERS = [ + "openai", + "anthropic", + "lmstudio", + "ollama", + "localai", + "groq", + "azure", + "koboldcpp", + "togetherai", +]; +const WARN_PERFORMANCE = [ + "lmstudio", + "groq", + "azure", + "koboldcpp", + "ollama", + "localai", +]; const LLM_DEFAULT = { name: "Please make a selection", diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 9cf2170b7..3413bd359 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -480,7 +480,7 @@ Read the following conversation. CHAT HISTORY ${history.map((c) => `@${c.from}: ${c.content}`).join("\n")} -Then select the next role from that is going to speak next. +Then select the next role from that is going to speak next. Only return the role. `, }, @@ -522,7 +522,7 @@ Only return the role. ? [ { role: "user", - content: `You are in a whatsapp group. Read the following conversation and then reply. + content: `You are in a whatsapp group. Read the following conversation and then reply. Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself. CHAT HISTORY @@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to }) return new Providers.LMStudioProvider({}); case "ollama": return new Providers.OllamaProvider({ model: config.model }); + case "groq": + return new Providers.GroqProvider({ model: config.model }); + case "togetherai": + return new Providers.TogetherAIProvider({ model: config.model }); + case "azure": + return new Providers.AzureOpenAiProvider({ model: config.model }); + case "koboldcpp": + return new Providers.KoboldCPPProvider({}); + case "localai": + return new Providers.LocalAIProvider({ model: config.model }); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 0e871b36e..91a81ebfa 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -58,6 +58,9 @@ class Provider { } } + // For some providers we may want to override the system prompt to be more verbose. + // Currently we only do this for lmstudio, but we probably will want to expand this even more + // to any Untooled LLM. static systemPrompt(provider = null) { switch (provider) { case "lmstudio": diff --git a/server/utils/agents/aibitat/providers/azure.js b/server/utils/agents/aibitat/providers/azure.js new file mode 100644 index 000000000..cdcf7618b --- /dev/null +++ b/server/utils/agents/aibitat/providers/azure.js @@ -0,0 +1,105 @@ +const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Azure OpenAI API. + */ +class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(_config = {}) { + super(); + const client = new OpenAIClient( + process.env.AZURE_OPENAI_ENDPOINT, + new AzureKeyCredential(process.env.AZURE_OPENAI_KEY) + ); + this._client = client; + this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo"; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client + .getChatCompletions(this.model, messages, { + temperature: 0, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("Azure OpenAI chat: No results!"); + if (result.choices.length === 0) + throw new Error("Azure OpenAI chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.getChatCompletions( + this.model, + this.cleanMsgs(messages), + { + temperature: 0.7, + } + ); + completion = response.choices[0].message; + } + return { result: completion.content, cost: 0 }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * Stubbed since Azure OpenAI has no public cost basis. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = AzureOpenAiProvider; diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js new file mode 100644 index 000000000..3b87ba510 --- /dev/null +++ b/server/utils/agents/aibitat/providers/groq.js @@ -0,0 +1,110 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const { RetryError } = require("../error.js"); + +/** + * The provider for the Groq provider. + */ +class GroqProvider extends Provider { + model; + + constructor(config = {}) { + const { model = "llama3-8b-8192" } = config; + const client = new OpenAI({ + baseURL: "https://api.groq.com/openai/v1", + apiKey: process.env.GROQ_API_KEY, + maxRetries: 3, + }); + super(client); + this.model = model; + this.verbose = true; + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + const response = await this.client.chat.completions.create({ + model: this.model, + // stream: true, + messages, + ...(Array.isArray(functions) && functions?.length > 0 + ? { functions } + : {}), + }); + + // Right now, we only support one completion, + // so we just take the first one in the list + const completion = response.choices[0].message; + const cost = this.getCost(response.usage); + // treat function calls + if (completion.function_call) { + let functionArgs = {}; + try { + functionArgs = JSON.parse(completion.function_call.arguments); + } catch (error) { + // call the complete function again in case it gets a json error + return this.complete( + [ + ...messages, + { + role: "function", + name: completion.function_call.name, + function_call: completion.function_call, + content: error?.message, + }, + ], + functions + ); + } + + // console.log(completion, { functionArgs }) + return { + result: null, + functionCall: { + name: completion.function_call.name, + arguments: functionArgs, + }, + cost, + }; + } + + return { + result: completion.content, + cost, + }; + } catch (error) { + // If invalid Auth error we need to abort because no amount of waiting + // will make auth better. + if (error instanceof OpenAI.AuthenticationError) throw error; + + if ( + error instanceof OpenAI.RateLimitError || + error instanceof OpenAI.InternalServerError || + error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); + } + + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since Groq has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = GroqProvider; diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js index 37ecb5599..11fbfec8b 100644 --- a/server/utils/agents/aibitat/providers/helpers/untooled.js +++ b/server/utils/agents/aibitat/providers/helpers/untooled.js @@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; const response = await chatCb({ messages: [ { - content: `You are a program which picks the most optimal function and parameters to call. + content: `You are a program which picks the most optimal function and parameters to call. DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. When a function is selection, respond in JSON with no additional text. When there is no relevant function to call - return with a regular chat text response. @@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; ...history, ], }); - const call = safeJsonParse(response, null); if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index fda8b5136..6f8a2da0b 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js"); const AnthropicProvider = require("./anthropic.js"); const LMStudioProvider = require("./lmstudio.js"); const OllamaProvider = require("./ollama.js"); +const GroqProvider = require("./groq.js"); +const TogetherAIProvider = require("./togetherai.js"); +const AzureOpenAiProvider = require("./azure.js"); +const KoboldCPPProvider = require("./koboldcpp.js"); +const LocalAIProvider = require("./localai.js"); module.exports = { OpenAIProvider, AnthropicProvider, LMStudioProvider, OllamaProvider, + GroqProvider, + TogetherAIProvider, + AzureOpenAiProvider, + KoboldCPPProvider, + LocalAIProvider, }; diff --git a/server/utils/agents/aibitat/providers/koboldcpp.js b/server/utils/agents/aibitat/providers/koboldcpp.js new file mode 100644 index 000000000..77088263c --- /dev/null +++ b/server/utils/agents/aibitat/providers/koboldcpp.js @@ -0,0 +1,113 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the KoboldCPP provider. + */ +class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(_config = {}) { + super(); + const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null; + const client = new OpenAI({ + baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""), + apiKey: null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("KoboldCPP chat: No results!"); + if (result.choices.length === 0) + throw new Error("KoboldCPP chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since KoboldCPP has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = KoboldCPPProvider; diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index d3aa4346a..f5c4a2e82 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance apiKey: null, maxRetries: 3, - model, }); + this._client = client; this.model = model; this.verbose = true; diff --git a/server/utils/agents/aibitat/providers/localai.js b/server/utils/agents/aibitat/providers/localai.js new file mode 100644 index 000000000..161172c21 --- /dev/null +++ b/server/utils/agents/aibitat/providers/localai.js @@ -0,0 +1,114 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the LocalAI provider. + */ +class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = null } = config; + super(); + const client = new OpenAI({ + baseURL: process.env.LOCAL_AI_BASE_PATH, + apiKey: process.env.LOCAL_AI_API_KEY ?? null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LocalAI chat: No results!"); + + if (result.choices.length === 0) + throw new Error("LocalAI chat: No results length!"); + + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { result: completion.content, cost: 0 }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LocalAI has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = LocalAiProvider; diff --git a/server/utils/agents/aibitat/providers/togetherai.js b/server/utils/agents/aibitat/providers/togetherai.js new file mode 100644 index 000000000..4ea5e11c2 --- /dev/null +++ b/server/utils/agents/aibitat/providers/togetherai.js @@ -0,0 +1,113 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the TogetherAI provider. + */ +class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config; + super(); + const client = new OpenAI({ + baseURL: "https://api.together.xyz/v1", + apiKey: process.env.TOGETHER_AI_API_KEY, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LMStudio chat: No results!"); + if (result.choices.length === 0) + throw new Error("LMStudio chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LMStudio has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = TogetherAIProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index e18b8b7bb..768ad8199 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -85,6 +85,36 @@ class AgentHandler { if (!process.env.OLLAMA_BASE_PATH) throw new Error("Ollama base path must be provided to use agents."); break; + case "groq": + if (!process.env.GROQ_API_KEY) + throw new Error("Groq API key must be provided to use agents."); + break; + case "togetherai": + if (!process.env.TOGETHER_AI_API_KEY) + throw new Error("TogetherAI API key must be provided to use agents."); + break; + case "azure": + if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY) + throw new Error( + "Azure OpenAI API endpoint and key must be provided to use agents." + ); + break; + case "koboldcpp": + if (!process.env.KOBOLD_CPP_BASE_PATH) + throw new Error( + "KoboldCPP must have a valid base path to use for the api." + ); + break; + case "localai": + if (!process.env.LOCAL_AI_BASE_PATH) + throw new Error( + "LocalAI must have a valid base path to use for the api." + ); + break; + case "gemini": + if (!process.env.GEMINI_API_KEY) + throw new Error("Gemini API key must be provided to use agents."); + break; default: throw new Error("No provider found to power agent cluster."); } @@ -100,6 +130,18 @@ class AgentHandler { return "server-default"; case "ollama": return "llama3:latest"; + case "groq": + return "llama3-70b-8192"; + case "togetherai": + return "mistralai/Mixtral-8x7B-Instruct-v0.1"; + case "azure": + return "gpt-3.5-turbo"; + case "koboldcpp": + return null; + case "gemini": + return "gemini-pro"; + case "localai": + return null; default: return "unknown"; } diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index ce690ae47..3743ffad7 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) { try { const { OpenAI: OpenAIApi } = require("openai"); const openai = new OpenAIApi({ - baseURL: basePath || process.env.LMSTUDIO_BASE_PATH, + baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH, apiKey: null, }); const models = await openai.models From 81bc16cc39b84c250acb8d5c3431762db9af8387 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Wed, 8 May 2024 16:04:18 -0700 Subject: [PATCH 02/11] More agent providers (#1316) * add OpenRouter support * add mistral agents add perplexity agents add textwebgenui agents --- .vscode/settings.json | 5 + .../AgentConfig/AgentLLMSelection/index.jsx | 12 ++ server/utils/agents/aibitat/index.js | 10 ++ .../agents/aibitat/providers/genericOpenAi.js | 115 +++++++++++++++++ server/utils/agents/aibitat/providers/groq.js | 3 + .../utils/agents/aibitat/providers/index.js | 10 ++ .../utils/agents/aibitat/providers/mistral.js | 116 +++++++++++++++++ .../agents/aibitat/providers/openrouter.js | 117 ++++++++++++++++++ .../agents/aibitat/providers/perplexity.js | 112 +++++++++++++++++ .../agents/aibitat/providers/textgenwebui.js | 112 +++++++++++++++++ server/utils/agents/index.js | 33 +++++ 11 files changed, 645 insertions(+) create mode 100644 server/utils/agents/aibitat/providers/genericOpenAi.js create mode 100644 server/utils/agents/aibitat/providers/mistral.js create mode 100644 server/utils/agents/aibitat/providers/openrouter.js create mode 100644 server/utils/agents/aibitat/providers/perplexity.js create mode 100644 server/utils/agents/aibitat/providers/textgenwebui.js diff --git a/.vscode/settings.json b/.vscode/settings.json index eecaa83fd..110c4fa6e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -15,19 +15,24 @@ "epub", "GROQ", "hljs", + "huggingface", "inferencing", + "koboldcpp", "Langchain", "lmstudio", + "localai", "mbox", "Milvus", "Mintplex", "moderations", "Ollama", + "Oobabooga", "openai", "opendocument", "openrouter", "Qdrant", "Serper", + "textgenwebui", "togetherai", "vectordbs", "Weaviate", diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index 400eef02d..ef260dec1 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -15,6 +15,16 @@ const ENABLED_PROVIDERS = [ "azure", "koboldcpp", "togetherai", + "openrouter", + "mistral", + "perplexity", + "textgenwebui", + // TODO: More agent support. + // "generic-openai", // Need to support text-input for agent model input for this to be enabled. + // "cohere", // Has tool calling and will need to build explicit support + // "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested. + // "gemini", // Too rate limited and broken in several ways to use for agents. + // "gemini", // Too rate limited and broken in several ways to use for agents. ]; const WARN_PERFORMANCE = [ "lmstudio", @@ -23,6 +33,8 @@ const WARN_PERFORMANCE = [ "koboldcpp", "ollama", "localai", + "openrouter", + "generic-openai", ]; const LLM_DEFAULT = { diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 3413bd359..f21c4aa45 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -753,6 +753,16 @@ ${this.getHistory({ to: route.to }) return new Providers.KoboldCPPProvider({}); case "localai": return new Providers.LocalAIProvider({ model: config.model }); + case "openrouter": + return new Providers.OpenRouterProvider({ model: config.model }); + case "mistral": + return new Providers.MistralProvider({ model: config.model }); + case "generic-openai": + return new Providers.GenericOpenAiProvider({ model: config.model }); + case "perplexity": + return new Providers.PerplexityProvider({ model: config.model }); + case "textgenwebui": + return new Providers.TextWebGenUiProvider({}); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/genericOpenAi.js b/server/utils/agents/aibitat/providers/genericOpenAi.js new file mode 100644 index 000000000..3521bc7d0 --- /dev/null +++ b/server/utils/agents/aibitat/providers/genericOpenAi.js @@ -0,0 +1,115 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Generic OpenAI provider. + * Since we cannot promise the generic provider even supports tool calling + * which is nearly 100% likely it does not, we can just wrap it in untooled + * which often is far better anyway. + */ +class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + super(); + const { model = "gpt-3.5-turbo" } = config; + const client = new OpenAI({ + baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH, + apiKey: process.env.GENERIC_OPEN_AI_API_KEY ?? null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("Generic OpenAI chat: No results!"); + if (result.choices.length === 0) + throw new Error("Generic OpenAI chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = GenericOpenAiProvider; diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js index 3b87ba510..720d87437 100644 --- a/server/utils/agents/aibitat/providers/groq.js +++ b/server/utils/agents/aibitat/providers/groq.js @@ -4,6 +4,9 @@ const { RetryError } = require("../error.js"); /** * The provider for the Groq provider. + * Using OpenAI tool calling with groq really sucks right now + * its just fast and bad. We should probably migrate this to Untooled to improve + * coherence. */ class GroqProvider extends Provider { model; diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index 6f8a2da0b..14748b2ec 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -7,6 +7,11 @@ const TogetherAIProvider = require("./togetherai.js"); const AzureOpenAiProvider = require("./azure.js"); const KoboldCPPProvider = require("./koboldcpp.js"); const LocalAIProvider = require("./localai.js"); +const OpenRouterProvider = require("./openrouter.js"); +const MistralProvider = require("./mistral.js"); +const GenericOpenAiProvider = require("./genericOpenAi.js"); +const PerplexityProvider = require("./perplexity.js"); +const TextWebGenUiProvider = require("./textgenwebui.js"); module.exports = { OpenAIProvider, @@ -18,4 +23,9 @@ module.exports = { AzureOpenAiProvider, KoboldCPPProvider, LocalAIProvider, + OpenRouterProvider, + MistralProvider, + GenericOpenAiProvider, + PerplexityProvider, + TextWebGenUiProvider, }; diff --git a/server/utils/agents/aibitat/providers/mistral.js b/server/utils/agents/aibitat/providers/mistral.js new file mode 100644 index 000000000..cdc2a5e75 --- /dev/null +++ b/server/utils/agents/aibitat/providers/mistral.js @@ -0,0 +1,116 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Mistral provider. + * Mistral limits what models can call tools and even when using those + * the model names change and dont match docs. When you do have the right model + * it still fails and is not truly OpenAI compatible so its easier to just wrap + * this with Untooled which 100% works since its just text & works far more reliably + */ +class MistralProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + super(); + const { model = "mistral-medium" } = config; + const client = new OpenAI({ + baseURL: "https://api.mistral.ai/v1", + apiKey: process.env.MISTRAL_API_KEY, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LMStudio chat: No results!"); + if (result.choices.length === 0) + throw new Error("LMStudio chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = MistralProvider; diff --git a/server/utils/agents/aibitat/providers/openrouter.js b/server/utils/agents/aibitat/providers/openrouter.js new file mode 100644 index 000000000..81297ae28 --- /dev/null +++ b/server/utils/agents/aibitat/providers/openrouter.js @@ -0,0 +1,117 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the OpenRouter provider. + */ +class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = "openrouter/auto" } = config; + super(); + const client = new OpenAI({ + baseURL: "https://openrouter.ai/api/v1", + apiKey: process.env.OPENROUTER_API_KEY, + maxRetries: 3, + defaultHeaders: { + "HTTP-Referer": "https://useanything.com", + "X-Title": "AnythingLLM", + }, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("OpenRouter chat: No results!"); + if (result.choices.length === 0) + throw new Error("OpenRouter chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since OpenRouter has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = OpenRouterProvider; diff --git a/server/utils/agents/aibitat/providers/perplexity.js b/server/utils/agents/aibitat/providers/perplexity.js new file mode 100644 index 000000000..29970fd06 --- /dev/null +++ b/server/utils/agents/aibitat/providers/perplexity.js @@ -0,0 +1,112 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Perplexity provider. + */ +class PerplexityProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + super(); + const { model = "sonar-small-online" } = config; + const client = new OpenAI({ + baseURL: "https://api.perplexity.ai", + apiKey: process.env.PERPLEXITY_API_KEY ?? null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("Perplexity chat: No results!"); + if (result.choices.length === 0) + throw new Error("Perplexity chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = PerplexityProvider; diff --git a/server/utils/agents/aibitat/providers/textgenwebui.js b/server/utils/agents/aibitat/providers/textgenwebui.js new file mode 100644 index 000000000..767577d42 --- /dev/null +++ b/server/utils/agents/aibitat/providers/textgenwebui.js @@ -0,0 +1,112 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Oobabooga provider. + */ +class TextWebGenUiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(_config = {}) { + super(); + const client = new OpenAI({ + baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH, + apiKey: null, + maxRetries: 3, + }); + + this._client = client; + this.model = null; // text-web-gen-ui does not have a model pref. + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("Oobabooga chat: No results!"); + if (result.choices.length === 0) + throw new Error("Oobabooga chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since KoboldCPP has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = TextWebGenUiProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 768ad8199..7851deb7b 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -115,6 +115,29 @@ class AgentHandler { if (!process.env.GEMINI_API_KEY) throw new Error("Gemini API key must be provided to use agents."); break; + case "openrouter": + if (!process.env.OPENROUTER_API_KEY) + throw new Error("OpenRouter API key must be provided to use agents."); + break; + case "mistral": + if (!process.env.MISTRAL_API_KEY) + throw new Error("Mistral API key must be provided to use agents."); + break; + case "generic-openai": + if (!process.env.GENERIC_OPEN_AI_BASE_PATH) + throw new Error("API base path must be provided to use agents."); + break; + case "perplexity": + if (!process.env.PERPLEXITY_API_KEY) + throw new Error("Perplexity API key must be provided to use agents."); + break; + case "textgenwebui": + if (!process.env.TEXT_GEN_WEB_UI_BASE_PATH) + throw new Error( + "TextWebGenUI API base path must be provided to use agents." + ); + break; + default: throw new Error("No provider found to power agent cluster."); } @@ -142,6 +165,16 @@ class AgentHandler { return "gemini-pro"; case "localai": return null; + case "openrouter": + return "openrouter/auto"; + case "mistral": + return "mistral-medium"; + case "generic-openai": + return "gpt-3.5-turbo"; + case "perplexity": + return "sonar-small-online"; + case "textgenwebui": + return null; default: return "unknown"; } From 2da0d39b5112defeee732e858f5f5fcfc25834c4 Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Wed, 8 May 2024 16:04:45 -0700 Subject: [PATCH 03/11] update todo comment --- .../WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index ef260dec1..51c115817 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -24,7 +24,6 @@ const ENABLED_PROVIDERS = [ // "cohere", // Has tool calling and will need to build explicit support // "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested. // "gemini", // Too rate limited and broken in several ways to use for agents. - // "gemini", // Too rate limited and broken in several ways to use for agents. ]; const WARN_PERFORMANCE = [ "lmstudio", @@ -35,6 +34,7 @@ const WARN_PERFORMANCE = [ "localai", "openrouter", "generic-openai", + "textgenwebui", ]; const LLM_DEFAULT = { From 0f981abd40bb3d4904bbf996bb22273f0f92c865 Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Wed, 8 May 2024 16:23:13 -0700 Subject: [PATCH 04/11] remove unused import --- frontend/src/components/Modals/Password/MultiUserAuth.jsx | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/src/components/Modals/Password/MultiUserAuth.jsx b/frontend/src/components/Modals/Password/MultiUserAuth.jsx index a44e040c4..e4de5e67e 100644 --- a/frontend/src/components/Modals/Password/MultiUserAuth.jsx +++ b/frontend/src/components/Modals/Password/MultiUserAuth.jsx @@ -1,7 +1,6 @@ import React, { useEffect, useState } from "react"; import System from "../../../models/system"; import { AUTH_TOKEN, AUTH_USER } from "../../../utils/constants"; -import useLogo from "../../../hooks/useLogo"; import paths from "../../../utils/paths"; import showToast from "@/utils/toast"; import ModalWrapper from "@/components/ModalWrapper"; @@ -163,7 +162,6 @@ const ResetPasswordForm = ({ onSubmit }) => { export default function MultiUserAuth() { const [loading, setLoading] = useState(false); const [error, setError] = useState(null); - const { logo: _initLogo } = useLogo(); const [recoveryCodes, setRecoveryCodes] = useState([]); const [downloadComplete, setDownloadComplete] = useState(false); const [user, setUser] = useState(null); From 6eefd0d2806ee4ceec4aae98f4361eb1b2a30cbe Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Fri, 10 May 2024 09:48:03 -0700 Subject: [PATCH 05/11] update STORAGE_DIR for baremetal.md resolves #1340 --- BARE_METAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/BARE_METAL.md b/BARE_METAL.md index 8d495d81b..220ef0f89 100644 --- a/BARE_METAL.md +++ b/BARE_METAL.md @@ -27,7 +27,7 @@ Here you can find the scripts and known working process to run AnythingLLM outsi 4. Ensure that the `server/.env` file has _at least_ these keys to start. These values will persist and this file will be automatically written and managed after your first successful boot. ``` -STORAGE_DIR="/your/absolute/path/to/server/.env" +STORAGE_DIR="/your/absolute/path/to/server/storage" ``` 5. Edit the `frontend/.env` file for the `VITE_BASE_API` to now be set to `/api`. This is documented in the .env for which one you should use. From d36c3ff8b2fcb16edcd0992a75a289ba44f1cd77 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Fri, 10 May 2024 12:35:33 -0700 Subject: [PATCH 06/11] [FEAT] Slash templates (#1314) * WIP slash presets * WIP slash command customization CRUD + validations complete * backend slash command support * fix permission setting on new slash commands rework form submit and pattern on frontend * Add field updates for hooks, required=true to field add user<>command constraint to keep them unique enforce uniquness via teritary uid field on table for multi and non-multi user * reset migration --------- Co-authored-by: timothycarambat --- .../SlashPresets/AddPresetModal.jsx | 111 +++++++++++++ .../SlashPresets/EditPresetModal.jsx | 148 ++++++++++++++++++ .../SlashCommands/SlashPresets/index.jsx | 127 +++++++++++++++ .../PromptInput/SlashCommands/index.jsx | 4 +- frontend/src/models/system.js | 68 ++++++++ server/endpoints/system.js | 106 +++++++++++++ server/models/slashCommandsPresets.js | 105 +++++++++++++ .../20240510032311_init/migration.sql | 15 ++ server/prisma/schema.prisma | 15 ++ server/utils/chats/index.js | 26 ++- server/utils/chats/stream.js | 8 +- 11 files changed, 722 insertions(+), 11 deletions(-) create mode 100644 frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx create mode 100644 frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx create mode 100644 frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx create mode 100644 server/models/slashCommandsPresets.js create mode 100644 server/prisma/migrations/20240510032311_init/migration.sql diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx new file mode 100644 index 000000000..e5154580b --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx @@ -0,0 +1,111 @@ +import { useState } from "react"; +import { X } from "@phosphor-icons/react"; +import ModalWrapper from "@/components/ModalWrapper"; +import { CMD_REGEX } from "."; + +export default function AddPresetModal({ isOpen, onClose, onSave }) { + const [command, setCommand] = useState(""); + + const handleSubmit = async (e) => { + e.preventDefault(); + const form = new FormData(e.target); + const sanitizedCommand = command.replace(CMD_REGEX, ""); + const saved = await onSave({ + command: `/${sanitizedCommand}`, + prompt: form.get("prompt"), + description: form.get("description"), + }); + if (saved) setCommand(""); + }; + + const handleCommandChange = (e) => { + const value = e.target.value.replace(CMD_REGEX, ""); + setCommand(value); + }; + + return ( + +
+
+
+

Add New Preset

+ +
+
+
+
+ +
+ / + +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+
+
+ ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx new file mode 100644 index 000000000..fdffbe609 --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx @@ -0,0 +1,148 @@ +import { useState } from "react"; +import { X } from "@phosphor-icons/react"; +import ModalWrapper from "@/components/ModalWrapper"; +import { CMD_REGEX } from "."; + +export default function EditPresetModal({ + isOpen, + onClose, + onSave, + onDelete, + preset, +}) { + const [command, setCommand] = useState(preset?.command?.slice(1) || ""); + const [deleting, setDeleting] = useState(false); + + const handleSubmit = (e) => { + e.preventDefault(); + const form = new FormData(e.target); + const sanitizedCommand = command.replace(CMD_REGEX, ""); + onSave({ + id: preset.id, + command: `/${sanitizedCommand}`, + prompt: form.get("prompt"), + description: form.get("description"), + }); + }; + + const handleCommandChange = (e) => { + const value = e.target.value.replace(CMD_REGEX, ""); + setCommand(value); + }; + + const handleDelete = async () => { + const confirmDelete = window.confirm( + "Are you sure you want to delete this preset?" + ); + if (!confirmDelete) return; + + setDeleting(true); + await onDelete(preset.id); + setDeleting(false); + onClose(); + }; + + return ( + +
+
+
+

Edit Preset

+ +
+
+
+
+ +
+ / + +
+
+
+ + +
+
+ + +
+
+
+
+
+ +
+
+ + +
+
+
+
+
+ ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx new file mode 100644 index 000000000..ca39b68a8 --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx @@ -0,0 +1,127 @@ +import { useEffect, useState } from "react"; +import { useIsAgentSessionActive } from "@/utils/chat/agent"; +import AddPresetModal from "./AddPresetModal"; +import EditPresetModal from "./EditPresetModal"; +import { useModal } from "@/hooks/useModal"; +import System from "@/models/system"; +import { DotsThree, Plus } from "@phosphor-icons/react"; +import showToast from "@/utils/toast"; + +export const CMD_REGEX = new RegExp(/[^a-zA-Z0-9_-]/g); +export default function SlashPresets({ setShowing, sendCommand }) { + const isActiveAgentSession = useIsAgentSessionActive(); + const { + isOpen: isAddModalOpen, + openModal: openAddModal, + closeModal: closeAddModal, + } = useModal(); + const { + isOpen: isEditModalOpen, + openModal: openEditModal, + closeModal: closeEditModal, + } = useModal(); + const [presets, setPresets] = useState([]); + const [selectedPreset, setSelectedPreset] = useState(null); + + useEffect(() => { + fetchPresets(); + }, []); + if (isActiveAgentSession) return null; + + const fetchPresets = async () => { + const presets = await System.getSlashCommandPresets(); + setPresets(presets); + }; + + const handleSavePreset = async (preset) => { + const { error } = await System.createSlashCommandPreset(preset); + if (!!error) { + showToast(error, "error"); + return false; + } + + fetchPresets(); + closeAddModal(); + return true; + }; + + const handleEditPreset = (preset) => { + setSelectedPreset(preset); + openEditModal(); + }; + + const handleUpdatePreset = async (updatedPreset) => { + const { error } = await System.updateSlashCommandPreset( + updatedPreset.id, + updatedPreset + ); + + if (!!error) { + showToast(error, "error"); + return; + } + + fetchPresets(); + closeEditModal(); + }; + + const handleDeletePreset = async (presetId) => { + await System.deleteSlashCommandPreset(presetId); + fetchPresets(); + closeEditModal(); + }; + + return ( + <> + {presets.map((preset) => ( + + + ))} + + + {selectedPreset && ( + + )} + + ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx index 5a606af6d..9b626372c 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx @@ -3,6 +3,7 @@ import SlashCommandIcon from "./icons/slash-commands-icon.svg"; import { Tooltip } from "react-tooltip"; import ResetCommand from "./reset"; import EndAgentSession from "./endAgentSession"; +import SlashPresets from "./SlashPresets"; export default function SlashCommandsButton({ showing, setShowSlashCommand }) { return ( @@ -52,10 +53,11 @@ export function SlashCommands({ showing, setShowing, sendCommand }) {
+
diff --git a/frontend/src/models/system.js b/frontend/src/models/system.js index af532a047..e64b01199 100644 --- a/frontend/src/models/system.js +++ b/frontend/src/models/system.js @@ -567,6 +567,74 @@ const System = { }); }, dataConnectors: DataConnector, + + getSlashCommandPresets: async function () { + return await fetch(`${API_BASE}/system/slash-command-presets`, { + method: "GET", + headers: baseHeaders(), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not fetch slash command presets."); + return res.json(); + }) + .then((res) => res.presets) + .catch((e) => { + console.error(e); + return []; + }); + }, + + createSlashCommandPreset: async function (presetData) { + return await fetch(`${API_BASE}/system/slash-command-presets`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify(presetData), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not create slash command preset."); + return res.json(); + }) + .then((res) => { + return { preset: res.preset, error: null }; + }) + .catch((e) => { + console.error(e); + return { preset: null, error: e.message }; + }); + }, + + updateSlashCommandPreset: async function (presetId, presetData) { + return await fetch(`${API_BASE}/system/slash-command-presets/${presetId}`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify(presetData), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not update slash command preset."); + return res.json(); + }) + .then((res) => { + return { preset: res.preset, error: null }; + }) + .catch((e) => { + return { preset: null, error: "Failed to update this command." }; + }); + }, + + deleteSlashCommandPreset: async function (presetId) { + return await fetch(`${API_BASE}/system/slash-command-presets/${presetId}`, { + method: "DELETE", + headers: baseHeaders(), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not delete slash command preset."); + return true; + }) + .catch((e) => { + console.error(e); + return false; + }); + }, }; export default System; diff --git a/server/endpoints/system.js b/server/endpoints/system.js index 60d51e35f..4538ee060 100644 --- a/server/endpoints/system.js +++ b/server/endpoints/system.js @@ -50,6 +50,7 @@ const { resetPassword, generateRecoveryCodes, } = require("../utils/PasswordRecovery"); +const { SlashCommandPresets } = require("../models/slashCommandsPresets"); function systemEndpoints(app) { if (!app) return; @@ -1044,6 +1045,111 @@ function systemEndpoints(app) { response.sendStatus(500).end(); } }); + + app.get( + "/system/slash-command-presets", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const userPresets = await SlashCommandPresets.getUserPresets(user?.id); + response.status(200).json({ presets: userPresets }); + } catch (error) { + console.error("Error fetching slash command presets:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.post( + "/system/slash-command-presets", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { command, prompt, description } = reqBody(request); + const presetData = { + command: SlashCommandPresets.formatCommand(String(command)), + prompt: String(prompt), + description: String(description), + }; + + const preset = await SlashCommandPresets.create(user?.id, presetData); + if (!preset) { + return response + .status(500) + .json({ message: "Failed to create preset" }); + } + response.status(201).json({ preset }); + } catch (error) { + console.error("Error creating slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.post( + "/system/slash-command-presets/:slashCommandId", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { slashCommandId } = request.params; + const { command, prompt, description } = reqBody(request); + + // Valid user running owns the preset if user session is valid. + const ownsPreset = await SlashCommandPresets.get({ + userId: user?.id ?? null, + id: Number(slashCommandId), + }); + if (!ownsPreset) + return response.status(404).json({ message: "Preset not found" }); + + const updates = { + command: SlashCommandPresets.formatCommand(String(command)), + prompt: String(prompt), + description: String(description), + }; + + const preset = await SlashCommandPresets.update( + Number(slashCommandId), + updates + ); + if (!preset) return response.sendStatus(422); + response.status(200).json({ preset: { ...ownsPreset, ...updates } }); + } catch (error) { + console.error("Error updating slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.delete( + "/system/slash-command-presets/:slashCommandId", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const { slashCommandId } = request.params; + const user = await userFromSession(request, response); + + // Valid user running owns the preset if user session is valid. + const ownsPreset = await SlashCommandPresets.get({ + userId: user?.id ?? null, + id: Number(slashCommandId), + }); + if (!ownsPreset) + return response + .status(403) + .json({ message: "Failed to delete preset" }); + + await SlashCommandPresets.delete(Number(slashCommandId)); + response.sendStatus(204); + } catch (error) { + console.error("Error deleting slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); } module.exports = { systemEndpoints }; diff --git a/server/models/slashCommandsPresets.js b/server/models/slashCommandsPresets.js new file mode 100644 index 000000000..4828c77d5 --- /dev/null +++ b/server/models/slashCommandsPresets.js @@ -0,0 +1,105 @@ +const { v4 } = require("uuid"); +const prisma = require("../utils/prisma"); +const CMD_REGEX = new RegExp(/[^a-zA-Z0-9_-]/g); + +const SlashCommandPresets = { + formatCommand: function (command = "") { + if (!command || command.length < 2) return `/${v4().split("-")[0]}`; + + let adjustedCmd = command.toLowerCase(); // force lowercase + if (!adjustedCmd.startsWith("/")) adjustedCmd = `/${adjustedCmd}`; // Fix if no preceding / is found. + return `/${adjustedCmd.slice(1).toLowerCase().replace(CMD_REGEX, "-")}`; // replace any invalid chars with '-' + }, + + get: async function (clause = {}) { + try { + const preset = await prisma.slash_command_presets.findFirst({ + where: clause, + }); + return preset || null; + } catch (error) { + console.error(error.message); + return null; + } + }, + + where: async function (clause = {}, limit) { + try { + const presets = await prisma.slash_command_presets.findMany({ + where: clause, + take: limit || undefined, + }); + return presets; + } catch (error) { + console.error(error.message); + return []; + } + }, + + // Command + userId must be unique combination. + create: async function (userId = null, presetData = {}) { + try { + const preset = await prisma.slash_command_presets.create({ + data: { + ...presetData, + // This field (uid) is either the user_id or 0 (for non-multi-user mode). + // the UID field enforces the @@unique(userId, command) constraint since + // the real relational field (userId) cannot be non-null so this 'dummy' field gives us something + // to constrain against within the context of prisma and sqlite that works. + uid: userId ? Number(userId) : 0, + userId: userId ? Number(userId) : null, + }, + }); + return preset; + } catch (error) { + console.error("Failed to create preset", error.message); + return null; + } + }, + + getUserPresets: async function (userId = null) { + try { + return ( + await prisma.slash_command_presets.findMany({ + where: { userId: !!userId ? Number(userId) : null }, + orderBy: { createdAt: "asc" }, + }) + )?.map((preset) => ({ + id: preset.id, + command: preset.command, + prompt: preset.prompt, + description: preset.description, + })); + } catch (error) { + console.error("Failed to get user presets", error.message); + return []; + } + }, + + update: async function (presetId = null, presetData = {}) { + try { + const preset = await prisma.slash_command_presets.update({ + where: { id: Number(presetId) }, + data: presetData, + }); + return preset; + } catch (error) { + console.error("Failed to update preset", error.message); + return null; + } + }, + + delete: async function (presetId = null) { + try { + await prisma.slash_command_presets.delete({ + where: { id: Number(presetId) }, + }); + return true; + } catch (error) { + console.error("Failed to delete preset", error.message); + return false; + } + }, +}; + +module.exports.SlashCommandPresets = SlashCommandPresets; diff --git a/server/prisma/migrations/20240510032311_init/migration.sql b/server/prisma/migrations/20240510032311_init/migration.sql new file mode 100644 index 000000000..3b82efb88 --- /dev/null +++ b/server/prisma/migrations/20240510032311_init/migration.sql @@ -0,0 +1,15 @@ +-- CreateTable +CREATE TABLE "slash_command_presets" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "command" TEXT NOT NULL, + "prompt" TEXT NOT NULL, + "description" TEXT NOT NULL, + "uid" INTEGER NOT NULL DEFAULT 0, + "userId" INTEGER, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "lastUpdatedAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "slash_command_presets_userId_fkey" FOREIGN KEY ("userId") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateIndex +CREATE UNIQUE INDEX "slash_command_presets_uid_command_key" ON "slash_command_presets"("uid", "command"); diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index b830de9b7..0ded65be6 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -73,6 +73,7 @@ model users { recovery_codes recovery_codes[] password_reset_tokens password_reset_tokens[] workspace_agent_invocations workspace_agent_invocations[] + slash_command_presets slash_command_presets[] } model recovery_codes { @@ -260,3 +261,17 @@ model event_logs { @@index([event]) } + +model slash_command_presets { + id Int @id @default(autoincrement()) + command String + prompt String + description String + uid Int @default(0) // 0 is null user + userId Int? + createdAt DateTime @default(now()) + lastUpdatedAt DateTime @default(now()) + user users? @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([uid, command]) +} diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 76f98e0df..55e8fbe5f 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -4,14 +4,28 @@ const { resetMemory } = require("./commands/reset"); const { getVectorDbClass, getLLMProvider } = require("../helpers"); const { convertToPromptHistory } = require("../helpers/chat/responses"); const { DocumentManager } = require("../DocumentManager"); +const { SlashCommandPresets } = require("../../models/slashCommandsPresets"); const VALID_COMMANDS = { "/reset": resetMemory, }; -function grepCommand(message) { +async function grepCommand(message, user = null) { + const userPresets = await SlashCommandPresets.getUserPresets(user?.id); const availableCommands = Object.keys(VALID_COMMANDS); + // Check if the message starts with any preset command + const foundPreset = userPresets.find((p) => message.startsWith(p.command)); + if (!!foundPreset) { + // Replace the preset command with the corresponding prompt + const updatedMessage = message.replace( + foundPreset.command, + foundPreset.prompt + ); + return updatedMessage; + } + + // Check if the message starts with any built-in command for (let i = 0; i < availableCommands.length; i++) { const cmd = availableCommands[i]; const re = new RegExp(`^(${cmd})`, "i"); @@ -20,7 +34,7 @@ function grepCommand(message) { } } - return null; + return message; } async function chatWithWorkspace( @@ -31,10 +45,10 @@ async function chatWithWorkspace( thread = null ) { const uuid = uuidv4(); - const command = grepCommand(message); + const updatedMessage = await grepCommand(message, user); - if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { - return await VALID_COMMANDS[command](workspace, message, uuid, user); + if (Object.keys(VALID_COMMANDS).includes(updatedMessage)) { + return await VALID_COMMANDS[updatedMessage](workspace, message, uuid, user); } const LLMConnector = getLLMProvider({ @@ -164,7 +178,7 @@ async function chatWithWorkspace( const messages = await LLMConnector.compressMessages( { systemPrompt: chatPrompt(workspace), - userPrompt: message, + userPrompt: updatedMessage, contextTexts, chatHistory, }, diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index ba4dea163..ec8fdbfac 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -23,10 +23,10 @@ async function streamChatWithWorkspace( thread = null ) { const uuid = uuidv4(); - const command = grepCommand(message); + const updatedMessage = await grepCommand(message, user); - if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { - const data = await VALID_COMMANDS[command]( + if (Object.keys(VALID_COMMANDS).includes(updatedMessage)) { + const data = await VALID_COMMANDS[updatedMessage]( workspace, message, uuid, @@ -185,7 +185,7 @@ async function streamChatWithWorkspace( const messages = await LLMConnector.compressMessages( { systemPrompt: chatPrompt(workspace), - userPrompt: message, + userPrompt: updatedMessage, contextTexts, chatHistory, }, From 8f068b80d72e5c02af195ad7776fab2a6d924785 Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Fri, 10 May 2024 14:09:25 -0700 Subject: [PATCH 07/11] chat history performance improvements with `memo` --- .../ChatHistory/HistoricalMessage/index.jsx | 20 +++++++++++++++---- .../ChatContainer/PromptInput/index.jsx | 18 +++++++++++++++-- .../WorkspaceChat/ChatContainer/index.jsx | 18 ++++++++++++----- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx index d9efd98cc..e6ebaf0de 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx @@ -23,9 +23,8 @@ const HistoricalMessage = ({ return (
{ + return ( + (prevProps.message === nextProps.message) && + (prevProps.isLastMessage === nextProps.isLastMessage) && + (prevProps.chatId === nextProps.chatId) + ); + } +); \ No newline at end of file diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx index 859f84174..0b28ac587 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx @@ -12,20 +12,33 @@ import AvailableAgentsButton, { useAvailableAgents, } from "./AgentMenu"; import TextSizeButton from "./TextSizeMenu"; + +export const PROMPT_INPUT_EVENT = 'set_prompt_input'; export default function PromptInput({ - message, submit, onChange, inputDisabled, buttonDisabled, sendCommand, }) { + const [promptInput, setPromptInput] = useState(''); const { showAgents, setShowAgents } = useAvailableAgents(); const { showSlashCommand, setShowSlashCommand } = useSlashCommands(); const formRef = useRef(null); const textareaRef = useRef(null); const [_, setFocused] = useState(false); + // To prevent too many re-renders we remotely listen for updates from the parent + // via an event cycle. Otherwise, using message as a prop leads to a re-render every + // change on the input. + function handlePromptUpdate(e) { setPromptInput(e?.detail ?? ''); } + useEffect(() => { + if (!!window) window.addEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate); + return () => ( + window?.removeEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate) + ) + }, []); + useEffect(() => { if (!inputDisabled && textareaRef.current) { textareaRef.current.focus(); @@ -102,6 +115,7 @@ export default function PromptInput({ watchForSlash(e); watchForAt(e); adjustTextArea(e); + setPromptInput(e.target.value) }} onKeyDown={captureEnter} required={true} @@ -111,7 +125,7 @@ export default function PromptInput({ setFocused(false); adjustTextArea(e); }} - value={message} + value={promptInput} className="cursor-text max-h-[100px] md:min-h-[40px] mx-2 md:mx-0 py-2 w-full text-[16px] md:text-md text-white bg-transparent placeholder:text-white/60 resize-none active:outline-none focus:outline-none flex-grow" placeholder={"Send a message"} /> diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 7d2850bdc..07608a7fe 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -1,6 +1,6 @@ import { useState, useEffect } from "react"; import ChatHistory from "./ChatHistory"; -import PromptInput from "./PromptInput"; +import PromptInput, { PROMPT_INPUT_EVENT } from "./PromptInput"; import Workspace from "@/models/workspace"; import handleChat, { ABORT_STREAM_EVENT } from "@/utils/chat"; import { isMobile } from "react-device-detect"; @@ -20,10 +20,19 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { const [chatHistory, setChatHistory] = useState(knownHistory); const [socketId, setSocketId] = useState(null); const [websocket, setWebsocket] = useState(null); + + // Maintain state of message from whatever is in PromptInput const handleMessageChange = (event) => { setMessage(event.target.value); }; + // Emit an update to the sate of the prompt input without directly + // passing a prop in so that it does not re-render constantly. + function setMessageEmit(messageContent = '') { + setMessage(messageContent); + window.dispatchEvent(new CustomEvent(PROMPT_INPUT_EVENT, { detail: messageContent })) + } + const handleSubmit = async (event) => { event.preventDefault(); if (!message || message === "") return false; @@ -41,14 +50,14 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { ]; setChatHistory(prevChatHistory); - setMessage(""); + setMessageEmit(""); setLoadingResponse(true); }; const sendCommand = async (command, submit = false) => { if (!command || command === "") return false; if (!submit) { - setMessage(command); + setMessageEmit(command); return; } @@ -65,7 +74,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { ]; setChatHistory(prevChatHistory); - setMessage(""); + setMessageEmit(""); setLoadingResponse(true); }; @@ -208,7 +217,6 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { sendCommand={sendCommand} /> Date: Fri, 10 May 2024 14:10:07 -0700 Subject: [PATCH 08/11] linting --- .../ChatHistory/HistoricalMessage/index.jsx | 13 +++++++------ .../ChatContainer/PromptInput/index.jsx | 18 ++++++++++-------- .../WorkspaceChat/ChatContainer/index.jsx | 6 ++++-- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx index e6ebaf0de..0371d64e5 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx @@ -23,8 +23,9 @@ const HistoricalMessage = ({ return (
{ return ( - (prevProps.message === nextProps.message) && - (prevProps.isLastMessage === nextProps.isLastMessage) && - (prevProps.chatId === nextProps.chatId) + prevProps.message === nextProps.message && + prevProps.isLastMessage === nextProps.isLastMessage && + prevProps.chatId === nextProps.chatId ); } -); \ No newline at end of file +); diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx index 0b28ac587..98ad11f8f 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx @@ -13,7 +13,7 @@ import AvailableAgentsButton, { } from "./AgentMenu"; import TextSizeButton from "./TextSizeMenu"; -export const PROMPT_INPUT_EVENT = 'set_prompt_input'; +export const PROMPT_INPUT_EVENT = "set_prompt_input"; export default function PromptInput({ submit, onChange, @@ -21,7 +21,7 @@ export default function PromptInput({ buttonDisabled, sendCommand, }) { - const [promptInput, setPromptInput] = useState(''); + const [promptInput, setPromptInput] = useState(""); const { showAgents, setShowAgents } = useAvailableAgents(); const { showSlashCommand, setShowSlashCommand } = useSlashCommands(); const formRef = useRef(null); @@ -31,12 +31,14 @@ export default function PromptInput({ // To prevent too many re-renders we remotely listen for updates from the parent // via an event cycle. Otherwise, using message as a prop leads to a re-render every // change on the input. - function handlePromptUpdate(e) { setPromptInput(e?.detail ?? ''); } + function handlePromptUpdate(e) { + setPromptInput(e?.detail ?? ""); + } useEffect(() => { - if (!!window) window.addEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate); - return () => ( - window?.removeEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate) - ) + if (!!window) + window.addEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate); + return () => + window?.removeEventListener(PROMPT_INPUT_EVENT, handlePromptUpdate); }, []); useEffect(() => { @@ -115,7 +117,7 @@ export default function PromptInput({ watchForSlash(e); watchForAt(e); adjustTextArea(e); - setPromptInput(e.target.value) + setPromptInput(e.target.value); }} onKeyDown={captureEnter} required={true} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 07608a7fe..b3cc0d942 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -28,9 +28,11 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { // Emit an update to the sate of the prompt input without directly // passing a prop in so that it does not re-render constantly. - function setMessageEmit(messageContent = '') { + function setMessageEmit(messageContent = "") { setMessage(messageContent); - window.dispatchEvent(new CustomEvent(PROMPT_INPUT_EVENT, { detail: messageContent })) + window.dispatchEvent( + new CustomEvent(PROMPT_INPUT_EVENT, { detail: messageContent }) + ); } const handleSubmit = async (event) => { From 734c5a9e964bf7d7c839b9be5878d9291a48c7b5 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Fri, 10 May 2024 14:47:29 -0700 Subject: [PATCH 09/11] [FEAT] Implement regenerate response button (#1341) * implement regenerate response button * wip on rerenders * remove function that was duplicate * update delete-chats function --------- Co-authored-by: timothycarambat --- .../HistoricalMessage/Actions/index.jsx | 40 ++++++++++++++- .../ChatHistory/HistoricalMessage/index.jsx | 4 ++ .../ChatContainer/ChatHistory/index.jsx | 9 +++- .../WorkspaceChat/ChatContainer/index.jsx | 50 ++++++++++++++----- frontend/src/models/workspace.js | 16 ++++++ server/endpoints/workspaces.js | 31 ++++++++++++ 6 files changed, 135 insertions(+), 15 deletions(-) diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx index 23914963f..b7e540cb6 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx @@ -5,11 +5,19 @@ import { ClipboardText, ThumbsUp, ThumbsDown, + ArrowsClockwise, } from "@phosphor-icons/react"; import { Tooltip } from "react-tooltip"; import Workspace from "@/models/workspace"; -const Actions = ({ message, feedbackScore, chatId, slug }) => { +const Actions = ({ + message, + feedbackScore, + chatId, + slug, + isLastMessage, + regenerateMessage, +}) => { const [selectedFeedback, setSelectedFeedback] = useState(feedbackScore); const handleFeedback = async (newFeedback) => { @@ -22,6 +30,14 @@ const Actions = ({ message, feedbackScore, chatId, slug }) => { return (
+ {isLastMessage && + !message?.includes("Workspace chat memory was reset!") && ( + + )} {chatId && ( <> + + +
+ ); +} + export default memo(Actions); diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx index 0371d64e5..5f4e6c672 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx @@ -19,6 +19,8 @@ const HistoricalMessage = ({ error = false, feedbackScore = null, chatId = null, + isLastMessage = false, + regenerateMessage, }) => { return (
)} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx index c0eb5bf4c..3c2c47a05 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx @@ -8,7 +8,12 @@ import debounce from "lodash.debounce"; import useUser from "@/hooks/useUser"; import Chartable from "./Chartable"; -export default function ChatHistory({ history = [], workspace, sendCommand }) { +export default function ChatHistory({ + history = [], + workspace, + sendCommand, + regenerateAssistantMessage, +}) { const { user } = useUser(); const { showing, showModal, hideModal } = useManageWorkspaceModal(); const [isAtBottom, setIsAtBottom] = useState(true); @@ -165,6 +170,8 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) { feedbackScore={props.feedbackScore} chatId={props.chatId} error={props.error} + regenerateMessage={regenerateAssistantMessage} + isLastMessage={isLastBotReply} /> ); })} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index b3cc0d942..494ee57d9 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -26,7 +26,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { setMessage(event.target.value); }; - // Emit an update to the sate of the prompt input without directly + // Emit an update to the state of the prompt input without directly // passing a prop in so that it does not re-render constantly. function setMessageEmit(messageContent = "") { setMessage(messageContent); @@ -56,24 +56,47 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { setLoadingResponse(true); }; - const sendCommand = async (command, submit = false) => { + const regenerateAssistantMessage = (chatId) => { + const updatedHistory = chatHistory.slice(0, -1); + const lastUserMessage = updatedHistory.slice(-1)[0]; + Workspace.deleteChats(workspace.slug, [chatId]) + .then(() => sendCommand(lastUserMessage.content, true, updatedHistory)) + .catch((e) => console.error(e)); + }; + + const sendCommand = async (command, submit = false, history = []) => { if (!command || command === "") return false; if (!submit) { setMessageEmit(command); return; } - const prevChatHistory = [ - ...chatHistory, - { content: command, role: "user" }, - { - content: "", - role: "assistant", - pending: true, - userMessage: command, - animate: true, - }, - ]; + let prevChatHistory; + if (history.length > 0) { + // use pre-determined history chain. + prevChatHistory = [ + ...history, + { + content: "", + role: "assistant", + pending: true, + userMessage: command, + animate: true, + }, + ]; + } else { + prevChatHistory = [ + ...chatHistory, + { content: command, role: "user" }, + { + content: "", + role: "assistant", + pending: true, + userMessage: command, + animate: true, + }, + ]; + } setChatHistory(prevChatHistory); setMessageEmit(""); @@ -217,6 +240,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { history={chatHistory} workspace={workspace} sendCommand={sendCommand} + regenerateAssistantMessage={regenerateAssistantMessage} /> false); return result; }, + + deleteChats: async function (slug = "", chatIds = []) { + return await fetch(`${API_BASE}/workspace/${slug}/delete-chats`, { + method: "DELETE", + headers: baseHeaders(), + body: JSON.stringify({ chatIds }), + }) + .then((res) => { + if (res.ok) return true; + throw new Error("Failed to delete chats."); + }) + .catch((e) => { + console.log(e); + return false; + }); + }, streamChat: async function ({ slug }, message, handleChat) { const ctrl = new AbortController(); diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js index e9df2613c..f85c213fc 100644 --- a/server/endpoints/workspaces.js +++ b/server/endpoints/workspaces.js @@ -372,6 +372,37 @@ function workspaceEndpoints(app) { } ); + app.delete( + "/workspace/:slug/delete-chats", + [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug], + async (request, response) => { + try { + const { chatIds = [] } = reqBody(request); + const user = await userFromSession(request, response); + const workspace = response.locals.workspace; + + if (!workspace || !Array.isArray(chatIds)) { + response.sendStatus(400).end(); + return; + } + + // This works for both workspace and threads. + // we simplify this by just looking at workspace<>user overlap + // since they are all on the same table. + await WorkspaceChats.delete({ + id: { in: chatIds.map((id) => Number(id)) }, + user_id: user?.id ?? null, + workspaceId: workspace.id, + }); + + response.sendStatus(200).end(); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); + app.post( "/workspace/:slug/chat-feedback/:chatId", [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug], From 0a6a9e40c13452dfa1c6a1b8ef97e4b2c2710375 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Fri, 10 May 2024 14:49:02 -0700 Subject: [PATCH 10/11] [FIX] Add max tokens field to generic OpenAI LLM connector (#1345) * add max tokens field to generic openai llm connector * add max_tokens property to generic openai agent provider --- .../LLMSelection/GenericOpenAiOptions/index.jsx | 15 +++++++++++++++ server/models/systemSettings.js | 1 + server/utils/AiProviders/genericOpenAi/index.js | 3 +++ .../agents/aibitat/providers/genericOpenAi.js | 2 ++ server/utils/helpers/updateENV.js | 4 ++++ 5 files changed, 25 insertions(+) diff --git a/frontend/src/components/LLMSelection/GenericOpenAiOptions/index.jsx b/frontend/src/components/LLMSelection/GenericOpenAiOptions/index.jsx index 8f5a00b66..ac143e94a 100644 --- a/frontend/src/components/LLMSelection/GenericOpenAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GenericOpenAiOptions/index.jsx @@ -61,6 +61,21 @@ export default function GenericOpenAiOptions({ settings }) { autoComplete="off" />
+
+ + +
); } diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index 9ac41db0c..21d7af217 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -373,6 +373,7 @@ const SystemSettings = { GenericOpenAiModelPref: process.env.GENERIC_OPEN_AI_MODEL_PREF, GenericOpenAiTokenLimit: process.env.GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT, GenericOpenAiKey: !!process.env.GENERIC_OPEN_AI_API_KEY, + GenericOpenAiMaxTokens: process.env.GENERIC_OPEN_AI_MAX_TOKENS, // Cohere API Keys CohereApiKey: !!process.env.COHERE_API_KEY, diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js index 8c171b679..686d4c677 100644 --- a/server/utils/AiProviders/genericOpenAi/index.js +++ b/server/utils/AiProviders/genericOpenAi/index.js @@ -18,6 +18,7 @@ class GenericOpenAiLLM { }); this.model = modelPreference ?? process.env.GENERIC_OPEN_AI_MODEL_PREF ?? null; + this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024; if (!this.model) throw new Error("GenericOpenAI must have a valid model set."); this.limits = { @@ -94,6 +95,7 @@ class GenericOpenAiLLM { model: this.model, messages, temperature, + max_tokens: this.maxTokens, }) .catch((e) => { throw new Error(e.response.data.error.message); @@ -110,6 +112,7 @@ class GenericOpenAiLLM { stream: true, messages, temperature, + max_tokens: this.maxTokens, }); return streamRequest; } diff --git a/server/utils/agents/aibitat/providers/genericOpenAi.js b/server/utils/agents/aibitat/providers/genericOpenAi.js index 3521bc7d0..e41476d2a 100644 --- a/server/utils/agents/aibitat/providers/genericOpenAi.js +++ b/server/utils/agents/aibitat/providers/genericOpenAi.js @@ -24,6 +24,7 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { this._client = client; this.model = model; this.verbose = true; + this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024; } get client() { @@ -36,6 +37,7 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { model: this.model, temperature: 0, messages, + max_tokens: this.maxTokens, }) .then((result) => { if (!result.hasOwnProperty("choices")) diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 11be3db80..39223c334 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -173,6 +173,10 @@ const KEY_MAPPING = { envKey: "GENERIC_OPEN_AI_API_KEY", checks: [], }, + GenericOpenAiMaxTokens: { + envKey: "GENERIC_OPEN_AI_MAX_TOKENS", + checks: [nonZero], + }, EmbeddingEngine: { envKey: "EMBEDDING_ENGINE", From 5eec5cbb9e39ffe814c7317f23992a029001746c Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Fri, 10 May 2024 15:02:59 -0700 Subject: [PATCH 11/11] update agent modal --- .../ChatContainer/PromptInput/AgentMenu/index.jsx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/AgentMenu/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/AgentMenu/index.jsx index 17b071126..bc98b45b4 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/AgentMenu/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/AgentMenu/index.jsx @@ -161,10 +161,6 @@ function FirstTimeAgentUser() { Now you can use agents for real-time web search and scraping, saving documents to your browser, summarizing documents, and more. -
-
- Currently, agents only work with OpenAI as your agent LLM. All - LLM providers will be supported in the future.

This feature is currently early access and fully custom agents