mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-14 02:20:12 +01:00
feat: expand summarization to OSS LLM (#1648)
This commit is contained in:
parent
0d84244ca1
commit
d470845931
@ -41,6 +41,7 @@ class AIbitat {
|
|||||||
...rest,
|
...rest,
|
||||||
};
|
};
|
||||||
this.provider = this.defaultProvider.provider;
|
this.provider = this.defaultProvider.provider;
|
||||||
|
this.model = this.defaultProvider.model;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -154,11 +154,12 @@ const docSummarizer = {
|
|||||||
this.controller.abort();
|
this.controller.abort();
|
||||||
});
|
});
|
||||||
|
|
||||||
return await summarizeContent(
|
return await summarizeContent({
|
||||||
this.super.provider,
|
provider: this.super.provider,
|
||||||
this.controller.signal,
|
model: this.super.model,
|
||||||
document.content
|
controllerSignal: this.controller.signal,
|
||||||
);
|
content: document.content,
|
||||||
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.super.handlerProps.log(
|
this.super.handlerProps.log(
|
||||||
`document-summarizer.summarizeDoc raised an error. ${error.message}`
|
`document-summarizer.summarizeDoc raised an error. ${error.message}`
|
||||||
|
@ -90,11 +90,13 @@ const webScraping = {
|
|||||||
);
|
);
|
||||||
this.controller.abort();
|
this.controller.abort();
|
||||||
});
|
});
|
||||||
return summarizeContent(
|
|
||||||
this.super.provider,
|
return summarizeContent({
|
||||||
this.controller.signal,
|
provider: this.super.provider,
|
||||||
content
|
model: this.super.model,
|
||||||
);
|
controllerSignal: this.controller.signal,
|
||||||
|
content,
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
@ -2,8 +2,19 @@
|
|||||||
* A service that provides an AI client to create a completion.
|
* A service that provides an AI client to create a completion.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} LangChainModelConfig
|
||||||
|
* @property {(string|null)} baseURL - Override the default base URL process.env for this provider
|
||||||
|
* @property {(string|null)} apiKey - Override the default process.env for this provider
|
||||||
|
* @property {(number|null)} temperature - Override the default temperature
|
||||||
|
* @property {(string|null)} model - Overrides model used for provider.
|
||||||
|
*/
|
||||||
|
|
||||||
const { ChatOpenAI } = require("@langchain/openai");
|
const { ChatOpenAI } = require("@langchain/openai");
|
||||||
const { ChatAnthropic } = require("@langchain/anthropic");
|
const { ChatAnthropic } = require("@langchain/anthropic");
|
||||||
|
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
|
||||||
|
const { toValidNumber } = require("../../../http");
|
||||||
|
|
||||||
const DEFAULT_WORKSPACE_PROMPT =
|
const DEFAULT_WORKSPACE_PROMPT =
|
||||||
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
||||||
|
|
||||||
@ -27,8 +38,15 @@ class Provider {
|
|||||||
return this._client;
|
return this._client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param {string} provider - the string key of the provider LLM being loaded.
|
||||||
|
* @param {LangChainModelConfig} config - Config to be used to override default connection object.
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
static LangChainChatModel(provider = "openai", config = {}) {
|
static LangChainChatModel(provider = "openai", config = {}) {
|
||||||
switch (provider) {
|
switch (provider) {
|
||||||
|
// Cloud models
|
||||||
case "openai":
|
case "openai":
|
||||||
return new ChatOpenAI({
|
return new ChatOpenAI({
|
||||||
apiKey: process.env.OPEN_AI_KEY,
|
apiKey: process.env.OPEN_AI_KEY,
|
||||||
@ -39,11 +57,108 @@ class Provider {
|
|||||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||||
...config,
|
...config,
|
||||||
});
|
});
|
||||||
default:
|
case "groq":
|
||||||
return new ChatOpenAI({
|
return new ChatOpenAI({
|
||||||
apiKey: process.env.OPEN_AI_KEY,
|
configuration: {
|
||||||
|
baseURL: "https://api.groq.com/openai/v1",
|
||||||
|
},
|
||||||
|
apiKey: process.env.GROQ_API_KEY,
|
||||||
...config,
|
...config,
|
||||||
});
|
});
|
||||||
|
case "mistral":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: "https://api.mistral.ai/v1",
|
||||||
|
},
|
||||||
|
apiKey: process.env.MISTRAL_API_KEY ?? null,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "openrouter":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: "https://openrouter.ai/api/v1",
|
||||||
|
defaultHeaders: {
|
||||||
|
"HTTP-Referer": "https://useanything.com",
|
||||||
|
"X-Title": "AnythingLLM",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
apiKey: process.env.OPENROUTER_API_KEY ?? null,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "perplexity":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: "https://api.perplexity.ai",
|
||||||
|
},
|
||||||
|
apiKey: process.env.PERPLEXITY_API_KEY ?? null,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "togetherai":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: "https://api.together.xyz/v1",
|
||||||
|
},
|
||||||
|
apiKey: process.env.TOGETHER_AI_API_KEY ?? null,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "generic-openai":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH,
|
||||||
|
},
|
||||||
|
apiKey: process.env.GENERIC_OPEN_AI_API_KEY,
|
||||||
|
maxTokens: toValidNumber(
|
||||||
|
process.env.GENERIC_OPEN_AI_MAX_TOKENS,
|
||||||
|
1024
|
||||||
|
),
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
|
||||||
|
// OSS Model Runners
|
||||||
|
// case "anythingllm_ollama":
|
||||||
|
// return new ChatOllama({
|
||||||
|
// baseUrl: process.env.PLACEHOLDER,
|
||||||
|
// ...config,
|
||||||
|
// });
|
||||||
|
case "ollama":
|
||||||
|
return new ChatOllama({
|
||||||
|
baseUrl: process.env.OLLAMA_BASE_PATH,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "lmstudio":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""),
|
||||||
|
},
|
||||||
|
apiKey: "not-used", // Needs to be specified or else will assume OpenAI
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "koboldcpp":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: process.env.KOBOLD_CPP_BASE_PATH,
|
||||||
|
},
|
||||||
|
apiKey: "not-used",
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "localai":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: process.env.LOCAL_AI_BASE_PATH,
|
||||||
|
},
|
||||||
|
apiKey: process.env.LOCAL_AI_API_KEY ?? "not-used",
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "textgenwebui":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
configuration: {
|
||||||
|
baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH,
|
||||||
|
},
|
||||||
|
apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? "not-used",
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
default:
|
||||||
|
throw new Error(`Unsupported provider ${provider} for this task.`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,28 +1,52 @@
|
|||||||
const OpenAI = require("openai");
|
const OpenAI = require("openai");
|
||||||
const Provider = require("./ai-provider.js");
|
const Provider = require("./ai-provider.js");
|
||||||
const { RetryError } = require("../error.js");
|
const InheritMultiple = require("./helpers/classes.js");
|
||||||
|
const UnTooled = require("./helpers/untooled.js");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The agent provider for the Groq provider.
|
* The agent provider for the GroqAI provider.
|
||||||
* Using OpenAI tool calling with groq really sucks right now
|
* We wrap Groq in UnTooled because its tool-calling built in is quite bad and wasteful.
|
||||||
* its just fast and bad. We should probably migrate this to Untooled to improve
|
|
||||||
* coherence.
|
|
||||||
*/
|
*/
|
||||||
class GroqProvider extends Provider {
|
class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
model;
|
model;
|
||||||
|
|
||||||
constructor(config = {}) {
|
constructor(config = {}) {
|
||||||
const { model = "llama3-8b-8192" } = config;
|
const { model = "llama3-8b-8192" } = config;
|
||||||
|
super();
|
||||||
const client = new OpenAI({
|
const client = new OpenAI({
|
||||||
baseURL: "https://api.groq.com/openai/v1",
|
baseURL: "https://api.groq.com/openai/v1",
|
||||||
apiKey: process.env.GROQ_API_KEY,
|
apiKey: process.env.GROQ_API_KEY,
|
||||||
maxRetries: 3,
|
maxRetries: 3,
|
||||||
});
|
});
|
||||||
super(client);
|
|
||||||
|
this._client = client;
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.verbose = true;
|
this.verbose = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get client() {
|
||||||
|
return this._client;
|
||||||
|
}
|
||||||
|
|
||||||
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
return await this.client.chat.completions
|
||||||
|
.create({
|
||||||
|
model: this.model,
|
||||||
|
temperature: 0,
|
||||||
|
messages,
|
||||||
|
})
|
||||||
|
.then((result) => {
|
||||||
|
if (!result.hasOwnProperty("choices"))
|
||||||
|
throw new Error("GroqAI chat: No results!");
|
||||||
|
if (result.choices.length === 0)
|
||||||
|
throw new Error("GroqAI chat: No results length!");
|
||||||
|
return result.choices[0].message.content;
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a completion based on the received messages.
|
* Create a completion based on the received messages.
|
||||||
*
|
*
|
||||||
@ -32,68 +56,49 @@ class GroqProvider extends Provider {
|
|||||||
*/
|
*/
|
||||||
async complete(messages, functions = null) {
|
async complete(messages, functions = null) {
|
||||||
try {
|
try {
|
||||||
const response = await this.client.chat.completions.create({
|
let completion;
|
||||||
model: this.model,
|
if (functions.length > 0) {
|
||||||
// stream: true,
|
const { toolCall, text } = await this.functionCall(
|
||||||
messages,
|
messages,
|
||||||
...(Array.isArray(functions) && functions?.length > 0
|
functions,
|
||||||
? { functions }
|
this.#handleFunctionCallChat.bind(this)
|
||||||
: {}),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Right now, we only support one completion,
|
|
||||||
// so we just take the first one in the list
|
|
||||||
const completion = response.choices[0].message;
|
|
||||||
const cost = this.getCost(response.usage);
|
|
||||||
// treat function calls
|
|
||||||
if (completion.function_call) {
|
|
||||||
let functionArgs = {};
|
|
||||||
try {
|
|
||||||
functionArgs = JSON.parse(completion.function_call.arguments);
|
|
||||||
} catch (error) {
|
|
||||||
// call the complete function again in case it gets a json error
|
|
||||||
return this.complete(
|
|
||||||
[
|
|
||||||
...messages,
|
|
||||||
{
|
|
||||||
role: "function",
|
|
||||||
name: completion.function_call.name,
|
|
||||||
function_call: completion.function_call,
|
|
||||||
content: error?.message,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
functions
|
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
// console.log(completion, { functionArgs })
|
if (toolCall !== null) {
|
||||||
|
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||||
|
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||||
return {
|
return {
|
||||||
result: null,
|
result: null,
|
||||||
functionCall: {
|
functionCall: {
|
||||||
name: completion.function_call.name,
|
name: toolCall.name,
|
||||||
arguments: functionArgs,
|
arguments: toolCall.arguments,
|
||||||
},
|
},
|
||||||
cost,
|
cost: 0,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
completion = { content: text };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!completion?.content) {
|
||||||
|
this.providerLog(
|
||||||
|
"Will assume chat completion without tool call inputs."
|
||||||
|
);
|
||||||
|
const response = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
messages: this.cleanMsgs(messages),
|
||||||
|
});
|
||||||
|
completion = response.choices[0].message;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
|
||||||
|
// from calling the exact same function over and over in a loop within a single chat exchange
|
||||||
|
// _but_ we should enable it to call previously used tools in a new chat interaction.
|
||||||
|
this.deduplicator.reset("runs");
|
||||||
return {
|
return {
|
||||||
result: completion.content,
|
result: completion.content,
|
||||||
cost,
|
cost: 0,
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
|
||||||
// will make auth better.
|
|
||||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
|
||||||
|
|
||||||
if (
|
|
||||||
error instanceof OpenAI.RateLimitError ||
|
|
||||||
error instanceof OpenAI.InternalServerError ||
|
|
||||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
|
||||||
) {
|
|
||||||
throw new RetryError(error.message);
|
|
||||||
}
|
|
||||||
|
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -103,7 +108,7 @@ class GroqProvider extends Provider {
|
|||||||
*
|
*
|
||||||
* @param _usage The completion to get the cost for.
|
* @param _usage The completion to get the cost for.
|
||||||
* @returns The cost of the completion.
|
* @returns The cost of the completion.
|
||||||
* Stubbed since Groq has no cost basis.
|
* Stubbed since LMStudio has no cost basis.
|
||||||
*/
|
*/
|
||||||
getCost(_usage) {
|
getCost(_usage) {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -3,26 +3,27 @@ const { PromptTemplate } = require("@langchain/core/prompts");
|
|||||||
const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters");
|
const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters");
|
||||||
const Provider = require("../providers/ai-provider");
|
const Provider = require("../providers/ai-provider");
|
||||||
/**
|
/**
|
||||||
* Summarize content using OpenAI's GPT-3.5 model.
|
* @typedef {Object} LCSummarizationConfig
|
||||||
*
|
* @property {string} provider The LLM to use for summarization (inherited)
|
||||||
* @param self The context of the caller function
|
* @property {string} model The LLM Model to use for summarization (inherited)
|
||||||
* @param content The content to summarize.
|
* @property {AbortController['signal']} controllerSignal Abort controller to stop recursive summarization
|
||||||
* @returns The summarized content.
|
* @property {string} content The text content of the text to summarize
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const SUMMARY_MODEL = {
|
/**
|
||||||
anthropic: "claude-3-opus-20240229", // 200,000 tokens
|
* Summarize content using LLM LC-Chain call
|
||||||
openai: "gpt-4o", // 128,000 tokens
|
* @param {LCSummarizationConfig} The LLM to use for summarization (inherited)
|
||||||
};
|
* @returns {Promise<string>} The summarized content.
|
||||||
|
*/
|
||||||
async function summarizeContent(
|
async function summarizeContent({
|
||||||
provider = "openai",
|
provider = "openai",
|
||||||
|
model = null,
|
||||||
controllerSignal,
|
controllerSignal,
|
||||||
content
|
content,
|
||||||
) {
|
}) {
|
||||||
const llm = Provider.LangChainChatModel(provider, {
|
const llm = Provider.LangChainChatModel(provider, {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
modelName: SUMMARY_MODEL[provider],
|
model: model,
|
||||||
});
|
});
|
||||||
|
|
||||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||||
|
Loading…
Reference in New Issue
Block a user