mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-10 17:00:11 +01:00
feat: expand summarization to OSS LLM (#1648)
This commit is contained in:
parent
0d84244ca1
commit
d470845931
@ -41,6 +41,7 @@ class AIbitat {
|
||||
...rest,
|
||||
};
|
||||
this.provider = this.defaultProvider.provider;
|
||||
this.model = this.defaultProvider.model;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -154,11 +154,12 @@ const docSummarizer = {
|
||||
this.controller.abort();
|
||||
});
|
||||
|
||||
return await summarizeContent(
|
||||
this.super.provider,
|
||||
this.controller.signal,
|
||||
document.content
|
||||
);
|
||||
return await summarizeContent({
|
||||
provider: this.super.provider,
|
||||
model: this.super.model,
|
||||
controllerSignal: this.controller.signal,
|
||||
content: document.content,
|
||||
});
|
||||
} catch (error) {
|
||||
this.super.handlerProps.log(
|
||||
`document-summarizer.summarizeDoc raised an error. ${error.message}`
|
||||
|
@ -90,11 +90,13 @@ const webScraping = {
|
||||
);
|
||||
this.controller.abort();
|
||||
});
|
||||
return summarizeContent(
|
||||
this.super.provider,
|
||||
this.controller.signal,
|
||||
content
|
||||
);
|
||||
|
||||
return summarizeContent({
|
||||
provider: this.super.provider,
|
||||
model: this.super.model,
|
||||
controllerSignal: this.controller.signal,
|
||||
content,
|
||||
});
|
||||
},
|
||||
});
|
||||
},
|
||||
|
@ -2,8 +2,19 @@
|
||||
* A service that provides an AI client to create a completion.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} LangChainModelConfig
|
||||
* @property {(string|null)} baseURL - Override the default base URL process.env for this provider
|
||||
* @property {(string|null)} apiKey - Override the default process.env for this provider
|
||||
* @property {(number|null)} temperature - Override the default temperature
|
||||
* @property {(string|null)} model - Overrides model used for provider.
|
||||
*/
|
||||
|
||||
const { ChatOpenAI } = require("@langchain/openai");
|
||||
const { ChatAnthropic } = require("@langchain/anthropic");
|
||||
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
|
||||
const { toValidNumber } = require("../../../http");
|
||||
|
||||
const DEFAULT_WORKSPACE_PROMPT =
|
||||
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
||||
|
||||
@ -27,8 +38,15 @@ class Provider {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} provider - the string key of the provider LLM being loaded.
|
||||
* @param {LangChainModelConfig} config - Config to be used to override default connection object.
|
||||
* @returns
|
||||
*/
|
||||
static LangChainChatModel(provider = "openai", config = {}) {
|
||||
switch (provider) {
|
||||
// Cloud models
|
||||
case "openai":
|
||||
return new ChatOpenAI({
|
||||
apiKey: process.env.OPEN_AI_KEY,
|
||||
@ -39,11 +57,108 @@ class Provider {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
...config,
|
||||
});
|
||||
default:
|
||||
case "groq":
|
||||
return new ChatOpenAI({
|
||||
apiKey: process.env.OPEN_AI_KEY,
|
||||
configuration: {
|
||||
baseURL: "https://api.groq.com/openai/v1",
|
||||
},
|
||||
apiKey: process.env.GROQ_API_KEY,
|
||||
...config,
|
||||
});
|
||||
case "mistral":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: "https://api.mistral.ai/v1",
|
||||
},
|
||||
apiKey: process.env.MISTRAL_API_KEY ?? null,
|
||||
...config,
|
||||
});
|
||||
case "openrouter":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: "https://openrouter.ai/api/v1",
|
||||
defaultHeaders: {
|
||||
"HTTP-Referer": "https://useanything.com",
|
||||
"X-Title": "AnythingLLM",
|
||||
},
|
||||
},
|
||||
apiKey: process.env.OPENROUTER_API_KEY ?? null,
|
||||
...config,
|
||||
});
|
||||
case "perplexity":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: "https://api.perplexity.ai",
|
||||
},
|
||||
apiKey: process.env.PERPLEXITY_API_KEY ?? null,
|
||||
...config,
|
||||
});
|
||||
case "togetherai":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: "https://api.together.xyz/v1",
|
||||
},
|
||||
apiKey: process.env.TOGETHER_AI_API_KEY ?? null,
|
||||
...config,
|
||||
});
|
||||
case "generic-openai":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH,
|
||||
},
|
||||
apiKey: process.env.GENERIC_OPEN_AI_API_KEY,
|
||||
maxTokens: toValidNumber(
|
||||
process.env.GENERIC_OPEN_AI_MAX_TOKENS,
|
||||
1024
|
||||
),
|
||||
...config,
|
||||
});
|
||||
|
||||
// OSS Model Runners
|
||||
// case "anythingllm_ollama":
|
||||
// return new ChatOllama({
|
||||
// baseUrl: process.env.PLACEHOLDER,
|
||||
// ...config,
|
||||
// });
|
||||
case "ollama":
|
||||
return new ChatOllama({
|
||||
baseUrl: process.env.OLLAMA_BASE_PATH,
|
||||
...config,
|
||||
});
|
||||
case "lmstudio":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""),
|
||||
},
|
||||
apiKey: "not-used", // Needs to be specified or else will assume OpenAI
|
||||
...config,
|
||||
});
|
||||
case "koboldcpp":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: process.env.KOBOLD_CPP_BASE_PATH,
|
||||
},
|
||||
apiKey: "not-used",
|
||||
...config,
|
||||
});
|
||||
case "localai":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: process.env.LOCAL_AI_BASE_PATH,
|
||||
},
|
||||
apiKey: process.env.LOCAL_AI_API_KEY ?? "not-used",
|
||||
...config,
|
||||
});
|
||||
case "textgenwebui":
|
||||
return new ChatOpenAI({
|
||||
configuration: {
|
||||
baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH,
|
||||
},
|
||||
apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? "not-used",
|
||||
...config,
|
||||
});
|
||||
default:
|
||||
throw new Error(`Unsupported provider ${provider} for this task.`);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,28 +1,52 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const { RetryError } = require("../error.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
|
||||
/**
|
||||
* The agent provider for the Groq provider.
|
||||
* Using OpenAI tool calling with groq really sucks right now
|
||||
* its just fast and bad. We should probably migrate this to Untooled to improve
|
||||
* coherence.
|
||||
* The agent provider for the GroqAI provider.
|
||||
* We wrap Groq in UnTooled because its tool-calling built in is quite bad and wasteful.
|
||||
*/
|
||||
class GroqProvider extends Provider {
|
||||
class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
model;
|
||||
|
||||
constructor(config = {}) {
|
||||
const { model = "llama3-8b-8192" } = config;
|
||||
super();
|
||||
const client = new OpenAI({
|
||||
baseURL: "https://api.groq.com/openai/v1",
|
||||
apiKey: process.env.GROQ_API_KEY,
|
||||
maxRetries: 3,
|
||||
});
|
||||
super(client);
|
||||
|
||||
this._client = client;
|
||||
this.model = model;
|
||||
this.verbose = true;
|
||||
}
|
||||
|
||||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client.chat.completions
|
||||
.create({
|
||||
model: this.model,
|
||||
temperature: 0,
|
||||
messages,
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("GroqAI chat: No results!");
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("GroqAI chat: No results length!");
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
@ -32,68 +56,49 @@ class GroqProvider extends Provider {
|
||||
*/
|
||||
async complete(messages, functions = null) {
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
// stream: true,
|
||||
messages,
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { functions }
|
||||
: {}),
|
||||
});
|
||||
let completion;
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
messages,
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
|
||||
// Right now, we only support one completion,
|
||||
// so we just take the first one in the list
|
||||
const completion = response.choices[0].message;
|
||||
const cost = this.getCost(response.usage);
|
||||
// treat function calls
|
||||
if (completion.function_call) {
|
||||
let functionArgs = {};
|
||||
try {
|
||||
functionArgs = JSON.parse(completion.function_call.arguments);
|
||||
} catch (error) {
|
||||
// call the complete function again in case it gets a json error
|
||||
return this.complete(
|
||||
[
|
||||
...messages,
|
||||
{
|
||||
role: "function",
|
||||
name: completion.function_call.name,
|
||||
function_call: completion.function_call,
|
||||
content: error?.message,
|
||||
},
|
||||
],
|
||||
functions
|
||||
);
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// console.log(completion, { functionArgs })
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: completion.function_call.name,
|
||||
arguments: functionArgs,
|
||||
},
|
||||
cost,
|
||||
};
|
||||
completion = { content: text };
|
||||
}
|
||||
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(messages),
|
||||
});
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
|
||||
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
|
||||
// from calling the exact same function over and over in a loop within a single chat exchange
|
||||
// _but_ we should enable it to call previously used tools in a new chat interaction.
|
||||
this.deduplicator.reset("runs");
|
||||
return {
|
||||
result: completion.content,
|
||||
cost,
|
||||
cost: 0,
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@ -103,7 +108,7 @@ class GroqProvider extends Provider {
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* Stubbed since Groq has no cost basis.
|
||||
* Stubbed since LMStudio has no cost basis.
|
||||
*/
|
||||
getCost(_usage) {
|
||||
return 0;
|
||||
|
@ -3,26 +3,27 @@ const { PromptTemplate } = require("@langchain/core/prompts");
|
||||
const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters");
|
||||
const Provider = require("../providers/ai-provider");
|
||||
/**
|
||||
* Summarize content using OpenAI's GPT-3.5 model.
|
||||
*
|
||||
* @param self The context of the caller function
|
||||
* @param content The content to summarize.
|
||||
* @returns The summarized content.
|
||||
* @typedef {Object} LCSummarizationConfig
|
||||
* @property {string} provider The LLM to use for summarization (inherited)
|
||||
* @property {string} model The LLM Model to use for summarization (inherited)
|
||||
* @property {AbortController['signal']} controllerSignal Abort controller to stop recursive summarization
|
||||
* @property {string} content The text content of the text to summarize
|
||||
*/
|
||||
|
||||
const SUMMARY_MODEL = {
|
||||
anthropic: "claude-3-opus-20240229", // 200,000 tokens
|
||||
openai: "gpt-4o", // 128,000 tokens
|
||||
};
|
||||
|
||||
async function summarizeContent(
|
||||
/**
|
||||
* Summarize content using LLM LC-Chain call
|
||||
* @param {LCSummarizationConfig} The LLM to use for summarization (inherited)
|
||||
* @returns {Promise<string>} The summarized content.
|
||||
*/
|
||||
async function summarizeContent({
|
||||
provider = "openai",
|
||||
model = null,
|
||||
controllerSignal,
|
||||
content
|
||||
) {
|
||||
content,
|
||||
}) {
|
||||
const llm = Provider.LangChainChatModel(provider, {
|
||||
temperature: 0,
|
||||
modelName: SUMMARY_MODEL[provider],
|
||||
model: model,
|
||||
});
|
||||
|
||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||
|
Loading…
Reference in New Issue
Block a user