mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-10 17:00:11 +01:00
add middleware validations on embed chat
This commit is contained in:
parent
3b9f7cb373
commit
2f0942afac
@ -23,11 +23,13 @@ const ChatService = {
|
||||
async onopen(response) {
|
||||
if (response.ok) {
|
||||
return; // everything's good
|
||||
} else if (
|
||||
response.status >= 400 &&
|
||||
response.status < 500 &&
|
||||
response.status !== 429
|
||||
) {
|
||||
} else if (response.status >= 400) {
|
||||
await response
|
||||
.json()
|
||||
.then((serverResponse) => {
|
||||
handleChat(serverResponse);
|
||||
})
|
||||
.catch(() => {
|
||||
handleChat({
|
||||
id: v4(),
|
||||
type: "abort",
|
||||
@ -36,8 +38,9 @@ const ChatService = {
|
||||
close: true,
|
||||
error: `An error occurred while streaming response. Code ${response.status}`,
|
||||
});
|
||||
});
|
||||
ctrl.abort();
|
||||
throw new Error("Invalid Status code response.");
|
||||
throw new Error();
|
||||
} else {
|
||||
handleChat({
|
||||
id: v4(),
|
||||
@ -48,7 +51,7 @@ const ChatService = {
|
||||
error: `An error occurred while streaming response. Unknown Error.`,
|
||||
});
|
||||
ctrl.abort();
|
||||
throw new Error("Unknown error");
|
||||
throw new Error("Unknown Error");
|
||||
}
|
||||
},
|
||||
async onmessage(msg) {
|
||||
|
@ -1,22 +1,25 @@
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { reqBody, multiUserMode } = require("../../utils/http");
|
||||
const { Telemetry } = require("../../models/telemetry");
|
||||
const {
|
||||
writeResponseChunk,
|
||||
VALID_CHAT_MODE,
|
||||
} = require("../../utils/chats/stream");
|
||||
const { writeResponseChunk } = require("../../utils/chats/stream");
|
||||
const { streamChatWithForEmbed } = require("../../utils/chats/embed");
|
||||
const { convertToChatHistory } = require("../../utils/chats");
|
||||
const { EmbedConfig } = require("../../models/embedConfig");
|
||||
const { EmbedChats } = require("../../models/embedChats");
|
||||
const {
|
||||
validEmbedConfig,
|
||||
canRespond,
|
||||
setConnectionMeta,
|
||||
} = require("../../utils/middleware/embedMiddleware");
|
||||
|
||||
function embeddedEndpoints(app) {
|
||||
if (!app) return;
|
||||
|
||||
// TODO: middleware
|
||||
app.post("/embed/:embedId/stream-chat", async (request, response) => {
|
||||
app.post(
|
||||
"/embed/:embedId/stream-chat",
|
||||
[validEmbedConfig, setConnectionMeta, canRespond],
|
||||
async (request, response) => {
|
||||
try {
|
||||
const { embedId } = request.params;
|
||||
const embed = response.locals.embedConfig;
|
||||
const {
|
||||
sessionId,
|
||||
message,
|
||||
@ -26,47 +29,12 @@ function embeddedEndpoints(app) {
|
||||
temperature = null,
|
||||
} = reqBody(request);
|
||||
|
||||
const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId });
|
||||
if (!embed) {
|
||||
response.sendStatus(400).end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!embed.enabled) {
|
||||
response.status(200).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error:
|
||||
"This chat has been disabled by the administrator - try again later.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) {
|
||||
response.status(200).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error: !message?.length
|
||||
? "Message is empty."
|
||||
: `${embed.chat_mode} is not a valid mode.`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Content-Type", "text/event-stream");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.flushHeaders();
|
||||
|
||||
// TODO Per-user and Per-day limit checks for embed_config
|
||||
|
||||
await streamChatWithForEmbed(response, embed, message, sessionId, {
|
||||
prompt,
|
||||
model,
|
||||
@ -90,17 +58,16 @@ function embeddedEndpoints(app) {
|
||||
});
|
||||
response.end();
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: middleware
|
||||
app.get("/embed/:embedId/:sessionId", async (request, response) => {
|
||||
try {
|
||||
const { embedId, sessionId } = request.params;
|
||||
const embed = await EmbedConfig.get({ uuid: embedId });
|
||||
if (!embed) {
|
||||
response.sendStatus(400).end();
|
||||
return;
|
||||
}
|
||||
);
|
||||
|
||||
app.get(
|
||||
"/embed/:embedId/:sessionId",
|
||||
[validEmbedConfig],
|
||||
async (request, response) => {
|
||||
try {
|
||||
const { sessionId } = request.params;
|
||||
const embed = response.locals.embedConfig;
|
||||
|
||||
const history = await EmbedChats.forEmbedByUser(embed.id, sessionId);
|
||||
response.status(200).json({
|
||||
@ -110,7 +77,8 @@ function embeddedEndpoints(app) {
|
||||
console.log(e.message, e);
|
||||
response.sendStatus(500).end();
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
module.exports = { embeddedEndpoints };
|
||||
|
@ -1,13 +1,20 @@
|
||||
const prisma = require("../utils/prisma");
|
||||
|
||||
const EmbedChats = {
|
||||
new: async function ({ embedId, prompt, response = {}, sessionId }) {
|
||||
new: async function ({
|
||||
embedId,
|
||||
prompt,
|
||||
response = {},
|
||||
connection_information = {},
|
||||
sessionId,
|
||||
}) {
|
||||
try {
|
||||
const chat = await prisma.embed_chats.create({
|
||||
data: {
|
||||
prompt,
|
||||
embed_id: Number(embedId),
|
||||
response: JSON.stringify(response),
|
||||
connection_information: JSON.stringify(connection_information),
|
||||
session_id: sessionId,
|
||||
},
|
||||
});
|
||||
|
@ -5,6 +5,10 @@ const EmbedConfig = {
|
||||
// Used for generic updates so we can validate keys in request body
|
||||
"allowlist_domains",
|
||||
"allow_model_override",
|
||||
"allow_temperature_override",
|
||||
"allow_prompt_override",
|
||||
"max_chats_per_day",
|
||||
"max_chats_per_session",
|
||||
"chat_mode",
|
||||
],
|
||||
|
||||
@ -94,6 +98,20 @@ const EmbedConfig = {
|
||||
return [];
|
||||
}
|
||||
},
|
||||
|
||||
// Will return null if process should be skipped
|
||||
// an empty array means the system will check. This
|
||||
// prevents a bad parse from allowing all requests
|
||||
parseAllowedHosts: function (embed) {
|
||||
if (!embed.allowlist_domains) return null;
|
||||
|
||||
try {
|
||||
return JSON.parse(embed.allowlist_domains);
|
||||
} catch {
|
||||
console.error(`Failed to parse allowlist_domains for Embed ${embed.id}!`);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = { EmbedConfig };
|
||||
|
@ -85,8 +85,8 @@ async function streamChatWithForEmbed(
|
||||
namespace: embed.workspace.slug,
|
||||
input: message,
|
||||
LLMConnector,
|
||||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
similarityThreshold: embed.workspace?.similarityThreshold,
|
||||
topN: embed.workspace?.topN,
|
||||
});
|
||||
|
||||
// Failed similarity search.
|
||||
@ -136,7 +136,7 @@ async function streamChatWithForEmbed(
|
||||
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
|
||||
);
|
||||
completeText = await LLMConnector.getChatCompletion(messages, {
|
||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||
temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||
});
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
@ -148,7 +148,7 @@ async function streamChatWithForEmbed(
|
||||
});
|
||||
} else {
|
||||
const stream = await LLMConnector.streamGetChatCompletion(messages, {
|
||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||
temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||
});
|
||||
completeText = await handleStreamResponses(response, stream, {
|
||||
uuid,
|
||||
@ -160,6 +160,9 @@ async function streamChatWithForEmbed(
|
||||
embedId: embed.id,
|
||||
prompt: message,
|
||||
response: { text: completeText, type: chatMode },
|
||||
connection_information: response.locals.connection
|
||||
? { ...response.locals.connection }
|
||||
: {},
|
||||
sessionId,
|
||||
});
|
||||
return;
|
||||
@ -233,6 +236,9 @@ async function streamEmptyEmbeddingChat({
|
||||
embedId: embed.id,
|
||||
prompt: message,
|
||||
response: { text: completeText, type: "chat" },
|
||||
connection_information: response.locals.connection
|
||||
? { ...response.locals.connection }
|
||||
: {},
|
||||
sessionId,
|
||||
});
|
||||
return;
|
||||
|
138
server/utils/middleware/embedMiddleware.js
Normal file
138
server/utils/middleware/embedMiddleware.js
Normal file
@ -0,0 +1,138 @@
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { VALID_CHAT_MODE, writeResponseChunk } = require("../chats/stream");
|
||||
const { EmbedChats } = require("../../models/embedChats");
|
||||
const { EmbedConfig } = require("../../models/embedConfig");
|
||||
const { reqBody } = require("../http");
|
||||
|
||||
// Finds or Aborts request for a /:embedId/ url. This should always
|
||||
// be the first middleware and the :embedID should be in the URL.
|
||||
async function validEmbedConfig(request, response, next) {
|
||||
const { embedId } = request.params;
|
||||
|
||||
const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId });
|
||||
if (!embed) {
|
||||
response.sendStatus(404).end();
|
||||
return;
|
||||
}
|
||||
|
||||
response.locals.embedConfig = embed;
|
||||
next();
|
||||
}
|
||||
|
||||
function setConnectionMeta(request, response, next) {
|
||||
response.locals.connection = {
|
||||
host: request.hostname,
|
||||
path: request.path,
|
||||
ip: request.ip,
|
||||
};
|
||||
next();
|
||||
}
|
||||
|
||||
async function canRespond(request, response, next) {
|
||||
const embed = response.locals.embedConfig;
|
||||
if (!embed) {
|
||||
response.sendStatus(404).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Block if disabled by admin.
|
||||
if (!embed.enabled) {
|
||||
response.status(503).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error:
|
||||
"This chat has been disabled by the administrator - try again later.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if requester hostname is in the valid allowlist of domains.
|
||||
const host = request.hostname;
|
||||
const allowedHosts = EmbedConfig.parseAllowedHosts(embed);
|
||||
if (allowedHosts !== null && !allowedHosts.includes(host)) {
|
||||
response.status(401).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error: "Invalid request.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { sessionId, message } = reqBody(request);
|
||||
|
||||
if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) {
|
||||
response.status(400).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error: !message?.length
|
||||
? "Message is empty."
|
||||
: `${embed.chat_mode} is not a valid mode.`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isNaN(embed.max_chats_per_day) && Number(embed.max_chats_per_day) > 0) {
|
||||
const dailyChatCount = await EmbedChats.count({
|
||||
embed_id: embed.id,
|
||||
createdAt: {
|
||||
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
|
||||
},
|
||||
});
|
||||
|
||||
if (dailyChatCount >= Number(embed.max_chats_per_day)) {
|
||||
response.status(429).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error:
|
||||
"The quota for this chat has been reached. Try again later or contact the site owner.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!isNaN(embed.max_chats_per_session) &&
|
||||
Number(embed.max_chats_per_session) > 0
|
||||
) {
|
||||
const dailySessionCount = await EmbedChats.count({
|
||||
embed_id: embed.id,
|
||||
session_id: sessionId,
|
||||
createdAt: {
|
||||
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
|
||||
},
|
||||
});
|
||||
|
||||
if (dailySessionCount >= Number(embed.max_chats_per_session)) {
|
||||
response.status(429).json({
|
||||
id: uuidv4(),
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
close: true,
|
||||
error:
|
||||
"Your quota for this chat has been reached. Try again later or contact the site owner.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
next();
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
setConnectionMeta,
|
||||
validEmbedConfig,
|
||||
canRespond,
|
||||
};
|
Loading…
Reference in New Issue
Block a user