Agent support for LLMs with no function calling (#1295)

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* azure, groq, koboldcpp agent support complete + WIP togetherai

* WIP gemini agent support

* WIP gemini blocked and will not fix for now

* azure fix

* merge fix

* add localai agent support

* azure untooled agent support

* merge fix

* refactor implementation of several agent provideers

* update bad merge comment

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-05-08 15:17:54 -07:00 committed by GitHub
parent b2b41db110
commit 8422f92542
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 645 additions and 8 deletions

View File

@ -28,6 +28,7 @@
"openrouter",
"Qdrant",
"Serper",
"togetherai",
"vectordbs",
"Weaviate",
"Zilliz"

View File

@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
import AgentModelSelection from "../AgentModelSelection";
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
const WARN_PERFORMANCE = ["lmstudio", "ollama"];
const ENABLED_PROVIDERS = [
"openai",
"anthropic",
"lmstudio",
"ollama",
"localai",
"groq",
"azure",
"koboldcpp",
"togetherai",
];
const WARN_PERFORMANCE = [
"lmstudio",
"groq",
"azure",
"koboldcpp",
"ollama",
"localai",
];
const LLM_DEFAULT = {
name: "Please make a selection",

View File

@ -480,7 +480,7 @@ Read the following conversation.
CHAT HISTORY
${history.map((c) => `@${c.from}: ${c.content}`).join("\n")}
Then select the next role from that is going to speak next.
Then select the next role from that is going to speak next.
Only return the role.
`,
},
@ -522,7 +522,7 @@ Only return the role.
? [
{
role: "user",
content: `You are in a whatsapp group. Read the following conversation and then reply.
content: `You are in a whatsapp group. Read the following conversation and then reply.
Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself.
CHAT HISTORY
@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
return new Providers.LMStudioProvider({});
case "ollama":
return new Providers.OllamaProvider({ model: config.model });
case "groq":
return new Providers.GroqProvider({ model: config.model });
case "togetherai":
return new Providers.TogetherAIProvider({ model: config.model });
case "azure":
return new Providers.AzureOpenAiProvider({ model: config.model });
case "koboldcpp":
return new Providers.KoboldCPPProvider({});
case "localai":
return new Providers.LocalAIProvider({ model: config.model });
default:
throw new Error(

View File

@ -58,6 +58,9 @@ class Provider {
}
}
// For some providers we may want to override the system prompt to be more verbose.
// Currently we only do this for lmstudio, but we probably will want to expand this even more
// to any Untooled LLM.
static systemPrompt(provider = null) {
switch (provider) {
case "lmstudio":

View File

@ -0,0 +1,105 @@
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the Azure OpenAI API.
*/
class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const client = new OpenAIClient(
process.env.AZURE_OPENAI_ENDPOINT,
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
);
this._client = client;
this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client
.getChatCompletions(this.model, messages, {
temperature: 0,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("Azure OpenAI chat: No results!");
if (result.choices.length === 0)
throw new Error("Azure OpenAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.getChatCompletions(
this.model,
this.cleanMsgs(messages),
{
temperature: 0.7,
}
);
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
* Stubbed since Azure OpenAI has no public cost basis.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
*/
getCost(_usage) {
return 0;
}
}
module.exports = AzureOpenAiProvider;

View File

@ -0,0 +1,110 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js");
/**
* The provider for the Groq provider.
*/
class GroqProvider extends Provider {
model;
constructor(config = {}) {
const { model = "llama3-8b-8192" } = config;
const client = new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY,
maxRetries: 3,
});
super(client);
this.model = model;
this.verbose = true;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
const response = await this.client.chat.completions.create({
model: this.model,
// stream: true,
messages,
...(Array.isArray(functions) && functions?.length > 0
? { functions }
: {}),
});
// Right now, we only support one completion,
// so we just take the first one in the list
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
// treat function calls
if (completion.function_call) {
let functionArgs = {};
try {
functionArgs = JSON.parse(completion.function_call.arguments);
} catch (error) {
// call the complete function again in case it gets a json error
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
}
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
}
return {
result: completion.content,
cost,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since Groq has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = GroqProvider;

View File

@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
const response = await chatCb({
messages: [
{
content: `You are a program which picks the most optimal function and parameters to call.
content: `You are a program which picks the most optimal function and parameters to call.
DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
When a function is selection, respond in JSON with no additional text.
When there is no relevant function to call - return with a regular chat text response.
@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
...history,
],
});
const call = safeJsonParse(response, null);
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.

View File

@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
const AnthropicProvider = require("./anthropic.js");
const LMStudioProvider = require("./lmstudio.js");
const OllamaProvider = require("./ollama.js");
const GroqProvider = require("./groq.js");
const TogetherAIProvider = require("./togetherai.js");
const AzureOpenAiProvider = require("./azure.js");
const KoboldCPPProvider = require("./koboldcpp.js");
const LocalAIProvider = require("./localai.js");
module.exports = {
OpenAIProvider,
AnthropicProvider,
LMStudioProvider,
OllamaProvider,
GroqProvider,
TogetherAIProvider,
AzureOpenAiProvider,
KoboldCPPProvider,
LocalAIProvider,
};

View File

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the KoboldCPP provider.
*/
class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
const client = new OpenAI({
baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
apiKey: null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("KoboldCPP chat: No results!");
if (result.choices.length === 0)
throw new Error("KoboldCPP chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since KoboldCPP has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = KoboldCPPProvider;

View File

@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
apiKey: null,
maxRetries: 3,
model,
});
this._client = client;
this.model = model;
this.verbose = true;

View File

@ -0,0 +1,114 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the LocalAI provider.
*/
class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = null } = config;
super();
const client = new OpenAI({
baseURL: process.env.LOCAL_AI_BASE_PATH,
apiKey: process.env.LOCAL_AI_API_KEY ?? null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LocalAI chat: No results!");
if (result.choices.length === 0)
throw new Error("LocalAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LocalAI has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = LocalAiProvider;

View File

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the TogetherAI provider.
*/
class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
super();
const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_AI_API_KEY,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = TogetherAIProvider;

View File

@ -85,6 +85,36 @@ class AgentHandler {
if (!process.env.OLLAMA_BASE_PATH)
throw new Error("Ollama base path must be provided to use agents.");
break;
case "groq":
if (!process.env.GROQ_API_KEY)
throw new Error("Groq API key must be provided to use agents.");
break;
case "togetherai":
if (!process.env.TOGETHER_AI_API_KEY)
throw new Error("TogetherAI API key must be provided to use agents.");
break;
case "azure":
if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
throw new Error(
"Azure OpenAI API endpoint and key must be provided to use agents."
);
break;
case "koboldcpp":
if (!process.env.KOBOLD_CPP_BASE_PATH)
throw new Error(
"KoboldCPP must have a valid base path to use for the api."
);
break;
case "localai":
if (!process.env.LOCAL_AI_BASE_PATH)
throw new Error(
"LocalAI must have a valid base path to use for the api."
);
break;
case "gemini":
if (!process.env.GEMINI_API_KEY)
throw new Error("Gemini API key must be provided to use agents.");
break;
default:
throw new Error("No provider found to power agent cluster.");
}
@ -100,6 +130,18 @@ class AgentHandler {
return "server-default";
case "ollama":
return "llama3:latest";
case "groq":
return "llama3-70b-8192";
case "togetherai":
return "mistralai/Mixtral-8x7B-Instruct-v0.1";
case "azure":
return "gpt-3.5-turbo";
case "koboldcpp":
return null;
case "gemini":
return "gemini-pro";
case "localai":
return null;
default:
return "unknown";
}

View File

@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
try {
const { OpenAI: OpenAIApi } = require("openai");
const openai = new OpenAIApi({
baseURL: basePath || process.env.LMSTUDIO_BASE_PATH,
baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
apiKey: null,
});
const models = await openai.models