mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-11 01:10:11 +01:00
Agent support for LLMs with no function calling (#1295)
* add LMStudio agent support (generic) support "work" with non-tool callable LLMs, highly dependent on system specs * add comments * enable few-shot prompting per function for OSS models * Add Agent support for Ollama models * azure, groq, koboldcpp agent support complete + WIP togetherai * WIP gemini agent support * WIP gemini blocked and will not fix for now * azure fix * merge fix * add localai agent support * azure untooled agent support * merge fix * refactor implementation of several agent provideers * update bad merge comment --------- Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
parent
b2b41db110
commit
8422f92542
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
@ -28,6 +28,7 @@
|
||||
"openrouter",
|
||||
"Qdrant",
|
||||
"Serper",
|
||||
"togetherai",
|
||||
"vectordbs",
|
||||
"Weaviate",
|
||||
"Zilliz"
|
||||
|
@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
|
||||
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
import AgentModelSelection from "../AgentModelSelection";
|
||||
|
||||
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
|
||||
const WARN_PERFORMANCE = ["lmstudio", "ollama"];
|
||||
const ENABLED_PROVIDERS = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"lmstudio",
|
||||
"ollama",
|
||||
"localai",
|
||||
"groq",
|
||||
"azure",
|
||||
"koboldcpp",
|
||||
"togetherai",
|
||||
];
|
||||
const WARN_PERFORMANCE = [
|
||||
"lmstudio",
|
||||
"groq",
|
||||
"azure",
|
||||
"koboldcpp",
|
||||
"ollama",
|
||||
"localai",
|
||||
];
|
||||
|
||||
const LLM_DEFAULT = {
|
||||
name: "Please make a selection",
|
||||
|
@ -480,7 +480,7 @@ Read the following conversation.
|
||||
CHAT HISTORY
|
||||
${history.map((c) => `@${c.from}: ${c.content}`).join("\n")}
|
||||
|
||||
Then select the next role from that is going to speak next.
|
||||
Then select the next role from that is going to speak next.
|
||||
Only return the role.
|
||||
`,
|
||||
},
|
||||
@ -522,7 +522,7 @@ Only return the role.
|
||||
? [
|
||||
{
|
||||
role: "user",
|
||||
content: `You are in a whatsapp group. Read the following conversation and then reply.
|
||||
content: `You are in a whatsapp group. Read the following conversation and then reply.
|
||||
Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself.
|
||||
|
||||
CHAT HISTORY
|
||||
@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
|
||||
return new Providers.LMStudioProvider({});
|
||||
case "ollama":
|
||||
return new Providers.OllamaProvider({ model: config.model });
|
||||
case "groq":
|
||||
return new Providers.GroqProvider({ model: config.model });
|
||||
case "togetherai":
|
||||
return new Providers.TogetherAIProvider({ model: config.model });
|
||||
case "azure":
|
||||
return new Providers.AzureOpenAiProvider({ model: config.model });
|
||||
case "koboldcpp":
|
||||
return new Providers.KoboldCPPProvider({});
|
||||
case "localai":
|
||||
return new Providers.LocalAIProvider({ model: config.model });
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
|
@ -58,6 +58,9 @@ class Provider {
|
||||
}
|
||||
}
|
||||
|
||||
// For some providers we may want to override the system prompt to be more verbose.
|
||||
// Currently we only do this for lmstudio, but we probably will want to expand this even more
|
||||
// to any Untooled LLM.
|
||||
static systemPrompt(provider = null) {
|
||||
switch (provider) {
|
||||
case "lmstudio":
|
||||
|
105
server/utils/agents/aibitat/providers/azure.js
Normal file
105
server/utils/agents/aibitat/providers/azure.js
Normal file
@ -0,0 +1,105 @@
|
||||
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
|
||||
/**
|
||||
* The provider for the Azure OpenAI API.
|
||||
*/
|
||||
class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
model;
|
||||
|
||||
constructor(_config = {}) {
|
||||
super();
|
||||
const client = new OpenAIClient(
|
||||
process.env.AZURE_OPENAI_ENDPOINT,
|
||||
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
|
||||
);
|
||||
this._client = client;
|
||||
this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client
|
||||
.getChatCompletions(this.model, messages, {
|
||||
temperature: 0,
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("Azure OpenAI chat: No results!");
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("Azure OpenAI chat: No results length!");
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
let completion;
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
messages,
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
completion = { content: text };
|
||||
}
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.getChatCompletions(
|
||||
this.model,
|
||||
this.cleanMsgs(messages),
|
||||
{
|
||||
temperature: 0.7,
|
||||
}
|
||||
);
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
return { result: completion.content, cost: 0 };
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
* Stubbed since Azure OpenAI has no public cost basis.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = AzureOpenAiProvider;
|
110
server/utils/agents/aibitat/providers/groq.js
Normal file
110
server/utils/agents/aibitat/providers/groq.js
Normal file
@ -0,0 +1,110 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const { RetryError } = require("../error.js");
|
||||
|
||||
/**
|
||||
* The provider for the Groq provider.
|
||||
*/
|
||||
class GroqProvider extends Provider {
|
||||
model;
|
||||
|
||||
constructor(config = {}) {
|
||||
const { model = "llama3-8b-8192" } = config;
|
||||
const client = new OpenAI({
|
||||
baseURL: "https://api.groq.com/openai/v1",
|
||||
apiKey: process.env.GROQ_API_KEY,
|
||||
maxRetries: 3,
|
||||
});
|
||||
super(client);
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
// stream: true,
|
||||
messages,
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { functions }
|
||||
: {}),
|
||||
});
|
||||
|
||||
// Right now, we only support one completion,
|
||||
// so we just take the first one in the list
|
||||
const completion = response.choices[0].message;
|
||||
const cost = this.getCost(response.usage);
|
||||
// treat function calls
|
||||
if (completion.function_call) {
|
||||
let functionArgs = {};
|
||||
try {
|
||||
functionArgs = JSON.parse(completion.function_call.arguments);
|
||||
} catch (error) {
|
||||
// call the complete function again in case it gets a json error
|
||||
return this.complete(
|
||||
[
|
||||
...messages,
|
||||
{
|
||||
role: "function",
|
||||
name: completion.function_call.name,
|
||||
function_call: completion.function_call,
|
||||
content: error?.message,
|
||||
},
|
||||
],
|
||||
functions
|
||||
);
|
||||
}
|
||||
|
||||
// console.log(completion, { functionArgs })
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: completion.function_call.name,
|
||||
arguments: functionArgs,
|
||||
},
|
||||
cost,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
result: completion.content,
|
||||
cost,
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* Stubbed since Groq has no cost basis.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = GroqProvider;
|
@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
|
||||
const response = await chatCb({
|
||||
messages: [
|
||||
{
|
||||
content: `You are a program which picks the most optimal function and parameters to call.
|
||||
content: `You are a program which picks the most optimal function and parameters to call.
|
||||
DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
|
||||
When a function is selection, respond in JSON with no additional text.
|
||||
When there is no relevant function to call - return with a regular chat text response.
|
||||
@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
|
||||
...history,
|
||||
],
|
||||
});
|
||||
|
||||
const call = safeJsonParse(response, null);
|
||||
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.
|
||||
|
||||
|
@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
|
||||
const AnthropicProvider = require("./anthropic.js");
|
||||
const LMStudioProvider = require("./lmstudio.js");
|
||||
const OllamaProvider = require("./ollama.js");
|
||||
const GroqProvider = require("./groq.js");
|
||||
const TogetherAIProvider = require("./togetherai.js");
|
||||
const AzureOpenAiProvider = require("./azure.js");
|
||||
const KoboldCPPProvider = require("./koboldcpp.js");
|
||||
const LocalAIProvider = require("./localai.js");
|
||||
|
||||
module.exports = {
|
||||
OpenAIProvider,
|
||||
AnthropicProvider,
|
||||
LMStudioProvider,
|
||||
OllamaProvider,
|
||||
GroqProvider,
|
||||
TogetherAIProvider,
|
||||
AzureOpenAiProvider,
|
||||
KoboldCPPProvider,
|
||||
LocalAIProvider,
|
||||
};
|
||||
|
113
server/utils/agents/aibitat/providers/koboldcpp.js
Normal file
113
server/utils/agents/aibitat/providers/koboldcpp.js
Normal file
@ -0,0 +1,113 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
|
||||
/**
|
||||
* The provider for the KoboldCPP provider.
|
||||
*/
|
||||
class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
model;
|
||||
|
||||
constructor(_config = {}) {
|
||||
super();
|
||||
const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
|
||||
const client = new OpenAI({
|
||||
baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
|
||||
apiKey: null,
|
||||
maxRetries: 3,
|
||||
});
|
||||
|
||||
this._client = client;
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client.chat.completions
|
||||
.create({
|
||||
model: this.model,
|
||||
temperature: 0,
|
||||
messages,
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("KoboldCPP chat: No results!");
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("KoboldCPP chat: No results length!");
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
let completion;
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
messages,
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
completion = { content: text };
|
||||
}
|
||||
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(messages),
|
||||
});
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
|
||||
return {
|
||||
result: completion.content,
|
||||
cost: 0,
|
||||
};
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* Stubbed since KoboldCPP has no cost basis.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = KoboldCPPProvider;
|
@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
|
||||
apiKey: null,
|
||||
maxRetries: 3,
|
||||
model,
|
||||
});
|
||||
|
||||
this._client = client;
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
|
114
server/utils/agents/aibitat/providers/localai.js
Normal file
114
server/utils/agents/aibitat/providers/localai.js
Normal file
@ -0,0 +1,114 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
|
||||
/**
|
||||
* The provider for the LocalAI provider.
|
||||
*/
|
||||
class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
model;
|
||||
|
||||
constructor(config = {}) {
|
||||
const { model = null } = config;
|
||||
super();
|
||||
const client = new OpenAI({
|
||||
baseURL: process.env.LOCAL_AI_BASE_PATH,
|
||||
apiKey: process.env.LOCAL_AI_API_KEY ?? null,
|
||||
maxRetries: 3,
|
||||
});
|
||||
|
||||
this._client = client;
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client.chat.completions
|
||||
.create({
|
||||
model: this.model,
|
||||
temperature: 0,
|
||||
messages,
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("LocalAI chat: No results!");
|
||||
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("LocalAI chat: No results length!");
|
||||
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
let completion;
|
||||
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
messages,
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
|
||||
completion = { content: text };
|
||||
}
|
||||
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(messages),
|
||||
});
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
|
||||
return { result: completion.content, cost: 0 };
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* Stubbed since LocalAI has no cost basis.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = LocalAiProvider;
|
113
server/utils/agents/aibitat/providers/togetherai.js
Normal file
113
server/utils/agents/aibitat/providers/togetherai.js
Normal file
@ -0,0 +1,113 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
|
||||
/**
|
||||
* The provider for the TogetherAI provider.
|
||||
*/
|
||||
class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
model;
|
||||
|
||||
constructor(config = {}) {
|
||||
const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
|
||||
super();
|
||||
const client = new OpenAI({
|
||||
baseURL: "https://api.together.xyz/v1",
|
||||
apiKey: process.env.TOGETHER_AI_API_KEY,
|
||||
maxRetries: 3,
|
||||
});
|
||||
|
||||
this._client = client;
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client.chat.completions
|
||||
.create({
|
||||
model: this.model,
|
||||
temperature: 0,
|
||||
messages,
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("LMStudio chat: No results!");
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("LMStudio chat: No results length!");
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
let completion;
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
messages,
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
completion = { content: text };
|
||||
}
|
||||
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(messages),
|
||||
});
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
|
||||
return {
|
||||
result: completion.content,
|
||||
cost: 0,
|
||||
};
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* Stubbed since LMStudio has no cost basis.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = TogetherAIProvider;
|
@ -85,6 +85,36 @@ class AgentHandler {
|
||||
if (!process.env.OLLAMA_BASE_PATH)
|
||||
throw new Error("Ollama base path must be provided to use agents.");
|
||||
break;
|
||||
case "groq":
|
||||
if (!process.env.GROQ_API_KEY)
|
||||
throw new Error("Groq API key must be provided to use agents.");
|
||||
break;
|
||||
case "togetherai":
|
||||
if (!process.env.TOGETHER_AI_API_KEY)
|
||||
throw new Error("TogetherAI API key must be provided to use agents.");
|
||||
break;
|
||||
case "azure":
|
||||
if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
|
||||
throw new Error(
|
||||
"Azure OpenAI API endpoint and key must be provided to use agents."
|
||||
);
|
||||
break;
|
||||
case "koboldcpp":
|
||||
if (!process.env.KOBOLD_CPP_BASE_PATH)
|
||||
throw new Error(
|
||||
"KoboldCPP must have a valid base path to use for the api."
|
||||
);
|
||||
break;
|
||||
case "localai":
|
||||
if (!process.env.LOCAL_AI_BASE_PATH)
|
||||
throw new Error(
|
||||
"LocalAI must have a valid base path to use for the api."
|
||||
);
|
||||
break;
|
||||
case "gemini":
|
||||
if (!process.env.GEMINI_API_KEY)
|
||||
throw new Error("Gemini API key must be provided to use agents.");
|
||||
break;
|
||||
default:
|
||||
throw new Error("No provider found to power agent cluster.");
|
||||
}
|
||||
@ -100,6 +130,18 @@ class AgentHandler {
|
||||
return "server-default";
|
||||
case "ollama":
|
||||
return "llama3:latest";
|
||||
case "groq":
|
||||
return "llama3-70b-8192";
|
||||
case "togetherai":
|
||||
return "mistralai/Mixtral-8x7B-Instruct-v0.1";
|
||||
case "azure":
|
||||
return "gpt-3.5-turbo";
|
||||
case "koboldcpp":
|
||||
return null;
|
||||
case "gemini":
|
||||
return "gemini-pro";
|
||||
case "localai":
|
||||
return null;
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
|
||||
try {
|
||||
const { OpenAI: OpenAIApi } = require("openai");
|
||||
const openai = new OpenAIApi({
|
||||
baseURL: basePath || process.env.LMSTUDIO_BASE_PATH,
|
||||
baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
|
||||
apiKey: null,
|
||||
});
|
||||
const models = await openai.models
|
||||
|
Loading…
Reference in New Issue
Block a user