Llm chore cleanup (#501)

* move internal functions to private in class
simplify lc message convertor

* Fix hanging Context text when none is present
This commit is contained in:
Timothy Carambat 2023-12-28 14:42:34 -08:00 committed by GitHub
parent 2a1202de54
commit 6d5968bf7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 129 additions and 98 deletions

View File

@ -27,6 +27,18 @@ class AzureOpenAiLLM {
this.embedder = !embedder ? new AzureOpenAiEmbedder() : embedder; this.embedder = !embedder ? new AzureOpenAiEmbedder() : embedder;
} }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -55,13 +67,7 @@ class AzureOpenAiLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }

View File

@ -1,4 +1,3 @@
const { v4 } = require("uuid");
const { chatPrompt } = require("../../chats"); const { chatPrompt } = require("../../chats");
class GeminiLLM { class GeminiLLM {
@ -22,7 +21,18 @@ class GeminiLLM {
"INVALID GEMINI LLM SETUP. No embedding engine has been set. Go to instance settings and set up an embedding interface to use Gemini as your LLM." "INVALID GEMINI LLM SETUP. No embedding engine has been set. Go to instance settings and set up an embedding interface to use Gemini as your LLM."
); );
this.embedder = embedder; this.embedder = embedder;
this.answerKey = v4().split("-")[0]; }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
} }
streamingEnabled() { streamingEnabled() {
@ -57,13 +67,7 @@ class GeminiLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [ return [
prompt, prompt,

View File

@ -27,6 +27,18 @@ class LMStudioLLM {
this.embedder = embedder; this.embedder = embedder;
} }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -54,13 +66,7 @@ class LMStudioLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }

View File

@ -29,6 +29,18 @@ class LocalAiLLM {
this.embedder = embedder; this.embedder = embedder;
} }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -54,13 +66,7 @@ class LocalAiLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }

View File

@ -1,8 +1,6 @@
const os = require("os");
const fs = require("fs"); const fs = require("fs");
const path = require("path"); const path = require("path");
const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { HumanMessage, SystemMessage, AIMessage } = require("langchain/schema");
const { chatPrompt } = require("../../chats"); const { chatPrompt } = require("../../chats");
// Docs: https://api.js.langchain.com/classes/chat_models_llama_cpp.ChatLlamaCpp.html // Docs: https://api.js.langchain.com/classes/chat_models_llama_cpp.ChatLlamaCpp.html
@ -29,12 +27,6 @@ class NativeLLM {
: path.resolve(__dirname, `../../../storage/models/downloaded`) : path.resolve(__dirname, `../../../storage/models/downloaded`)
); );
// Set ENV for if llama.cpp needs to rebuild at runtime and machine is not
// running Apple Silicon.
process.env.NODE_LLAMA_CPP_METAL = os
.cpus()
.some((cpu) => cpu.model.includes("Apple"));
// Make directory when it does not exist in existing installations // Make directory when it does not exist in existing installations
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
} }
@ -56,12 +48,46 @@ class NativeLLM {
// If the model has been loaded once, it is in the memory now // If the model has been loaded once, it is in the memory now
// so we can skip re-loading it and instead go straight to inference. // so we can skip re-loading it and instead go straight to inference.
// Note: this will break temperature setting hopping between workspaces with different temps. // Note: this will break temperature setting hopping between workspaces with different temps.
async llamaClient({ temperature = 0.7 }) { async #llamaClient({ temperature = 0.7 }) {
if (global.llamaModelInstance) return global.llamaModelInstance; if (global.llamaModelInstance) return global.llamaModelInstance;
await this.#initializeLlamaModel(temperature); await this.#initializeLlamaModel(temperature);
return global.llamaModelInstance; return global.llamaModelInstance;
} }
#convertToLangchainPrototypes(chats = []) {
const {
HumanMessage,
SystemMessage,
AIMessage,
} = require("langchain/schema");
const langchainChats = [];
const roleToMessageMap = {
system: SystemMessage,
user: HumanMessage,
assistant: AIMessage,
};
for (const chat of chats) {
if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
const MessageClass = roleToMessageMap[chat.role];
langchainChats.push(new MessageClass({ content: chat.content }));
}
return langchainChats;
}
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -84,13 +110,7 @@ class NativeLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }
@ -111,7 +131,7 @@ Context:
rawHistory rawHistory
); );
const model = await this.llamaClient({ const model = await this.#llamaClient({
temperature: Number(workspace?.openAiTemp ?? 0.7), temperature: Number(workspace?.openAiTemp ?? 0.7),
}); });
const response = await model.call(messages); const response = await model.call(messages);
@ -124,7 +144,7 @@ Context:
} }
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
const model = await this.llamaClient({ const model = await this.#llamaClient({
temperature: Number(workspace?.openAiTemp ?? 0.7), temperature: Number(workspace?.openAiTemp ?? 0.7),
}); });
const messages = await this.compressMessages( const messages = await this.compressMessages(
@ -140,13 +160,13 @@ Context:
} }
async getChatCompletion(messages = null, { temperature = 0.7 }) { async getChatCompletion(messages = null, { temperature = 0.7 }) {
const model = await this.llamaClient({ temperature }); const model = await this.#llamaClient({ temperature });
const response = await model.call(messages); const response = await model.call(messages);
return response.content; return response.content;
} }
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
const model = await this.llamaClient({ temperature }); const model = await this.#llamaClient({ temperature });
const responseStream = await model.stream(messages); const responseStream = await model.stream(messages);
return responseStream; return responseStream;
} }
@ -167,27 +187,7 @@ Context:
messageArray, messageArray,
rawHistory rawHistory
); );
return this.convertToLangchainPrototypes(compressedMessages); return this.#convertToLangchainPrototypes(compressedMessages);
}
convertToLangchainPrototypes(chats = []) {
const langchainChats = [];
for (const chat of chats) {
switch (chat.role) {
case "system":
langchainChats.push(new SystemMessage({ content: chat.content }));
break;
case "user":
langchainChats.push(new HumanMessage({ content: chat.content }));
break;
case "assistant":
langchainChats.push(new AIMessage({ content: chat.content }));
break;
default:
break;
}
}
return langchainChats;
} }
} }

View File

@ -40,24 +40,33 @@ class OllamaAILLM {
AIMessage, AIMessage,
} = require("langchain/schema"); } = require("langchain/schema");
const langchainChats = []; const langchainChats = [];
const roleToMessageMap = {
system: SystemMessage,
user: HumanMessage,
assistant: AIMessage,
};
for (const chat of chats) { for (const chat of chats) {
switch (chat.role) { if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
case "system": const MessageClass = roleToMessageMap[chat.role];
langchainChats.push(new SystemMessage({ content: chat.content })); langchainChats.push(new MessageClass({ content: chat.content }));
break;
case "user":
langchainChats.push(new HumanMessage({ content: chat.content }));
break;
case "assistant":
langchainChats.push(new AIMessage({ content: chat.content }));
break;
default:
break;
}
} }
return langchainChats; return langchainChats;
} }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -83,13 +92,7 @@ class OllamaAILLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }

View File

@ -24,6 +24,18 @@ class OpenAiLLM {
this.embedder = !embedder ? new OpenAiEmbedder() : embedder; this.embedder = !embedder ? new OpenAiEmbedder() : embedder;
} }
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() { streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this; return "streamChat" in this && "streamGetChatCompletion" in this;
} }
@ -68,13 +80,7 @@ class OpenAiLLM {
}) { }) {
const prompt = { const prompt = {
role: "system", role: "system",
content: `${systemPrompt} content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
Context:
${contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")}`,
}; };
return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
} }