Agent support for LLMs with no function calling (#1295)

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* azure, groq, koboldcpp agent support complete + WIP togetherai

* WIP gemini agent support

* WIP gemini blocked and will not fix for now

* azure fix

* merge fix

* add localai agent support

* azure untooled agent support

* merge fix

* refactor implementation of several agent provideers

* update bad merge comment

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-05-08 15:17:54 -07:00 committed by GitHub
parent b2b41db110
commit 8422f92542
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 645 additions and 8 deletions

View File

@ -28,6 +28,7 @@
"openrouter", "openrouter",
"Qdrant", "Qdrant",
"Serper", "Serper",
"togetherai",
"vectordbs", "vectordbs",
"Weaviate", "Weaviate",
"Zilliz" "Zilliz"

View File

@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
import AgentModelSelection from "../AgentModelSelection"; import AgentModelSelection from "../AgentModelSelection";
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; const ENABLED_PROVIDERS = [
const WARN_PERFORMANCE = ["lmstudio", "ollama"]; "openai",
"anthropic",
"lmstudio",
"ollama",
"localai",
"groq",
"azure",
"koboldcpp",
"togetherai",
];
const WARN_PERFORMANCE = [
"lmstudio",
"groq",
"azure",
"koboldcpp",
"ollama",
"localai",
];
const LLM_DEFAULT = { const LLM_DEFAULT = {
name: "Please make a selection", name: "Please make a selection",

View File

@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
return new Providers.LMStudioProvider({}); return new Providers.LMStudioProvider({});
case "ollama": case "ollama":
return new Providers.OllamaProvider({ model: config.model }); return new Providers.OllamaProvider({ model: config.model });
case "groq":
return new Providers.GroqProvider({ model: config.model });
case "togetherai":
return new Providers.TogetherAIProvider({ model: config.model });
case "azure":
return new Providers.AzureOpenAiProvider({ model: config.model });
case "koboldcpp":
return new Providers.KoboldCPPProvider({});
case "localai":
return new Providers.LocalAIProvider({ model: config.model });
default: default:
throw new Error( throw new Error(

View File

@ -58,6 +58,9 @@ class Provider {
} }
} }
// For some providers we may want to override the system prompt to be more verbose.
// Currently we only do this for lmstudio, but we probably will want to expand this even more
// to any Untooled LLM.
static systemPrompt(provider = null) { static systemPrompt(provider = null) {
switch (provider) { switch (provider) {
case "lmstudio": case "lmstudio":

View File

@ -0,0 +1,105 @@
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the Azure OpenAI API.
*/
class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const client = new OpenAIClient(
process.env.AZURE_OPENAI_ENDPOINT,
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
);
this._client = client;
this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client
.getChatCompletions(this.model, messages, {
temperature: 0,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("Azure OpenAI chat: No results!");
if (result.choices.length === 0)
throw new Error("Azure OpenAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.getChatCompletions(
this.model,
this.cleanMsgs(messages),
{
temperature: 0.7,
}
);
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
* Stubbed since Azure OpenAI has no public cost basis.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
*/
getCost(_usage) {
return 0;
}
}
module.exports = AzureOpenAiProvider;

View File

@ -0,0 +1,110 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js");
/**
* The provider for the Groq provider.
*/
class GroqProvider extends Provider {
model;
constructor(config = {}) {
const { model = "llama3-8b-8192" } = config;
const client = new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY,
maxRetries: 3,
});
super(client);
this.model = model;
this.verbose = true;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
const response = await this.client.chat.completions.create({
model: this.model,
// stream: true,
messages,
...(Array.isArray(functions) && functions?.length > 0
? { functions }
: {}),
});
// Right now, we only support one completion,
// so we just take the first one in the list
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
// treat function calls
if (completion.function_call) {
let functionArgs = {};
try {
functionArgs = JSON.parse(completion.function_call.arguments);
} catch (error) {
// call the complete function again in case it gets a json error
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
}
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
}
return {
result: completion.content,
cost,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since Groq has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = GroqProvider;

View File

@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
...history, ...history,
], ],
}); });
const call = safeJsonParse(response, null); const call = safeJsonParse(response, null);
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.

View File

@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
const AnthropicProvider = require("./anthropic.js"); const AnthropicProvider = require("./anthropic.js");
const LMStudioProvider = require("./lmstudio.js"); const LMStudioProvider = require("./lmstudio.js");
const OllamaProvider = require("./ollama.js"); const OllamaProvider = require("./ollama.js");
const GroqProvider = require("./groq.js");
const TogetherAIProvider = require("./togetherai.js");
const AzureOpenAiProvider = require("./azure.js");
const KoboldCPPProvider = require("./koboldcpp.js");
const LocalAIProvider = require("./localai.js");
module.exports = { module.exports = {
OpenAIProvider, OpenAIProvider,
AnthropicProvider, AnthropicProvider,
LMStudioProvider, LMStudioProvider,
OllamaProvider, OllamaProvider,
GroqProvider,
TogetherAIProvider,
AzureOpenAiProvider,
KoboldCPPProvider,
LocalAIProvider,
}; };

View File

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the KoboldCPP provider.
*/
class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
const client = new OpenAI({
baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
apiKey: null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("KoboldCPP chat: No results!");
if (result.choices.length === 0)
throw new Error("KoboldCPP chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since KoboldCPP has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = KoboldCPPProvider;

View File

@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
apiKey: null, apiKey: null,
maxRetries: 3, maxRetries: 3,
model,
}); });
this._client = client; this._client = client;
this.model = model; this.model = model;
this.verbose = true; this.verbose = true;

View File

@ -0,0 +1,114 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the LocalAI provider.
*/
class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = null } = config;
super();
const client = new OpenAI({
baseURL: process.env.LOCAL_AI_BASE_PATH,
apiKey: process.env.LOCAL_AI_API_KEY ?? null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LocalAI chat: No results!");
if (result.choices.length === 0)
throw new Error("LocalAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LocalAI has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = LocalAiProvider;

View File

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the TogetherAI provider.
*/
class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
super();
const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_AI_API_KEY,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = TogetherAIProvider;

View File

@ -85,6 +85,36 @@ class AgentHandler {
if (!process.env.OLLAMA_BASE_PATH) if (!process.env.OLLAMA_BASE_PATH)
throw new Error("Ollama base path must be provided to use agents."); throw new Error("Ollama base path must be provided to use agents.");
break; break;
case "groq":
if (!process.env.GROQ_API_KEY)
throw new Error("Groq API key must be provided to use agents.");
break;
case "togetherai":
if (!process.env.TOGETHER_AI_API_KEY)
throw new Error("TogetherAI API key must be provided to use agents.");
break;
case "azure":
if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
throw new Error(
"Azure OpenAI API endpoint and key must be provided to use agents."
);
break;
case "koboldcpp":
if (!process.env.KOBOLD_CPP_BASE_PATH)
throw new Error(
"KoboldCPP must have a valid base path to use for the api."
);
break;
case "localai":
if (!process.env.LOCAL_AI_BASE_PATH)
throw new Error(
"LocalAI must have a valid base path to use for the api."
);
break;
case "gemini":
if (!process.env.GEMINI_API_KEY)
throw new Error("Gemini API key must be provided to use agents.");
break;
default: default:
throw new Error("No provider found to power agent cluster."); throw new Error("No provider found to power agent cluster.");
} }
@ -100,6 +130,18 @@ class AgentHandler {
return "server-default"; return "server-default";
case "ollama": case "ollama":
return "llama3:latest"; return "llama3:latest";
case "groq":
return "llama3-70b-8192";
case "togetherai":
return "mistralai/Mixtral-8x7B-Instruct-v0.1";
case "azure":
return "gpt-3.5-turbo";
case "koboldcpp":
return null;
case "gemini":
return "gemini-pro";
case "localai":
return null;
default: default:
return "unknown"; return "unknown";
} }

View File

@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
try { try {
const { OpenAI: OpenAIApi } = require("openai"); const { OpenAI: OpenAIApi } = require("openai");
const openai = new OpenAIApi({ const openai = new OpenAIApi({
baseURL: basePath || process.env.LMSTUDIO_BASE_PATH, baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
apiKey: null, apiKey: null,
}); });
const models = await openai.models const models = await openai.models