mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-14 18:40:11 +01:00
Agent support for LLMs with no function calling (#1295)
* add LMStudio agent support (generic) support "work" with non-tool callable LLMs, highly dependent on system specs * add comments * enable few-shot prompting per function for OSS models * Add Agent support for Ollama models * azure, groq, koboldcpp agent support complete + WIP togetherai * WIP gemini agent support * WIP gemini blocked and will not fix for now * azure fix * merge fix * add localai agent support * azure untooled agent support * merge fix * refactor implementation of several agent provideers * update bad merge comment --------- Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
parent
b2b41db110
commit
8422f92542
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
@ -28,6 +28,7 @@
|
|||||||
"openrouter",
|
"openrouter",
|
||||||
"Qdrant",
|
"Qdrant",
|
||||||
"Serper",
|
"Serper",
|
||||||
|
"togetherai",
|
||||||
"vectordbs",
|
"vectordbs",
|
||||||
"Weaviate",
|
"Weaviate",
|
||||||
"Zilliz"
|
"Zilliz"
|
||||||
|
@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
|
|||||||
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
|
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||||
import AgentModelSelection from "../AgentModelSelection";
|
import AgentModelSelection from "../AgentModelSelection";
|
||||||
|
|
||||||
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
|
const ENABLED_PROVIDERS = [
|
||||||
const WARN_PERFORMANCE = ["lmstudio", "ollama"];
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"lmstudio",
|
||||||
|
"ollama",
|
||||||
|
"localai",
|
||||||
|
"groq",
|
||||||
|
"azure",
|
||||||
|
"koboldcpp",
|
||||||
|
"togetherai",
|
||||||
|
];
|
||||||
|
const WARN_PERFORMANCE = [
|
||||||
|
"lmstudio",
|
||||||
|
"groq",
|
||||||
|
"azure",
|
||||||
|
"koboldcpp",
|
||||||
|
"ollama",
|
||||||
|
"localai",
|
||||||
|
];
|
||||||
|
|
||||||
const LLM_DEFAULT = {
|
const LLM_DEFAULT = {
|
||||||
name: "Please make a selection",
|
name: "Please make a selection",
|
||||||
|
@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
|
|||||||
return new Providers.LMStudioProvider({});
|
return new Providers.LMStudioProvider({});
|
||||||
case "ollama":
|
case "ollama":
|
||||||
return new Providers.OllamaProvider({ model: config.model });
|
return new Providers.OllamaProvider({ model: config.model });
|
||||||
|
case "groq":
|
||||||
|
return new Providers.GroqProvider({ model: config.model });
|
||||||
|
case "togetherai":
|
||||||
|
return new Providers.TogetherAIProvider({ model: config.model });
|
||||||
|
case "azure":
|
||||||
|
return new Providers.AzureOpenAiProvider({ model: config.model });
|
||||||
|
case "koboldcpp":
|
||||||
|
return new Providers.KoboldCPPProvider({});
|
||||||
|
case "localai":
|
||||||
|
return new Providers.LocalAIProvider({ model: config.model });
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throw new Error(
|
throw new Error(
|
||||||
|
@ -58,6 +58,9 @@ class Provider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For some providers we may want to override the system prompt to be more verbose.
|
||||||
|
// Currently we only do this for lmstudio, but we probably will want to expand this even more
|
||||||
|
// to any Untooled LLM.
|
||||||
static systemPrompt(provider = null) {
|
static systemPrompt(provider = null) {
|
||||||
switch (provider) {
|
switch (provider) {
|
||||||
case "lmstudio":
|
case "lmstudio":
|
||||||
|
105
server/utils/agents/aibitat/providers/azure.js
Normal file
105
server/utils/agents/aibitat/providers/azure.js
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
|
||||||
|
const Provider = require("./ai-provider.js");
|
||||||
|
const InheritMultiple = require("./helpers/classes.js");
|
||||||
|
const UnTooled = require("./helpers/untooled.js");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provider for the Azure OpenAI API.
|
||||||
|
*/
|
||||||
|
class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
|
model;
|
||||||
|
|
||||||
|
constructor(_config = {}) {
|
||||||
|
super();
|
||||||
|
const client = new OpenAIClient(
|
||||||
|
process.env.AZURE_OPENAI_ENDPOINT,
|
||||||
|
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
|
||||||
|
);
|
||||||
|
this._client = client;
|
||||||
|
this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
|
||||||
|
this.verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
get client() {
|
||||||
|
return this._client;
|
||||||
|
}
|
||||||
|
|
||||||
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
return await this.client
|
||||||
|
.getChatCompletions(this.model, messages, {
|
||||||
|
temperature: 0,
|
||||||
|
})
|
||||||
|
.then((result) => {
|
||||||
|
if (!result.hasOwnProperty("choices"))
|
||||||
|
throw new Error("Azure OpenAI chat: No results!");
|
||||||
|
if (result.choices.length === 0)
|
||||||
|
throw new Error("Azure OpenAI chat: No results length!");
|
||||||
|
return result.choices[0].message.content;
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a completion based on the received messages.
|
||||||
|
*
|
||||||
|
* @param messages A list of messages to send to the API.
|
||||||
|
* @param functions
|
||||||
|
* @returns The completion.
|
||||||
|
*/
|
||||||
|
async complete(messages, functions = null) {
|
||||||
|
try {
|
||||||
|
let completion;
|
||||||
|
if (functions.length > 0) {
|
||||||
|
const { toolCall, text } = await this.functionCall(
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
this.#handleFunctionCallChat.bind(this)
|
||||||
|
);
|
||||||
|
if (toolCall !== null) {
|
||||||
|
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||||
|
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||||
|
return {
|
||||||
|
result: null,
|
||||||
|
functionCall: {
|
||||||
|
name: toolCall.name,
|
||||||
|
arguments: toolCall.arguments,
|
||||||
|
},
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
completion = { content: text };
|
||||||
|
}
|
||||||
|
if (!completion?.content) {
|
||||||
|
this.providerLog(
|
||||||
|
"Will assume chat completion without tool call inputs."
|
||||||
|
);
|
||||||
|
const response = await this.client.getChatCompletions(
|
||||||
|
this.model,
|
||||||
|
this.cleanMsgs(messages),
|
||||||
|
{
|
||||||
|
temperature: 0.7,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
completion = response.choices[0].message;
|
||||||
|
}
|
||||||
|
return { result: completion.content, cost: 0 };
|
||||||
|
} catch (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the cost of the completion.
|
||||||
|
* Stubbed since Azure OpenAI has no public cost basis.
|
||||||
|
*
|
||||||
|
* @param _usage The completion to get the cost for.
|
||||||
|
* @returns The cost of the completion.
|
||||||
|
*/
|
||||||
|
getCost(_usage) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = AzureOpenAiProvider;
|
110
server/utils/agents/aibitat/providers/groq.js
Normal file
110
server/utils/agents/aibitat/providers/groq.js
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
const OpenAI = require("openai");
|
||||||
|
const Provider = require("./ai-provider.js");
|
||||||
|
const { RetryError } = require("../error.js");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provider for the Groq provider.
|
||||||
|
*/
|
||||||
|
class GroqProvider extends Provider {
|
||||||
|
model;
|
||||||
|
|
||||||
|
constructor(config = {}) {
|
||||||
|
const { model = "llama3-8b-8192" } = config;
|
||||||
|
const client = new OpenAI({
|
||||||
|
baseURL: "https://api.groq.com/openai/v1",
|
||||||
|
apiKey: process.env.GROQ_API_KEY,
|
||||||
|
maxRetries: 3,
|
||||||
|
});
|
||||||
|
super(client);
|
||||||
|
this.model = model;
|
||||||
|
this.verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a completion based on the received messages.
|
||||||
|
*
|
||||||
|
* @param messages A list of messages to send to the API.
|
||||||
|
* @param functions
|
||||||
|
* @returns The completion.
|
||||||
|
*/
|
||||||
|
async complete(messages, functions = null) {
|
||||||
|
try {
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
// stream: true,
|
||||||
|
messages,
|
||||||
|
...(Array.isArray(functions) && functions?.length > 0
|
||||||
|
? { functions }
|
||||||
|
: {}),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Right now, we only support one completion,
|
||||||
|
// so we just take the first one in the list
|
||||||
|
const completion = response.choices[0].message;
|
||||||
|
const cost = this.getCost(response.usage);
|
||||||
|
// treat function calls
|
||||||
|
if (completion.function_call) {
|
||||||
|
let functionArgs = {};
|
||||||
|
try {
|
||||||
|
functionArgs = JSON.parse(completion.function_call.arguments);
|
||||||
|
} catch (error) {
|
||||||
|
// call the complete function again in case it gets a json error
|
||||||
|
return this.complete(
|
||||||
|
[
|
||||||
|
...messages,
|
||||||
|
{
|
||||||
|
role: "function",
|
||||||
|
name: completion.function_call.name,
|
||||||
|
function_call: completion.function_call,
|
||||||
|
content: error?.message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// console.log(completion, { functionArgs })
|
||||||
|
return {
|
||||||
|
result: null,
|
||||||
|
functionCall: {
|
||||||
|
name: completion.function_call.name,
|
||||||
|
arguments: functionArgs,
|
||||||
|
},
|
||||||
|
cost,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
result: completion.content,
|
||||||
|
cost,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
// If invalid Auth error we need to abort because no amount of waiting
|
||||||
|
// will make auth better.
|
||||||
|
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||||
|
|
||||||
|
if (
|
||||||
|
error instanceof OpenAI.RateLimitError ||
|
||||||
|
error instanceof OpenAI.InternalServerError ||
|
||||||
|
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||||
|
) {
|
||||||
|
throw new RetryError(error.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the cost of the completion.
|
||||||
|
*
|
||||||
|
* @param _usage The completion to get the cost for.
|
||||||
|
* @returns The cost of the completion.
|
||||||
|
* Stubbed since Groq has no cost basis.
|
||||||
|
*/
|
||||||
|
getCost(_usage) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = GroqProvider;
|
@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
|
|||||||
...history,
|
...history,
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
const call = safeJsonParse(response, null);
|
const call = safeJsonParse(response, null);
|
||||||
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.
|
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.
|
||||||
|
|
||||||
|
@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
|
|||||||
const AnthropicProvider = require("./anthropic.js");
|
const AnthropicProvider = require("./anthropic.js");
|
||||||
const LMStudioProvider = require("./lmstudio.js");
|
const LMStudioProvider = require("./lmstudio.js");
|
||||||
const OllamaProvider = require("./ollama.js");
|
const OllamaProvider = require("./ollama.js");
|
||||||
|
const GroqProvider = require("./groq.js");
|
||||||
|
const TogetherAIProvider = require("./togetherai.js");
|
||||||
|
const AzureOpenAiProvider = require("./azure.js");
|
||||||
|
const KoboldCPPProvider = require("./koboldcpp.js");
|
||||||
|
const LocalAIProvider = require("./localai.js");
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
OpenAIProvider,
|
OpenAIProvider,
|
||||||
AnthropicProvider,
|
AnthropicProvider,
|
||||||
LMStudioProvider,
|
LMStudioProvider,
|
||||||
OllamaProvider,
|
OllamaProvider,
|
||||||
|
GroqProvider,
|
||||||
|
TogetherAIProvider,
|
||||||
|
AzureOpenAiProvider,
|
||||||
|
KoboldCPPProvider,
|
||||||
|
LocalAIProvider,
|
||||||
};
|
};
|
||||||
|
113
server/utils/agents/aibitat/providers/koboldcpp.js
Normal file
113
server/utils/agents/aibitat/providers/koboldcpp.js
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
const OpenAI = require("openai");
|
||||||
|
const Provider = require("./ai-provider.js");
|
||||||
|
const InheritMultiple = require("./helpers/classes.js");
|
||||||
|
const UnTooled = require("./helpers/untooled.js");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provider for the KoboldCPP provider.
|
||||||
|
*/
|
||||||
|
class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
|
model;
|
||||||
|
|
||||||
|
constructor(_config = {}) {
|
||||||
|
super();
|
||||||
|
const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
|
||||||
|
const client = new OpenAI({
|
||||||
|
baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
|
||||||
|
apiKey: null,
|
||||||
|
maxRetries: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
this._client = client;
|
||||||
|
this.model = model;
|
||||||
|
this.verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
get client() {
|
||||||
|
return this._client;
|
||||||
|
}
|
||||||
|
|
||||||
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
return await this.client.chat.completions
|
||||||
|
.create({
|
||||||
|
model: this.model,
|
||||||
|
temperature: 0,
|
||||||
|
messages,
|
||||||
|
})
|
||||||
|
.then((result) => {
|
||||||
|
if (!result.hasOwnProperty("choices"))
|
||||||
|
throw new Error("KoboldCPP chat: No results!");
|
||||||
|
if (result.choices.length === 0)
|
||||||
|
throw new Error("KoboldCPP chat: No results length!");
|
||||||
|
return result.choices[0].message.content;
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a completion based on the received messages.
|
||||||
|
*
|
||||||
|
* @param messages A list of messages to send to the API.
|
||||||
|
* @param functions
|
||||||
|
* @returns The completion.
|
||||||
|
*/
|
||||||
|
async complete(messages, functions = null) {
|
||||||
|
try {
|
||||||
|
let completion;
|
||||||
|
if (functions.length > 0) {
|
||||||
|
const { toolCall, text } = await this.functionCall(
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
this.#handleFunctionCallChat.bind(this)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (toolCall !== null) {
|
||||||
|
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||||
|
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||||
|
return {
|
||||||
|
result: null,
|
||||||
|
functionCall: {
|
||||||
|
name: toolCall.name,
|
||||||
|
arguments: toolCall.arguments,
|
||||||
|
},
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
completion = { content: text };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!completion?.content) {
|
||||||
|
this.providerLog(
|
||||||
|
"Will assume chat completion without tool call inputs."
|
||||||
|
);
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
messages: this.cleanMsgs(messages),
|
||||||
|
});
|
||||||
|
completion = response.choices[0].message;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
result: completion.content,
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the cost of the completion.
|
||||||
|
*
|
||||||
|
* @param _usage The completion to get the cost for.
|
||||||
|
* @returns The cost of the completion.
|
||||||
|
* Stubbed since KoboldCPP has no cost basis.
|
||||||
|
*/
|
||||||
|
getCost(_usage) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = KoboldCPPProvider;
|
@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
|
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
|
||||||
apiKey: null,
|
apiKey: null,
|
||||||
maxRetries: 3,
|
maxRetries: 3,
|
||||||
model,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
this._client = client;
|
this._client = client;
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.verbose = true;
|
this.verbose = true;
|
||||||
|
114
server/utils/agents/aibitat/providers/localai.js
Normal file
114
server/utils/agents/aibitat/providers/localai.js
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
const OpenAI = require("openai");
|
||||||
|
const Provider = require("./ai-provider.js");
|
||||||
|
const InheritMultiple = require("./helpers/classes.js");
|
||||||
|
const UnTooled = require("./helpers/untooled.js");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provider for the LocalAI provider.
|
||||||
|
*/
|
||||||
|
class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
|
model;
|
||||||
|
|
||||||
|
constructor(config = {}) {
|
||||||
|
const { model = null } = config;
|
||||||
|
super();
|
||||||
|
const client = new OpenAI({
|
||||||
|
baseURL: process.env.LOCAL_AI_BASE_PATH,
|
||||||
|
apiKey: process.env.LOCAL_AI_API_KEY ?? null,
|
||||||
|
maxRetries: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
this._client = client;
|
||||||
|
this.model = model;
|
||||||
|
this.verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
get client() {
|
||||||
|
return this._client;
|
||||||
|
}
|
||||||
|
|
||||||
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
return await this.client.chat.completions
|
||||||
|
.create({
|
||||||
|
model: this.model,
|
||||||
|
temperature: 0,
|
||||||
|
messages,
|
||||||
|
})
|
||||||
|
.then((result) => {
|
||||||
|
if (!result.hasOwnProperty("choices"))
|
||||||
|
throw new Error("LocalAI chat: No results!");
|
||||||
|
|
||||||
|
if (result.choices.length === 0)
|
||||||
|
throw new Error("LocalAI chat: No results length!");
|
||||||
|
|
||||||
|
return result.choices[0].message.content;
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a completion based on the received messages.
|
||||||
|
*
|
||||||
|
* @param messages A list of messages to send to the API.
|
||||||
|
* @param functions
|
||||||
|
* @returns The completion.
|
||||||
|
*/
|
||||||
|
async complete(messages, functions = null) {
|
||||||
|
try {
|
||||||
|
let completion;
|
||||||
|
|
||||||
|
if (functions.length > 0) {
|
||||||
|
const { toolCall, text } = await this.functionCall(
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
this.#handleFunctionCallChat.bind(this)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (toolCall !== null) {
|
||||||
|
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||||
|
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||||
|
return {
|
||||||
|
result: null,
|
||||||
|
functionCall: {
|
||||||
|
name: toolCall.name,
|
||||||
|
arguments: toolCall.arguments,
|
||||||
|
},
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
completion = { content: text };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!completion?.content) {
|
||||||
|
this.providerLog(
|
||||||
|
"Will assume chat completion without tool call inputs."
|
||||||
|
);
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
messages: this.cleanMsgs(messages),
|
||||||
|
});
|
||||||
|
completion = response.choices[0].message;
|
||||||
|
}
|
||||||
|
|
||||||
|
return { result: completion.content, cost: 0 };
|
||||||
|
} catch (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the cost of the completion.
|
||||||
|
*
|
||||||
|
* @param _usage The completion to get the cost for.
|
||||||
|
* @returns The cost of the completion.
|
||||||
|
* Stubbed since LocalAI has no cost basis.
|
||||||
|
*/
|
||||||
|
getCost(_usage) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = LocalAiProvider;
|
113
server/utils/agents/aibitat/providers/togetherai.js
Normal file
113
server/utils/agents/aibitat/providers/togetherai.js
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
const OpenAI = require("openai");
|
||||||
|
const Provider = require("./ai-provider.js");
|
||||||
|
const InheritMultiple = require("./helpers/classes.js");
|
||||||
|
const UnTooled = require("./helpers/untooled.js");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provider for the TogetherAI provider.
|
||||||
|
*/
|
||||||
|
class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
|
model;
|
||||||
|
|
||||||
|
constructor(config = {}) {
|
||||||
|
const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
|
||||||
|
super();
|
||||||
|
const client = new OpenAI({
|
||||||
|
baseURL: "https://api.together.xyz/v1",
|
||||||
|
apiKey: process.env.TOGETHER_AI_API_KEY,
|
||||||
|
maxRetries: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
this._client = client;
|
||||||
|
this.model = model;
|
||||||
|
this.verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
get client() {
|
||||||
|
return this._client;
|
||||||
|
}
|
||||||
|
|
||||||
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
return await this.client.chat.completions
|
||||||
|
.create({
|
||||||
|
model: this.model,
|
||||||
|
temperature: 0,
|
||||||
|
messages,
|
||||||
|
})
|
||||||
|
.then((result) => {
|
||||||
|
if (!result.hasOwnProperty("choices"))
|
||||||
|
throw new Error("LMStudio chat: No results!");
|
||||||
|
if (result.choices.length === 0)
|
||||||
|
throw new Error("LMStudio chat: No results length!");
|
||||||
|
return result.choices[0].message.content;
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a completion based on the received messages.
|
||||||
|
*
|
||||||
|
* @param messages A list of messages to send to the API.
|
||||||
|
* @param functions
|
||||||
|
* @returns The completion.
|
||||||
|
*/
|
||||||
|
async complete(messages, functions = null) {
|
||||||
|
try {
|
||||||
|
let completion;
|
||||||
|
if (functions.length > 0) {
|
||||||
|
const { toolCall, text } = await this.functionCall(
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
this.#handleFunctionCallChat.bind(this)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (toolCall !== null) {
|
||||||
|
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||||
|
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||||
|
return {
|
||||||
|
result: null,
|
||||||
|
functionCall: {
|
||||||
|
name: toolCall.name,
|
||||||
|
arguments: toolCall.arguments,
|
||||||
|
},
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
completion = { content: text };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!completion?.content) {
|
||||||
|
this.providerLog(
|
||||||
|
"Will assume chat completion without tool call inputs."
|
||||||
|
);
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
messages: this.cleanMsgs(messages),
|
||||||
|
});
|
||||||
|
completion = response.choices[0].message;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
result: completion.content,
|
||||||
|
cost: 0,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the cost of the completion.
|
||||||
|
*
|
||||||
|
* @param _usage The completion to get the cost for.
|
||||||
|
* @returns The cost of the completion.
|
||||||
|
* Stubbed since LMStudio has no cost basis.
|
||||||
|
*/
|
||||||
|
getCost(_usage) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = TogetherAIProvider;
|
@ -85,6 +85,36 @@ class AgentHandler {
|
|||||||
if (!process.env.OLLAMA_BASE_PATH)
|
if (!process.env.OLLAMA_BASE_PATH)
|
||||||
throw new Error("Ollama base path must be provided to use agents.");
|
throw new Error("Ollama base path must be provided to use agents.");
|
||||||
break;
|
break;
|
||||||
|
case "groq":
|
||||||
|
if (!process.env.GROQ_API_KEY)
|
||||||
|
throw new Error("Groq API key must be provided to use agents.");
|
||||||
|
break;
|
||||||
|
case "togetherai":
|
||||||
|
if (!process.env.TOGETHER_AI_API_KEY)
|
||||||
|
throw new Error("TogetherAI API key must be provided to use agents.");
|
||||||
|
break;
|
||||||
|
case "azure":
|
||||||
|
if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
|
||||||
|
throw new Error(
|
||||||
|
"Azure OpenAI API endpoint and key must be provided to use agents."
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case "koboldcpp":
|
||||||
|
if (!process.env.KOBOLD_CPP_BASE_PATH)
|
||||||
|
throw new Error(
|
||||||
|
"KoboldCPP must have a valid base path to use for the api."
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case "localai":
|
||||||
|
if (!process.env.LOCAL_AI_BASE_PATH)
|
||||||
|
throw new Error(
|
||||||
|
"LocalAI must have a valid base path to use for the api."
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case "gemini":
|
||||||
|
if (!process.env.GEMINI_API_KEY)
|
||||||
|
throw new Error("Gemini API key must be provided to use agents.");
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error("No provider found to power agent cluster.");
|
throw new Error("No provider found to power agent cluster.");
|
||||||
}
|
}
|
||||||
@ -100,6 +130,18 @@ class AgentHandler {
|
|||||||
return "server-default";
|
return "server-default";
|
||||||
case "ollama":
|
case "ollama":
|
||||||
return "llama3:latest";
|
return "llama3:latest";
|
||||||
|
case "groq":
|
||||||
|
return "llama3-70b-8192";
|
||||||
|
case "togetherai":
|
||||||
|
return "mistralai/Mixtral-8x7B-Instruct-v0.1";
|
||||||
|
case "azure":
|
||||||
|
return "gpt-3.5-turbo";
|
||||||
|
case "koboldcpp":
|
||||||
|
return null;
|
||||||
|
case "gemini":
|
||||||
|
return "gemini-pro";
|
||||||
|
case "localai":
|
||||||
|
return null;
|
||||||
default:
|
default:
|
||||||
return "unknown";
|
return "unknown";
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
|
|||||||
try {
|
try {
|
||||||
const { OpenAI: OpenAIApi } = require("openai");
|
const { OpenAI: OpenAIApi } = require("openai");
|
||||||
const openai = new OpenAIApi({
|
const openai = new OpenAIApi({
|
||||||
baseURL: basePath || process.env.LMSTUDIO_BASE_PATH,
|
baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
|
||||||
apiKey: null,
|
apiKey: null,
|
||||||
});
|
});
|
||||||
const models = await openai.models
|
const models = await openai.models
|
||||||
|
Loading…
Reference in New Issue
Block a user