2024-05-17 02:25:05 +02:00
|
|
|
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
|
2024-03-12 23:21:27 +01:00
|
|
|
const {
|
|
|
|
writeResponseChunk,
|
|
|
|
clientAbortedHandler,
|
|
|
|
} = require("../../helpers/chat/responses");
|
2023-12-28 02:08:03 +01:00
|
|
|
|
|
|
|
class GeminiLLM {
|
2024-01-17 21:59:25 +01:00
|
|
|
constructor(embedder = null, modelPreference = null) {
|
2023-12-28 02:08:03 +01:00
|
|
|
if (!process.env.GEMINI_API_KEY)
|
|
|
|
throw new Error("No Gemini API key was set.");
|
|
|
|
|
|
|
|
// Docs: https://ai.google.dev/tutorials/node_quickstart
|
|
|
|
const { GoogleGenerativeAI } = require("@google/generative-ai");
|
|
|
|
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
|
2024-01-17 21:59:25 +01:00
|
|
|
this.model =
|
|
|
|
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
|
2024-04-19 17:59:46 +02:00
|
|
|
this.gemini = genAI.getGenerativeModel(
|
|
|
|
{ model: this.model },
|
|
|
|
{
|
|
|
|
// Gemini-1.5-pro is only available on the v1beta API.
|
|
|
|
apiVersion: this.model === "gemini-1.5-pro-latest" ? "v1beta" : "v1",
|
|
|
|
}
|
|
|
|
);
|
2023-12-28 02:08:03 +01:00
|
|
|
this.limits = {
|
|
|
|
history: this.promptWindowLimit() * 0.15,
|
|
|
|
system: this.promptWindowLimit() * 0.15,
|
|
|
|
user: this.promptWindowLimit() * 0.7,
|
|
|
|
};
|
|
|
|
|
2024-05-17 02:25:05 +02:00
|
|
|
this.embedder = embedder ?? new NativeEmbedder();
|
2024-01-17 23:42:05 +01:00
|
|
|
this.defaultTemp = 0.7; // not used for Gemini
|
2024-05-20 20:17:00 +02:00
|
|
|
this.safetyThreshold = this.#fetchSafetyThreshold();
|
2023-12-28 23:42:34 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
#appendContext(contextTexts = []) {
|
|
|
|
if (!contextTexts || !contextTexts.length) return "";
|
|
|
|
return (
|
|
|
|
"\nContext:\n" +
|
|
|
|
contextTexts
|
|
|
|
.map((text, i) => {
|
|
|
|
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
|
|
|
|
})
|
|
|
|
.join("")
|
|
|
|
);
|
2023-12-28 02:08:03 +01:00
|
|
|
}
|
|
|
|
|
2024-05-20 20:17:00 +02:00
|
|
|
// BLOCK_NONE can be a special candidate for some fields
|
|
|
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#how_to_remove_automated_response_blocking_for_select_safety_attributes
|
|
|
|
// so if you are wondering why BLOCK_NONE still failed, the link above will explain why.
|
|
|
|
#fetchSafetyThreshold() {
|
|
|
|
const threshold =
|
|
|
|
process.env.GEMINI_SAFETY_SETTING ?? "BLOCK_MEDIUM_AND_ABOVE";
|
|
|
|
const safetyThresholds = [
|
|
|
|
"BLOCK_NONE",
|
|
|
|
"BLOCK_ONLY_HIGH",
|
|
|
|
"BLOCK_MEDIUM_AND_ABOVE",
|
|
|
|
"BLOCK_LOW_AND_ABOVE",
|
|
|
|
];
|
|
|
|
return safetyThresholds.includes(threshold)
|
|
|
|
? threshold
|
|
|
|
: "BLOCK_MEDIUM_AND_ABOVE";
|
|
|
|
}
|
|
|
|
|
|
|
|
#safetySettings() {
|
|
|
|
return [
|
|
|
|
{
|
|
|
|
category: "HARM_CATEGORY_HATE_SPEECH",
|
|
|
|
threshold: this.safetyThreshold,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
|
|
threshold: this.safetyThreshold,
|
|
|
|
},
|
|
|
|
{ category: "HARM_CATEGORY_HARASSMENT", threshold: this.safetyThreshold },
|
|
|
|
{
|
|
|
|
category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
|
|
threshold: this.safetyThreshold,
|
|
|
|
},
|
|
|
|
];
|
|
|
|
}
|
|
|
|
|
2023-12-28 02:08:03 +01:00
|
|
|
streamingEnabled() {
|
2024-05-02 01:52:28 +02:00
|
|
|
return "streamGetChatCompletion" in this;
|
2023-12-28 02:08:03 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
promptWindowLimit() {
|
|
|
|
switch (this.model) {
|
|
|
|
case "gemini-pro":
|
|
|
|
return 30_720;
|
2024-04-19 17:59:46 +02:00
|
|
|
case "gemini-1.5-pro-latest":
|
|
|
|
return 1_048_576;
|
2023-12-28 02:08:03 +01:00
|
|
|
default:
|
|
|
|
return 30_720; // assume a gemini-pro model
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
isValidChatCompletionModel(modelName = "") {
|
2024-04-19 17:59:46 +02:00
|
|
|
const validModels = ["gemini-pro", "gemini-1.5-pro-latest"];
|
2023-12-28 02:08:03 +01:00
|
|
|
return validModels.includes(modelName);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Moderation cannot be done with Gemini.
|
|
|
|
// Not implemented so must be stubbed
|
|
|
|
async isSafe(_input = "") {
|
|
|
|
return { safe: true, reasons: [] };
|
|
|
|
}
|
|
|
|
|
|
|
|
constructPrompt({
|
|
|
|
systemPrompt = "",
|
|
|
|
contextTexts = [],
|
|
|
|
chatHistory = [],
|
|
|
|
userPrompt = "",
|
|
|
|
}) {
|
|
|
|
const prompt = {
|
|
|
|
role: "system",
|
2023-12-28 23:42:34 +01:00
|
|
|
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
|
2023-12-28 02:08:03 +01:00
|
|
|
};
|
|
|
|
return [
|
|
|
|
prompt,
|
|
|
|
{ role: "assistant", content: "Okay." },
|
|
|
|
...chatHistory,
|
|
|
|
{ role: "USER_PROMPT", content: userPrompt },
|
|
|
|
];
|
|
|
|
}
|
|
|
|
|
|
|
|
// This will take an OpenAi format message array and only pluck valid roles from it.
|
|
|
|
formatMessages(messages = []) {
|
|
|
|
// Gemini roles are either user || model.
|
|
|
|
// and all "content" is relabeled to "parts"
|
2024-03-27 01:20:12 +01:00
|
|
|
const allMessages = messages
|
2023-12-28 02:08:03 +01:00
|
|
|
.map((message) => {
|
|
|
|
if (message.role === "system")
|
2024-04-19 17:59:46 +02:00
|
|
|
return { role: "user", parts: [{ text: message.content }] };
|
2023-12-28 02:08:03 +01:00
|
|
|
if (message.role === "user")
|
2024-04-19 17:59:46 +02:00
|
|
|
return { role: "user", parts: [{ text: message.content }] };
|
2023-12-28 02:08:03 +01:00
|
|
|
if (message.role === "assistant")
|
2024-04-19 17:59:46 +02:00
|
|
|
return { role: "model", parts: [{ text: message.content }] };
|
2023-12-28 02:08:03 +01:00
|
|
|
return null;
|
|
|
|
})
|
|
|
|
.filter((msg) => !!msg);
|
2024-03-27 01:20:12 +01:00
|
|
|
|
|
|
|
// Specifically, Google cannot have the last sent message be from a user with no assistant reply
|
|
|
|
// otherwise it will crash. So if the last item is from the user, it was not completed so pop it off
|
|
|
|
// the history.
|
|
|
|
if (
|
|
|
|
allMessages.length > 0 &&
|
|
|
|
allMessages[allMessages.length - 1].role === "user"
|
|
|
|
)
|
|
|
|
allMessages.pop();
|
2024-05-11 02:33:25 +02:00
|
|
|
|
|
|
|
// Validate that after every user message, there is a model message
|
|
|
|
// sometimes when using gemini we try to compress messages in order to retain as
|
|
|
|
// much context as possible but this may mess up the order of the messages that the gemini model expects
|
|
|
|
// we do this check to work around the edge case where 2 user prompts may be next to each other, in the message array
|
|
|
|
for (let i = 0; i < allMessages.length; i++) {
|
|
|
|
if (
|
|
|
|
allMessages[i].role === "user" &&
|
|
|
|
i < allMessages.length - 1 &&
|
|
|
|
allMessages[i + 1].role !== "model"
|
|
|
|
) {
|
|
|
|
allMessages.splice(i + 1, 0, {
|
|
|
|
role: "model",
|
|
|
|
parts: [{ text: "Okay." }],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-27 01:20:12 +01:00
|
|
|
return allMessages;
|
2023-12-28 02:08:03 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
async getChatCompletion(messages = [], _opts = {}) {
|
|
|
|
if (!this.isValidChatCompletionModel(this.model))
|
|
|
|
throw new Error(
|
|
|
|
`Gemini chat: ${this.model} is not valid for chat completion!`
|
|
|
|
);
|
|
|
|
|
|
|
|
const prompt = messages.find(
|
|
|
|
(chat) => chat.role === "USER_PROMPT"
|
|
|
|
)?.content;
|
|
|
|
const chatThread = this.gemini.startChat({
|
|
|
|
history: this.formatMessages(messages),
|
2024-05-20 20:17:00 +02:00
|
|
|
safetySettings: this.#safetySettings(),
|
2023-12-28 02:08:03 +01:00
|
|
|
});
|
|
|
|
const result = await chatThread.sendMessage(prompt);
|
|
|
|
const response = result.response;
|
|
|
|
const responseText = response.text();
|
|
|
|
|
|
|
|
if (!responseText) throw new Error("Gemini: No response could be parsed.");
|
|
|
|
|
|
|
|
return responseText;
|
|
|
|
}
|
|
|
|
|
|
|
|
async streamGetChatCompletion(messages = [], _opts = {}) {
|
|
|
|
if (!this.isValidChatCompletionModel(this.model))
|
|
|
|
throw new Error(
|
|
|
|
`Gemini chat: ${this.model} is not valid for chat completion!`
|
|
|
|
);
|
|
|
|
|
|
|
|
const prompt = messages.find(
|
|
|
|
(chat) => chat.role === "USER_PROMPT"
|
|
|
|
)?.content;
|
|
|
|
const chatThread = this.gemini.startChat({
|
|
|
|
history: this.formatMessages(messages),
|
2024-05-20 20:17:00 +02:00
|
|
|
safetySettings: this.#safetySettings(),
|
2023-12-28 02:08:03 +01:00
|
|
|
});
|
|
|
|
const responseStream = await chatThread.sendMessageStream(prompt);
|
|
|
|
if (!responseStream.stream)
|
|
|
|
throw new Error("Could not stream response stream from Gemini.");
|
|
|
|
|
2024-02-07 17:15:14 +01:00
|
|
|
return responseStream.stream;
|
2023-12-28 02:08:03 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
async compressMessages(promptArgs = {}, rawHistory = []) {
|
|
|
|
const { messageArrayCompressor } = require("../../helpers/chat");
|
|
|
|
const messageArray = this.constructPrompt(promptArgs);
|
|
|
|
return await messageArrayCompressor(this, messageArray, rawHistory);
|
|
|
|
}
|
|
|
|
|
2024-02-07 17:15:14 +01:00
|
|
|
handleStream(response, stream, responseProps) {
|
|
|
|
const { uuid = uuidv4(), sources = [] } = responseProps;
|
|
|
|
|
|
|
|
return new Promise(async (resolve) => {
|
|
|
|
let fullText = "";
|
2024-03-12 23:21:27 +01:00
|
|
|
|
|
|
|
// Establish listener to early-abort a streaming response
|
|
|
|
// in case things go sideways or the user does not like the response.
|
|
|
|
// We preserve the generated text but continue as if chat was completed
|
|
|
|
// to preserve previously generated content.
|
|
|
|
const handleAbort = () => clientAbortedHandler(resolve, fullText);
|
|
|
|
response.on("close", handleAbort);
|
|
|
|
|
2024-02-07 17:15:14 +01:00
|
|
|
for await (const chunk of stream) {
|
2024-03-27 01:20:12 +01:00
|
|
|
let chunkText;
|
|
|
|
try {
|
|
|
|
// Due to content sensitivity we cannot always get the function .text();
|
|
|
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#gemini-TASK-samples-nodejs
|
|
|
|
// and it is not possible to unblock or disable this safety protocol without being allowlisted by Google.
|
|
|
|
chunkText = chunk.text();
|
|
|
|
} catch (e) {
|
|
|
|
chunkText = e.message;
|
|
|
|
writeResponseChunk(response, {
|
|
|
|
uuid,
|
|
|
|
sources: [],
|
|
|
|
type: "abort",
|
|
|
|
textResponse: null,
|
|
|
|
close: true,
|
|
|
|
error: e.message,
|
|
|
|
});
|
|
|
|
resolve(e.message);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
fullText += chunkText;
|
2024-02-07 17:15:14 +01:00
|
|
|
writeResponseChunk(response, {
|
|
|
|
uuid,
|
|
|
|
sources: [],
|
|
|
|
type: "textResponseChunk",
|
|
|
|
textResponse: chunk.text(),
|
|
|
|
close: false,
|
|
|
|
error: false,
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
writeResponseChunk(response, {
|
|
|
|
uuid,
|
|
|
|
sources,
|
|
|
|
type: "textResponseChunk",
|
|
|
|
textResponse: "",
|
|
|
|
close: true,
|
|
|
|
error: false,
|
|
|
|
});
|
2024-03-12 23:21:27 +01:00
|
|
|
response.removeListener("close", handleAbort);
|
2024-02-07 17:15:14 +01:00
|
|
|
resolve(fullText);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-12-28 02:08:03 +01:00
|
|
|
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
|
|
|
|
async embedTextInput(textInput) {
|
|
|
|
return await this.embedder.embedTextInput(textInput);
|
|
|
|
}
|
|
|
|
async embedChunks(textChunks = []) {
|
|
|
|
return await this.embedder.embedChunks(textChunks);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
module.exports = {
|
|
|
|
GeminiLLM,
|
|
|
|
};
|