feat: expand summarization to OSS LLM (#1648)

This commit is contained in:
Timothy Carambat 2024-06-10 14:31:39 -07:00 committed by GitHub
parent 0d84244ca1
commit d470845931
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 213 additions and 88 deletions

View File

@ -41,6 +41,7 @@ class AIbitat {
...rest, ...rest,
}; };
this.provider = this.defaultProvider.provider; this.provider = this.defaultProvider.provider;
this.model = this.defaultProvider.model;
} }
/** /**

View File

@ -154,11 +154,12 @@ const docSummarizer = {
this.controller.abort(); this.controller.abort();
}); });
return await summarizeContent( return await summarizeContent({
this.super.provider, provider: this.super.provider,
this.controller.signal, model: this.super.model,
document.content controllerSignal: this.controller.signal,
); content: document.content,
});
} catch (error) { } catch (error) {
this.super.handlerProps.log( this.super.handlerProps.log(
`document-summarizer.summarizeDoc raised an error. ${error.message}` `document-summarizer.summarizeDoc raised an error. ${error.message}`

View File

@ -90,11 +90,13 @@ const webScraping = {
); );
this.controller.abort(); this.controller.abort();
}); });
return summarizeContent(
this.super.provider, return summarizeContent({
this.controller.signal, provider: this.super.provider,
content model: this.super.model,
); controllerSignal: this.controller.signal,
content,
});
}, },
}); });
}, },

View File

@ -2,8 +2,19 @@
* A service that provides an AI client to create a completion. * A service that provides an AI client to create a completion.
*/ */
/**
* @typedef {Object} LangChainModelConfig
* @property {(string|null)} baseURL - Override the default base URL process.env for this provider
* @property {(string|null)} apiKey - Override the default process.env for this provider
* @property {(number|null)} temperature - Override the default temperature
* @property {(string|null)} model - Overrides model used for provider.
*/
const { ChatOpenAI } = require("@langchain/openai"); const { ChatOpenAI } = require("@langchain/openai");
const { ChatAnthropic } = require("@langchain/anthropic"); const { ChatAnthropic } = require("@langchain/anthropic");
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
const { toValidNumber } = require("../../../http");
const DEFAULT_WORKSPACE_PROMPT = const DEFAULT_WORKSPACE_PROMPT =
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
@ -27,8 +38,15 @@ class Provider {
return this._client; return this._client;
} }
/**
*
* @param {string} provider - the string key of the provider LLM being loaded.
* @param {LangChainModelConfig} config - Config to be used to override default connection object.
* @returns
*/
static LangChainChatModel(provider = "openai", config = {}) { static LangChainChatModel(provider = "openai", config = {}) {
switch (provider) { switch (provider) {
// Cloud models
case "openai": case "openai":
return new ChatOpenAI({ return new ChatOpenAI({
apiKey: process.env.OPEN_AI_KEY, apiKey: process.env.OPEN_AI_KEY,
@ -39,11 +57,108 @@ class Provider {
apiKey: process.env.ANTHROPIC_API_KEY, apiKey: process.env.ANTHROPIC_API_KEY,
...config, ...config,
}); });
default: case "groq":
return new ChatOpenAI({ return new ChatOpenAI({
apiKey: process.env.OPEN_AI_KEY, configuration: {
baseURL: "https://api.groq.com/openai/v1",
},
apiKey: process.env.GROQ_API_KEY,
...config, ...config,
}); });
case "mistral":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.mistral.ai/v1",
},
apiKey: process.env.MISTRAL_API_KEY ?? null,
...config,
});
case "openrouter":
return new ChatOpenAI({
configuration: {
baseURL: "https://openrouter.ai/api/v1",
defaultHeaders: {
"HTTP-Referer": "https://useanything.com",
"X-Title": "AnythingLLM",
},
},
apiKey: process.env.OPENROUTER_API_KEY ?? null,
...config,
});
case "perplexity":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.perplexity.ai",
},
apiKey: process.env.PERPLEXITY_API_KEY ?? null,
...config,
});
case "togetherai":
return new ChatOpenAI({
configuration: {
baseURL: "https://api.together.xyz/v1",
},
apiKey: process.env.TOGETHER_AI_API_KEY ?? null,
...config,
});
case "generic-openai":
return new ChatOpenAI({
configuration: {
baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH,
},
apiKey: process.env.GENERIC_OPEN_AI_API_KEY,
maxTokens: toValidNumber(
process.env.GENERIC_OPEN_AI_MAX_TOKENS,
1024
),
...config,
});
// OSS Model Runners
// case "anythingllm_ollama":
// return new ChatOllama({
// baseUrl: process.env.PLACEHOLDER,
// ...config,
// });
case "ollama":
return new ChatOllama({
baseUrl: process.env.OLLAMA_BASE_PATH,
...config,
});
case "lmstudio":
return new ChatOpenAI({
configuration: {
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""),
},
apiKey: "not-used", // Needs to be specified or else will assume OpenAI
...config,
});
case "koboldcpp":
return new ChatOpenAI({
configuration: {
baseURL: process.env.KOBOLD_CPP_BASE_PATH,
},
apiKey: "not-used",
...config,
});
case "localai":
return new ChatOpenAI({
configuration: {
baseURL: process.env.LOCAL_AI_BASE_PATH,
},
apiKey: process.env.LOCAL_AI_API_KEY ?? "not-used",
...config,
});
case "textgenwebui":
return new ChatOpenAI({
configuration: {
baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH,
},
apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? "not-used",
...config,
});
default:
throw new Error(`Unsupported provider ${provider} for this task.`);
} }
} }

View File

@ -1,28 +1,52 @@
const OpenAI = require("openai"); const OpenAI = require("openai");
const Provider = require("./ai-provider.js"); const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js"); const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/** /**
* The agent provider for the Groq provider. * The agent provider for the GroqAI provider.
* Using OpenAI tool calling with groq really sucks right now * We wrap Groq in UnTooled because its tool-calling built in is quite bad and wasteful.
* its just fast and bad. We should probably migrate this to Untooled to improve
* coherence.
*/ */
class GroqProvider extends Provider { class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
model; model;
constructor(config = {}) { constructor(config = {}) {
const { model = "llama3-8b-8192" } = config; const { model = "llama3-8b-8192" } = config;
super();
const client = new OpenAI({ const client = new OpenAI({
baseURL: "https://api.groq.com/openai/v1", baseURL: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY, apiKey: process.env.GROQ_API_KEY,
maxRetries: 3, maxRetries: 3,
}); });
super(client);
this._client = client;
this.model = model; this.model = model;
this.verbose = true; this.verbose = true;
} }
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("GroqAI chat: No results!");
if (result.choices.length === 0)
throw new Error("GroqAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/** /**
* Create a completion based on the received messages. * Create a completion based on the received messages.
* *
@ -32,68 +56,49 @@ class GroqProvider extends Provider {
*/ */
async complete(messages, functions = null) { async complete(messages, functions = null) {
try { try {
const response = await this.client.chat.completions.create({ let completion;
model: this.model, if (functions.length > 0) {
// stream: true, const { toolCall, text } = await this.functionCall(
messages, messages,
...(Array.isArray(functions) && functions?.length > 0 functions,
? { functions } this.#handleFunctionCallChat.bind(this)
: {}), );
});
// Right now, we only support one completion, if (toolCall !== null) {
// so we just take the first one in the list this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
const completion = response.choices[0].message; this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
const cost = this.getCost(response.usage); return {
// treat function calls result: null,
if (completion.function_call) { functionCall: {
let functionArgs = {}; name: toolCall.name,
try { arguments: toolCall.arguments,
functionArgs = JSON.parse(completion.function_call.arguments); },
} catch (error) { cost: 0,
// call the complete function again in case it gets a json error };
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
} }
completion = { content: text };
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
} }
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
// from calling the exact same function over and over in a loop within a single chat exchange
// _but_ we should enable it to call previously used tools in a new chat interaction.
this.deduplicator.reset("runs");
return { return {
result: completion.content, result: completion.content,
cost, cost: 0,
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error; throw error;
} }
} }
@ -103,7 +108,7 @@ class GroqProvider extends Provider {
* *
* @param _usage The completion to get the cost for. * @param _usage The completion to get the cost for.
* @returns The cost of the completion. * @returns The cost of the completion.
* Stubbed since Groq has no cost basis. * Stubbed since LMStudio has no cost basis.
*/ */
getCost(_usage) { getCost(_usage) {
return 0; return 0;

View File

@ -3,26 +3,27 @@ const { PromptTemplate } = require("@langchain/core/prompts");
const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters"); const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters");
const Provider = require("../providers/ai-provider"); const Provider = require("../providers/ai-provider");
/** /**
* Summarize content using OpenAI's GPT-3.5 model. * @typedef {Object} LCSummarizationConfig
* * @property {string} provider The LLM to use for summarization (inherited)
* @param self The context of the caller function * @property {string} model The LLM Model to use for summarization (inherited)
* @param content The content to summarize. * @property {AbortController['signal']} controllerSignal Abort controller to stop recursive summarization
* @returns The summarized content. * @property {string} content The text content of the text to summarize
*/ */
const SUMMARY_MODEL = { /**
anthropic: "claude-3-opus-20240229", // 200,000 tokens * Summarize content using LLM LC-Chain call
openai: "gpt-4o", // 128,000 tokens * @param {LCSummarizationConfig} The LLM to use for summarization (inherited)
}; * @returns {Promise<string>} The summarized content.
*/
async function summarizeContent( async function summarizeContent({
provider = "openai", provider = "openai",
model = null,
controllerSignal, controllerSignal,
content content,
) { }) {
const llm = Provider.LangChainChatModel(provider, { const llm = Provider.LangChainChatModel(provider, {
temperature: 0, temperature: 0,
modelName: SUMMARY_MODEL[provider], model: model,
}); });
const textSplitter = new RecursiveCharacterTextSplitter({ const textSplitter = new RecursiveCharacterTextSplitter({