From 8422f9254278f6430a350b808b877d42e6061ec0 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Wed, 8 May 2024 15:17:54 -0700 Subject: [PATCH] Agent support for LLMs with no function calling (#1295) * add LMStudio agent support (generic) support "work" with non-tool callable LLMs, highly dependent on system specs * add comments * enable few-shot prompting per function for OSS models * Add Agent support for Ollama models * azure, groq, koboldcpp agent support complete + WIP togetherai * WIP gemini agent support * WIP gemini blocked and will not fix for now * azure fix * merge fix * add localai agent support * azure untooled agent support * merge fix * refactor implementation of several agent provideers * update bad merge comment --------- Co-authored-by: timothycarambat --- .vscode/settings.json | 1 + .../AgentConfig/AgentLLMSelection/index.jsx | 21 +++- server/utils/agents/aibitat/index.js | 14 ++- .../agents/aibitat/providers/ai-provider.js | 3 + .../utils/agents/aibitat/providers/azure.js | 105 ++++++++++++++++ server/utils/agents/aibitat/providers/groq.js | 110 +++++++++++++++++ .../aibitat/providers/helpers/untooled.js | 3 +- .../utils/agents/aibitat/providers/index.js | 10 ++ .../agents/aibitat/providers/koboldcpp.js | 113 +++++++++++++++++ .../agents/aibitat/providers/lmstudio.js | 2 +- .../utils/agents/aibitat/providers/localai.js | 114 ++++++++++++++++++ .../agents/aibitat/providers/togetherai.js | 113 +++++++++++++++++ server/utils/agents/index.js | 42 +++++++ server/utils/helpers/customModels.js | 2 +- 14 files changed, 645 insertions(+), 8 deletions(-) create mode 100644 server/utils/agents/aibitat/providers/azure.js create mode 100644 server/utils/agents/aibitat/providers/groq.js create mode 100644 server/utils/agents/aibitat/providers/koboldcpp.js create mode 100644 server/utils/agents/aibitat/providers/localai.js create mode 100644 server/utils/agents/aibitat/providers/togetherai.js diff --git a/.vscode/settings.json b/.vscode/settings.json index f850bbb0..eecaa83f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -28,6 +28,7 @@ "openrouter", "Qdrant", "Serper", + "togetherai", "vectordbs", "Weaviate", "Zilliz" diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index fcb12d94..400eef02 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import AgentModelSelection from "../AgentModelSelection"; -const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; -const WARN_PERFORMANCE = ["lmstudio", "ollama"]; +const ENABLED_PROVIDERS = [ + "openai", + "anthropic", + "lmstudio", + "ollama", + "localai", + "groq", + "azure", + "koboldcpp", + "togetherai", +]; +const WARN_PERFORMANCE = [ + "lmstudio", + "groq", + "azure", + "koboldcpp", + "ollama", + "localai", +]; const LLM_DEFAULT = { name: "Please make a selection", diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 9cf2170b..3413bd35 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -480,7 +480,7 @@ Read the following conversation. CHAT HISTORY ${history.map((c) => `@${c.from}: ${c.content}`).join("\n")} -Then select the next role from that is going to speak next. +Then select the next role from that is going to speak next. Only return the role. `, }, @@ -522,7 +522,7 @@ Only return the role. ? [ { role: "user", - content: `You are in a whatsapp group. Read the following conversation and then reply. + content: `You are in a whatsapp group. Read the following conversation and then reply. Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself. CHAT HISTORY @@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to }) return new Providers.LMStudioProvider({}); case "ollama": return new Providers.OllamaProvider({ model: config.model }); + case "groq": + return new Providers.GroqProvider({ model: config.model }); + case "togetherai": + return new Providers.TogetherAIProvider({ model: config.model }); + case "azure": + return new Providers.AzureOpenAiProvider({ model: config.model }); + case "koboldcpp": + return new Providers.KoboldCPPProvider({}); + case "localai": + return new Providers.LocalAIProvider({ model: config.model }); default: throw new Error( diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 0e871b36..91a81ebf 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -58,6 +58,9 @@ class Provider { } } + // For some providers we may want to override the system prompt to be more verbose. + // Currently we only do this for lmstudio, but we probably will want to expand this even more + // to any Untooled LLM. static systemPrompt(provider = null) { switch (provider) { case "lmstudio": diff --git a/server/utils/agents/aibitat/providers/azure.js b/server/utils/agents/aibitat/providers/azure.js new file mode 100644 index 00000000..cdcf7618 --- /dev/null +++ b/server/utils/agents/aibitat/providers/azure.js @@ -0,0 +1,105 @@ +const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the Azure OpenAI API. + */ +class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(_config = {}) { + super(); + const client = new OpenAIClient( + process.env.AZURE_OPENAI_ENDPOINT, + new AzureKeyCredential(process.env.AZURE_OPENAI_KEY) + ); + this._client = client; + this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo"; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client + .getChatCompletions(this.model, messages, { + temperature: 0, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("Azure OpenAI chat: No results!"); + if (result.choices.length === 0) + throw new Error("Azure OpenAI chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.getChatCompletions( + this.model, + this.cleanMsgs(messages), + { + temperature: 0.7, + } + ); + completion = response.choices[0].message; + } + return { result: completion.content, cost: 0 }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * Stubbed since Azure OpenAI has no public cost basis. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = AzureOpenAiProvider; diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js new file mode 100644 index 00000000..3b87ba51 --- /dev/null +++ b/server/utils/agents/aibitat/providers/groq.js @@ -0,0 +1,110 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const { RetryError } = require("../error.js"); + +/** + * The provider for the Groq provider. + */ +class GroqProvider extends Provider { + model; + + constructor(config = {}) { + const { model = "llama3-8b-8192" } = config; + const client = new OpenAI({ + baseURL: "https://api.groq.com/openai/v1", + apiKey: process.env.GROQ_API_KEY, + maxRetries: 3, + }); + super(client); + this.model = model; + this.verbose = true; + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + const response = await this.client.chat.completions.create({ + model: this.model, + // stream: true, + messages, + ...(Array.isArray(functions) && functions?.length > 0 + ? { functions } + : {}), + }); + + // Right now, we only support one completion, + // so we just take the first one in the list + const completion = response.choices[0].message; + const cost = this.getCost(response.usage); + // treat function calls + if (completion.function_call) { + let functionArgs = {}; + try { + functionArgs = JSON.parse(completion.function_call.arguments); + } catch (error) { + // call the complete function again in case it gets a json error + return this.complete( + [ + ...messages, + { + role: "function", + name: completion.function_call.name, + function_call: completion.function_call, + content: error?.message, + }, + ], + functions + ); + } + + // console.log(completion, { functionArgs }) + return { + result: null, + functionCall: { + name: completion.function_call.name, + arguments: functionArgs, + }, + cost, + }; + } + + return { + result: completion.content, + cost, + }; + } catch (error) { + // If invalid Auth error we need to abort because no amount of waiting + // will make auth better. + if (error instanceof OpenAI.AuthenticationError) throw error; + + if ( + error instanceof OpenAI.RateLimitError || + error instanceof OpenAI.InternalServerError || + error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); + } + + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since Groq has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = GroqProvider; diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js index 37ecb559..11fbfec8 100644 --- a/server/utils/agents/aibitat/providers/helpers/untooled.js +++ b/server/utils/agents/aibitat/providers/helpers/untooled.js @@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; const response = await chatCb({ messages: [ { - content: `You are a program which picks the most optimal function and parameters to call. + content: `You are a program which picks the most optimal function and parameters to call. DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. When a function is selection, respond in JSON with no additional text. When there is no relevant function to call - return with a regular chat text response. @@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; ...history, ], }); - const call = safeJsonParse(response, null); if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index fda8b513..6f8a2da0 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js"); const AnthropicProvider = require("./anthropic.js"); const LMStudioProvider = require("./lmstudio.js"); const OllamaProvider = require("./ollama.js"); +const GroqProvider = require("./groq.js"); +const TogetherAIProvider = require("./togetherai.js"); +const AzureOpenAiProvider = require("./azure.js"); +const KoboldCPPProvider = require("./koboldcpp.js"); +const LocalAIProvider = require("./localai.js"); module.exports = { OpenAIProvider, AnthropicProvider, LMStudioProvider, OllamaProvider, + GroqProvider, + TogetherAIProvider, + AzureOpenAiProvider, + KoboldCPPProvider, + LocalAIProvider, }; diff --git a/server/utils/agents/aibitat/providers/koboldcpp.js b/server/utils/agents/aibitat/providers/koboldcpp.js new file mode 100644 index 00000000..77088263 --- /dev/null +++ b/server/utils/agents/aibitat/providers/koboldcpp.js @@ -0,0 +1,113 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the KoboldCPP provider. + */ +class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(_config = {}) { + super(); + const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null; + const client = new OpenAI({ + baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""), + apiKey: null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("KoboldCPP chat: No results!"); + if (result.choices.length === 0) + throw new Error("KoboldCPP chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since KoboldCPP has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = KoboldCPPProvider; diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index d3aa4346..f5c4a2e8 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance apiKey: null, maxRetries: 3, - model, }); + this._client = client; this.model = model; this.verbose = true; diff --git a/server/utils/agents/aibitat/providers/localai.js b/server/utils/agents/aibitat/providers/localai.js new file mode 100644 index 00000000..161172c2 --- /dev/null +++ b/server/utils/agents/aibitat/providers/localai.js @@ -0,0 +1,114 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the LocalAI provider. + */ +class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = null } = config; + super(); + const client = new OpenAI({ + baseURL: process.env.LOCAL_AI_BASE_PATH, + apiKey: process.env.LOCAL_AI_API_KEY ?? null, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LocalAI chat: No results!"); + + if (result.choices.length === 0) + throw new Error("LocalAI chat: No results length!"); + + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { result: completion.content, cost: 0 }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LocalAI has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = LocalAiProvider; diff --git a/server/utils/agents/aibitat/providers/togetherai.js b/server/utils/agents/aibitat/providers/togetherai.js new file mode 100644 index 00000000..4ea5e11c --- /dev/null +++ b/server/utils/agents/aibitat/providers/togetherai.js @@ -0,0 +1,113 @@ +const OpenAI = require("openai"); +const Provider = require("./ai-provider.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); + +/** + * The provider for the TogetherAI provider. + */ +class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config; + super(); + const client = new OpenAI({ + baseURL: "https://api.together.xyz/v1", + apiKey: process.env.TOGETHER_AI_API_KEY, + maxRetries: 3, + }); + + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("LMStudio chat: No results!"); + if (result.choices.length === 0) + throw new Error("LMStudio chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @returns The completion. + */ + async complete(messages, functions = null) { + try { + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion = { content: text }; + } + + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; + } + + return { + result: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Get the cost of the completion. + * + * @param _usage The completion to get the cost for. + * @returns The cost of the completion. + * Stubbed since LMStudio has no cost basis. + */ + getCost(_usage) { + return 0; + } +} + +module.exports = TogetherAIProvider; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index e18b8b7b..768ad819 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -85,6 +85,36 @@ class AgentHandler { if (!process.env.OLLAMA_BASE_PATH) throw new Error("Ollama base path must be provided to use agents."); break; + case "groq": + if (!process.env.GROQ_API_KEY) + throw new Error("Groq API key must be provided to use agents."); + break; + case "togetherai": + if (!process.env.TOGETHER_AI_API_KEY) + throw new Error("TogetherAI API key must be provided to use agents."); + break; + case "azure": + if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY) + throw new Error( + "Azure OpenAI API endpoint and key must be provided to use agents." + ); + break; + case "koboldcpp": + if (!process.env.KOBOLD_CPP_BASE_PATH) + throw new Error( + "KoboldCPP must have a valid base path to use for the api." + ); + break; + case "localai": + if (!process.env.LOCAL_AI_BASE_PATH) + throw new Error( + "LocalAI must have a valid base path to use for the api." + ); + break; + case "gemini": + if (!process.env.GEMINI_API_KEY) + throw new Error("Gemini API key must be provided to use agents."); + break; default: throw new Error("No provider found to power agent cluster."); } @@ -100,6 +130,18 @@ class AgentHandler { return "server-default"; case "ollama": return "llama3:latest"; + case "groq": + return "llama3-70b-8192"; + case "togetherai": + return "mistralai/Mixtral-8x7B-Instruct-v0.1"; + case "azure": + return "gpt-3.5-turbo"; + case "koboldcpp": + return null; + case "gemini": + return "gemini-pro"; + case "localai": + return null; default: return "unknown"; } diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index ce690ae4..3743ffad 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) { try { const { OpenAI: OpenAIApi } = require("openai"); const openai = new OpenAIApi({ - baseURL: basePath || process.env.LMSTUDIO_BASE_PATH, + baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH, apiKey: null, }); const models = await openai.models