add middleware validations on embed chat

This commit is contained in:
timothycarambat 2024-02-01 15:08:30 -08:00
parent 3b9f7cb373
commit 2f0942afac
6 changed files with 250 additions and 110 deletions

View File

@ -23,11 +23,13 @@ const ChatService = {
async onopen(response) {
if (response.ok) {
return; // everything's good
} else if (
response.status >= 400 &&
response.status < 500 &&
response.status !== 429
) {
} else if (response.status >= 400) {
await response
.json()
.then((serverResponse) => {
handleChat(serverResponse);
})
.catch(() => {
handleChat({
id: v4(),
type: "abort",
@ -36,8 +38,9 @@ const ChatService = {
close: true,
error: `An error occurred while streaming response. Code ${response.status}`,
});
});
ctrl.abort();
throw new Error("Invalid Status code response.");
throw new Error();
} else {
handleChat({
id: v4(),
@ -48,7 +51,7 @@ const ChatService = {
error: `An error occurred while streaming response. Unknown Error.`,
});
ctrl.abort();
throw new Error("Unknown error");
throw new Error("Unknown Error");
}
},
async onmessage(msg) {

View File

@ -1,22 +1,25 @@
const { v4: uuidv4 } = require("uuid");
const { reqBody, multiUserMode } = require("../../utils/http");
const { Telemetry } = require("../../models/telemetry");
const {
writeResponseChunk,
VALID_CHAT_MODE,
} = require("../../utils/chats/stream");
const { writeResponseChunk } = require("../../utils/chats/stream");
const { streamChatWithForEmbed } = require("../../utils/chats/embed");
const { convertToChatHistory } = require("../../utils/chats");
const { EmbedConfig } = require("../../models/embedConfig");
const { EmbedChats } = require("../../models/embedChats");
const {
validEmbedConfig,
canRespond,
setConnectionMeta,
} = require("../../utils/middleware/embedMiddleware");
function embeddedEndpoints(app) {
if (!app) return;
// TODO: middleware
app.post("/embed/:embedId/stream-chat", async (request, response) => {
app.post(
"/embed/:embedId/stream-chat",
[validEmbedConfig, setConnectionMeta, canRespond],
async (request, response) => {
try {
const { embedId } = request.params;
const embed = response.locals.embedConfig;
const {
sessionId,
message,
@ -26,47 +29,12 @@ function embeddedEndpoints(app) {
temperature = null,
} = reqBody(request);
const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId });
if (!embed) {
response.sendStatus(400).end();
return;
}
if (!embed.enabled) {
response.status(200).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error:
"This chat has been disabled by the administrator - try again later.",
});
return;
}
if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) {
response.status(200).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: !message?.length
? "Message is empty."
: `${embed.chat_mode} is not a valid mode.`,
});
return;
}
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Content-Type", "text/event-stream");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Connection", "keep-alive");
response.flushHeaders();
// TODO Per-user and Per-day limit checks for embed_config
await streamChatWithForEmbed(response, embed, message, sessionId, {
prompt,
model,
@ -90,17 +58,16 @@ function embeddedEndpoints(app) {
});
response.end();
}
});
// TODO: middleware
app.get("/embed/:embedId/:sessionId", async (request, response) => {
try {
const { embedId, sessionId } = request.params;
const embed = await EmbedConfig.get({ uuid: embedId });
if (!embed) {
response.sendStatus(400).end();
return;
}
);
app.get(
"/embed/:embedId/:sessionId",
[validEmbedConfig],
async (request, response) => {
try {
const { sessionId } = request.params;
const embed = response.locals.embedConfig;
const history = await EmbedChats.forEmbedByUser(embed.id, sessionId);
response.status(200).json({
@ -110,7 +77,8 @@ function embeddedEndpoints(app) {
console.log(e.message, e);
response.sendStatus(500).end();
}
});
}
);
}
module.exports = { embeddedEndpoints };

View File

@ -1,13 +1,20 @@
const prisma = require("../utils/prisma");
const EmbedChats = {
new: async function ({ embedId, prompt, response = {}, sessionId }) {
new: async function ({
embedId,
prompt,
response = {},
connection_information = {},
sessionId,
}) {
try {
const chat = await prisma.embed_chats.create({
data: {
prompt,
embed_id: Number(embedId),
response: JSON.stringify(response),
connection_information: JSON.stringify(connection_information),
session_id: sessionId,
},
});

View File

@ -5,6 +5,10 @@ const EmbedConfig = {
// Used for generic updates so we can validate keys in request body
"allowlist_domains",
"allow_model_override",
"allow_temperature_override",
"allow_prompt_override",
"max_chats_per_day",
"max_chats_per_session",
"chat_mode",
],
@ -94,6 +98,20 @@ const EmbedConfig = {
return [];
}
},
// Will return null if process should be skipped
// an empty array means the system will check. This
// prevents a bad parse from allowing all requests
parseAllowedHosts: function (embed) {
if (!embed.allowlist_domains) return null;
try {
return JSON.parse(embed.allowlist_domains);
} catch {
console.error(`Failed to parse allowlist_domains for Embed ${embed.id}!`);
return [];
}
},
};
module.exports = { EmbedConfig };

View File

@ -85,8 +85,8 @@ async function streamChatWithForEmbed(
namespace: embed.workspace.slug,
input: message,
LLMConnector,
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
similarityThreshold: embed.workspace?.similarityThreshold,
topN: embed.workspace?.topN,
});
// Failed similarity search.
@ -136,7 +136,7 @@ async function streamChatWithForEmbed(
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
);
completeText = await LLMConnector.getChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
});
writeResponseChunk(response, {
uuid,
@ -148,7 +148,7 @@ async function streamChatWithForEmbed(
});
} else {
const stream = await LLMConnector.streamGetChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
});
completeText = await handleStreamResponses(response, stream, {
uuid,
@ -160,6 +160,9 @@ async function streamChatWithForEmbed(
embedId: embed.id,
prompt: message,
response: { text: completeText, type: chatMode },
connection_information: response.locals.connection
? { ...response.locals.connection }
: {},
sessionId,
});
return;
@ -233,6 +236,9 @@ async function streamEmptyEmbeddingChat({
embedId: embed.id,
prompt: message,
response: { text: completeText, type: "chat" },
connection_information: response.locals.connection
? { ...response.locals.connection }
: {},
sessionId,
});
return;

View File

@ -0,0 +1,138 @@
const { v4: uuidv4 } = require("uuid");
const { VALID_CHAT_MODE, writeResponseChunk } = require("../chats/stream");
const { EmbedChats } = require("../../models/embedChats");
const { EmbedConfig } = require("../../models/embedConfig");
const { reqBody } = require("../http");
// Finds or Aborts request for a /:embedId/ url. This should always
// be the first middleware and the :embedID should be in the URL.
async function validEmbedConfig(request, response, next) {
const { embedId } = request.params;
const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId });
if (!embed) {
response.sendStatus(404).end();
return;
}
response.locals.embedConfig = embed;
next();
}
function setConnectionMeta(request, response, next) {
response.locals.connection = {
host: request.hostname,
path: request.path,
ip: request.ip,
};
next();
}
async function canRespond(request, response, next) {
const embed = response.locals.embedConfig;
if (!embed) {
response.sendStatus(404).end();
return;
}
// Block if disabled by admin.
if (!embed.enabled) {
response.status(503).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error:
"This chat has been disabled by the administrator - try again later.",
});
return;
}
// Check if requester hostname is in the valid allowlist of domains.
const host = request.hostname;
const allowedHosts = EmbedConfig.parseAllowedHosts(embed);
if (allowedHosts !== null && !allowedHosts.includes(host)) {
response.status(401).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: "Invalid request.",
});
return;
}
const { sessionId, message } = reqBody(request);
if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) {
response.status(400).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: !message?.length
? "Message is empty."
: `${embed.chat_mode} is not a valid mode.`,
});
return;
}
if (!isNaN(embed.max_chats_per_day) && Number(embed.max_chats_per_day) > 0) {
const dailyChatCount = await EmbedChats.count({
embed_id: embed.id,
createdAt: {
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
},
});
if (dailyChatCount >= Number(embed.max_chats_per_day)) {
response.status(429).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error:
"The quota for this chat has been reached. Try again later or contact the site owner.",
});
return;
}
}
if (
!isNaN(embed.max_chats_per_session) &&
Number(embed.max_chats_per_session) > 0
) {
const dailySessionCount = await EmbedChats.count({
embed_id: embed.id,
session_id: sessionId,
createdAt: {
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
},
});
if (dailySessionCount >= Number(embed.max_chats_per_session)) {
response.status(429).json({
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error:
"Your quota for this chat has been reached. Try again later or contact the site owner.",
});
return;
}
}
next();
}
module.exports = {
setConnectionMeta,
validEmbedConfig,
canRespond,
};