feat: expand summarization to OSS LLM (#1648)

This commit is contained in:
Timothy Carambat 2024-06-10 14:31:39 -07:00 committed by GitHub
parent 0d84244ca1
commit d470845931
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 213 additions and 88 deletions

View File

@ -41,6 +41,7 @@ class AIbitat {
...rest,
};
this.provider = this.defaultProvider.provider;
this.model = this.defaultProvider.model;
}
/**

View File

@ -154,11 +154,12 @@ const docSummarizer = {
this.controller.abort();
});
return await summarizeContent(
this.super.provider,
this.controller.signal,
document.content
);
return await summarizeContent({
provider: this.super.provider,
model: this.super.model,
controllerSignal: this.controller.signal,
content: document.content,
});
} catch (error) {
this.super.handlerProps.log(
`document-summarizer.summarizeDoc raised an error. ${error.message}`

View File

@ -90,11 +90,13 @@ const webScraping = {
);
this.controller.abort();
});
return summarizeContent(
this.super.provider,
this.controller.signal,
content
);
return summarizeContent({
provider: this.super.provider,
model: this.super.model,
controllerSignal: this.controller.signal,
content,
});
},
});
},

View File

@ -2,8 +2,19 @@
* A service that provides an AI client to create a completion.
*/
/**
* @typedef {Object} LangChainModelConfig
* @property {(string|null)} baseURL - Override the default base URL process.env for this provider
* @property {(string|null)} apiKey - Override the default process.env for this provider
* @property {(number|null)} temperature - Override the default temperature
* @property {(string|null)} model - Overrides model used for provider.
*/
const { ChatOpenAI } = require("@langchain/openai");
const { ChatAnthropic } = require("@langchain/anthropic");
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
const { toValidNumber } = require("../../../http");
const DEFAULT_WORKSPACE_PROMPT =
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
@ -27,8 +38,15 @@ class Provider {
return this._client;
}
/**
*
* @param {string} provider - the string key of the provider LLM being loaded.
* @param {LangChainModelConfig} config - Config to be used to override default connection object.
* @returns
*/
static LangChainChatModel(provider = "openai", config = {}) {
switch (provider) {
// Cloud models
case "openai":
return new ChatOpenAI({
apiKey: process.env.OPEN_AI_KEY,
@ -39,11 +57,108 @@ class Provider {
apiKey: process.env.ANTHROPIC_API_KEY,
...config,
});
default:
case "groq":
return new ChatOpenAI({
apiKey: process.env.OPEN_AI_KEY,
configuration: {
baseURL: "https://api.groq.com/openai/v1",
},
apiKey: process.env.GROQ_API_KEY,
...config,
});
case "mistral":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.mistral.ai/v1",
},
apiKey: process.env.MISTRAL_API_KEY ?? null,
...config,
});
case "openrouter":
return new ChatOpenAI({
configuration: {
baseURL: "https://openrouter.ai/api/v1",
defaultHeaders: {
"HTTP-Referer": "https://useanything.com",
"X-Title": "AnythingLLM",
},
},
apiKey: process.env.OPENROUTER_API_KEY ?? null,
...config,
});
case "perplexity":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.perplexity.ai",
},
apiKey: process.env.PERPLEXITY_API_KEY ?? null,
...config,
});
case "togetherai":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.together.xyz/v1",
},
apiKey: process.env.TOGETHER_AI_API_KEY ?? null,
...config,
});
case "generic-openai":
return new ChatOpenAI({
configuration: {
baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH,
},
apiKey: process.env.GENERIC_OPEN_AI_API_KEY,
maxTokens: toValidNumber(
process.env.GENERIC_OPEN_AI_MAX_TOKENS,
1024
),
...config,
});
// OSS Model Runners
// case "anythingllm_ollama":
// return new ChatOllama({
// baseUrl: process.env.PLACEHOLDER,
// ...config,
// });
case "ollama":
return new ChatOllama({
baseUrl: process.env.OLLAMA_BASE_PATH,
...config,
});
case "lmstudio":
return new ChatOpenAI({
configuration: {
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""),
},
apiKey: "not-used", // Needs to be specified or else will assume OpenAI
...config,
});
case "koboldcpp":
return new ChatOpenAI({
configuration: {
baseURL: process.env.KOBOLD_CPP_BASE_PATH,
},
apiKey: "not-used",
...config,
});
case "localai":
return new ChatOpenAI({
configuration: {
baseURL: process.env.LOCAL_AI_BASE_PATH,
},
apiKey: process.env.LOCAL_AI_API_KEY ?? "not-used",
...config,
});
case "textgenwebui":
return new ChatOpenAI({
configuration: {
baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH,
},
apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? "not-used",
...config,
});
default:
throw new Error(`Unsupported provider ${provider} for this task.`);
}
}

View File

@ -1,28 +1,52 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The agent provider for the Groq provider.
* Using OpenAI tool calling with groq really sucks right now
* its just fast and bad. We should probably migrate this to Untooled to improve
* coherence.
* The agent provider for the GroqAI provider.
* We wrap Groq in UnTooled because its tool-calling built in is quite bad and wasteful.
*/
class GroqProvider extends Provider {
class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = "llama3-8b-8192" } = config;
super();
const client = new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY,
maxRetries: 3,
});
super(client);
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("GroqAI chat: No results!");
if (result.choices.length === 0)
throw new Error("GroqAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
@ -32,68 +56,49 @@ class GroqProvider extends Provider {
*/
async complete(messages, functions = null) {
try {
const response = await this.client.chat.completions.create({
model: this.model,
// stream: true,
messages,
...(Array.isArray(functions) && functions?.length > 0
? { functions }
: {}),
});
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
// Right now, we only support one completion,
// so we just take the first one in the list
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
// treat function calls
if (completion.function_call) {
let functionArgs = {};
try {
functionArgs = JSON.parse(completion.function_call.arguments);
} catch (error) {
// call the complete function again in case it gets a json error
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
// from calling the exact same function over and over in a loop within a single chat exchange
// _but_ we should enable it to call previously used tools in a new chat interaction.
this.deduplicator.reset("runs");
return {
result: completion.content,
cost,
cost: 0,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
}
@ -103,7 +108,7 @@ class GroqProvider extends Provider {
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since Groq has no cost basis.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;

View File

@ -3,26 +3,27 @@ const { PromptTemplate } = require("@langchain/core/prompts");
const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters");
const Provider = require("../providers/ai-provider");
/**
* Summarize content using OpenAI's GPT-3.5 model.
*
* @param self The context of the caller function
* @param content The content to summarize.
* @returns The summarized content.
* @typedef {Object} LCSummarizationConfig
* @property {string} provider The LLM to use for summarization (inherited)
* @property {string} model The LLM Model to use for summarization (inherited)
* @property {AbortController['signal']} controllerSignal Abort controller to stop recursive summarization
* @property {string} content The text content of the text to summarize
*/
const SUMMARY_MODEL = {
anthropic: "claude-3-opus-20240229", // 200,000 tokens
openai: "gpt-4o", // 128,000 tokens
};
async function summarizeContent(
/**
* Summarize content using LLM LC-Chain call
* @param {LCSummarizationConfig} The LLM to use for summarization (inherited)
* @returns {Promise<string>} The summarized content.
*/
async function summarizeContent({
provider = "openai",
model = null,
controllerSignal,
content
) {
content,
}) {
const llm = Provider.LangChainChatModel(provider, {
temperature: 0,
modelName: SUMMARY_MODEL[provider],
model: model,
});
const textSplitter = new RecursiveCharacterTextSplitter({