From 2f0942afac9630ace38338e0e4d3882686177dac Mon Sep 17 00:00:00 2001 From: timothycarambat Date: Thu, 1 Feb 2024 15:08:30 -0800 Subject: [PATCH] add middleware validations on embed chat --- embed/src/models/chatService.js | 33 ++--- server/endpoints/embed/index.js | 148 ++++++++------------- server/models/embedChats.js | 9 +- server/models/embedConfig.js | 18 +++ server/utils/chats/embed.js | 14 +- server/utils/middleware/embedMiddleware.js | 138 +++++++++++++++++++ 6 files changed, 250 insertions(+), 110 deletions(-) create mode 100644 server/utils/middleware/embedMiddleware.js diff --git a/embed/src/models/chatService.js b/embed/src/models/chatService.js index 5590b3b2..a53b857c 100644 --- a/embed/src/models/chatService.js +++ b/embed/src/models/chatService.js @@ -23,21 +23,24 @@ const ChatService = { async onopen(response) { if (response.ok) { return; // everything's good - } else if ( - response.status >= 400 && - response.status < 500 && - response.status !== 429 - ) { - handleChat({ - id: v4(), - type: "abort", - textResponse: null, - sources: [], - close: true, - error: `An error occurred while streaming response. Code ${response.status}`, - }); + } else if (response.status >= 400) { + await response + .json() + .then((serverResponse) => { + handleChat(serverResponse); + }) + .catch(() => { + handleChat({ + id: v4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `An error occurred while streaming response. Code ${response.status}`, + }); + }); ctrl.abort(); - throw new Error("Invalid Status code response."); + throw new Error(); } else { handleChat({ id: v4(), @@ -48,7 +51,7 @@ const ChatService = { error: `An error occurred while streaming response. Unknown Error.`, }); ctrl.abort(); - throw new Error("Unknown error"); + throw new Error("Unknown Error"); } }, async onmessage(msg) { diff --git a/server/endpoints/embed/index.js b/server/endpoints/embed/index.js index 9b0c1f3b..348a1c4b 100644 --- a/server/endpoints/embed/index.js +++ b/server/endpoints/embed/index.js @@ -1,116 +1,84 @@ const { v4: uuidv4 } = require("uuid"); const { reqBody, multiUserMode } = require("../../utils/http"); const { Telemetry } = require("../../models/telemetry"); -const { - writeResponseChunk, - VALID_CHAT_MODE, -} = require("../../utils/chats/stream"); +const { writeResponseChunk } = require("../../utils/chats/stream"); const { streamChatWithForEmbed } = require("../../utils/chats/embed"); const { convertToChatHistory } = require("../../utils/chats"); -const { EmbedConfig } = require("../../models/embedConfig"); const { EmbedChats } = require("../../models/embedChats"); +const { + validEmbedConfig, + canRespond, + setConnectionMeta, +} = require("../../utils/middleware/embedMiddleware"); function embeddedEndpoints(app) { if (!app) return; - // TODO: middleware - app.post("/embed/:embedId/stream-chat", async (request, response) => { - try { - const { embedId } = request.params; - const { - sessionId, - message, - // optional keys for override of defaults if enabled. - prompt = null, - model = null, - temperature = null, - } = reqBody(request); + app.post( + "/embed/:embedId/stream-chat", + [validEmbedConfig, setConnectionMeta, canRespond], + async (request, response) => { + try { + const embed = response.locals.embedConfig; + const { + sessionId, + message, + // optional keys for override of defaults if enabled. + prompt = null, + model = null, + temperature = null, + } = reqBody(request); - const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId }); - if (!embed) { - response.sendStatus(400).end(); - return; - } + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Content-Type", "text/event-stream"); + response.setHeader("Access-Control-Allow-Origin", "*"); + response.setHeader("Connection", "keep-alive"); + response.flushHeaders(); - if (!embed.enabled) { - response.status(200).json({ + await streamChatWithForEmbed(response, embed, message, sessionId, { + prompt, + model, + temperature, + }); + await Telemetry.sendTelemetry("embed_sent_chat", { + multiUserMode: multiUserMode(response), + LLMSelection: process.env.LLM_PROVIDER || "openai", + Embedder: process.env.EMBEDDING_ENGINE || "inherit", + VectorDbSelection: process.env.VECTOR_DB || "pinecone", + }); + response.end(); + } catch (e) { + console.error(e); + writeResponseChunk(response, { id: uuidv4(), type: "abort", textResponse: null, - sources: [], close: true, - error: - "This chat has been disabled by the administrator - try again later.", + error: e.message, }); - return; + response.end(); } + } + ); - if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) { + app.get( + "/embed/:embedId/:sessionId", + [validEmbedConfig], + async (request, response) => { + try { + const { sessionId } = request.params; + const embed = response.locals.embedConfig; + + const history = await EmbedChats.forEmbedByUser(embed.id, sessionId); response.status(200).json({ - id: uuidv4(), - type: "abort", - textResponse: null, - sources: [], - close: true, - error: !message?.length - ? "Message is empty." - : `${embed.chat_mode} is not a valid mode.`, + history: convertToChatHistory(history), }); - return; + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); } - - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Content-Type", "text/event-stream"); - response.setHeader("Access-Control-Allow-Origin", "*"); - response.setHeader("Connection", "keep-alive"); - response.flushHeaders(); - - // TODO Per-user and Per-day limit checks for embed_config - - await streamChatWithForEmbed(response, embed, message, sessionId, { - prompt, - model, - temperature, - }); - await Telemetry.sendTelemetry("embed_sent_chat", { - multiUserMode: multiUserMode(response), - LLMSelection: process.env.LLM_PROVIDER || "openai", - Embedder: process.env.EMBEDDING_ENGINE || "inherit", - VectorDbSelection: process.env.VECTOR_DB || "pinecone", - }); - response.end(); - } catch (e) { - console.error(e); - writeResponseChunk(response, { - id: uuidv4(), - type: "abort", - textResponse: null, - close: true, - error: e.message, - }); - response.end(); } - }); - - // TODO: middleware - app.get("/embed/:embedId/:sessionId", async (request, response) => { - try { - const { embedId, sessionId } = request.params; - const embed = await EmbedConfig.get({ uuid: embedId }); - if (!embed) { - response.sendStatus(400).end(); - return; - } - - const history = await EmbedChats.forEmbedByUser(embed.id, sessionId); - response.status(200).json({ - history: convertToChatHistory(history), - }); - } catch (e) { - console.log(e.message, e); - response.sendStatus(500).end(); - } - }); + ); } module.exports = { embeddedEndpoints }; diff --git a/server/models/embedChats.js b/server/models/embedChats.js index 46bc4ef1..bdbc0dcb 100644 --- a/server/models/embedChats.js +++ b/server/models/embedChats.js @@ -1,13 +1,20 @@ const prisma = require("../utils/prisma"); const EmbedChats = { - new: async function ({ embedId, prompt, response = {}, sessionId }) { + new: async function ({ + embedId, + prompt, + response = {}, + connection_information = {}, + sessionId, + }) { try { const chat = await prisma.embed_chats.create({ data: { prompt, embed_id: Number(embedId), response: JSON.stringify(response), + connection_information: JSON.stringify(connection_information), session_id: sessionId, }, }); diff --git a/server/models/embedConfig.js b/server/models/embedConfig.js index 16b1e6c6..f627f760 100644 --- a/server/models/embedConfig.js +++ b/server/models/embedConfig.js @@ -5,6 +5,10 @@ const EmbedConfig = { // Used for generic updates so we can validate keys in request body "allowlist_domains", "allow_model_override", + "allow_temperature_override", + "allow_prompt_override", + "max_chats_per_day", + "max_chats_per_session", "chat_mode", ], @@ -94,6 +98,20 @@ const EmbedConfig = { return []; } }, + + // Will return null if process should be skipped + // an empty array means the system will check. This + // prevents a bad parse from allowing all requests + parseAllowedHosts: function (embed) { + if (!embed.allowlist_domains) return null; + + try { + return JSON.parse(embed.allowlist_domains); + } catch { + console.error(`Failed to parse allowlist_domains for Embed ${embed.id}!`); + return []; + } + }, }; module.exports = { EmbedConfig }; diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index e08c1cdc..fb119b72 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -85,8 +85,8 @@ async function streamChatWithForEmbed( namespace: embed.workspace.slug, input: message, LLMConnector, - similarityThreshold: workspace?.similarityThreshold, - topN: workspace?.topN, + similarityThreshold: embed.workspace?.similarityThreshold, + topN: embed.workspace?.topN, }); // Failed similarity search. @@ -136,7 +136,7 @@ async function streamChatWithForEmbed( `\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.` ); completeText = await LLMConnector.getChatCompletion(messages, { - temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp, + temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp, }); writeResponseChunk(response, { uuid, @@ -148,7 +148,7 @@ async function streamChatWithForEmbed( }); } else { const stream = await LLMConnector.streamGetChatCompletion(messages, { - temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp, + temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp, }); completeText = await handleStreamResponses(response, stream, { uuid, @@ -160,6 +160,9 @@ async function streamChatWithForEmbed( embedId: embed.id, prompt: message, response: { text: completeText, type: chatMode }, + connection_information: response.locals.connection + ? { ...response.locals.connection } + : {}, sessionId, }); return; @@ -233,6 +236,9 @@ async function streamEmptyEmbeddingChat({ embedId: embed.id, prompt: message, response: { text: completeText, type: "chat" }, + connection_information: response.locals.connection + ? { ...response.locals.connection } + : {}, sessionId, }); return; diff --git a/server/utils/middleware/embedMiddleware.js b/server/utils/middleware/embedMiddleware.js new file mode 100644 index 00000000..e9636f85 --- /dev/null +++ b/server/utils/middleware/embedMiddleware.js @@ -0,0 +1,138 @@ +const { v4: uuidv4 } = require("uuid"); +const { VALID_CHAT_MODE, writeResponseChunk } = require("../chats/stream"); +const { EmbedChats } = require("../../models/embedChats"); +const { EmbedConfig } = require("../../models/embedConfig"); +const { reqBody } = require("../http"); + +// Finds or Aborts request for a /:embedId/ url. This should always +// be the first middleware and the :embedID should be in the URL. +async function validEmbedConfig(request, response, next) { + const { embedId } = request.params; + + const embed = await EmbedConfig.getWithWorkspace({ uuid: embedId }); + if (!embed) { + response.sendStatus(404).end(); + return; + } + + response.locals.embedConfig = embed; + next(); +} + +function setConnectionMeta(request, response, next) { + response.locals.connection = { + host: request.hostname, + path: request.path, + ip: request.ip, + }; + next(); +} + +async function canRespond(request, response, next) { + const embed = response.locals.embedConfig; + if (!embed) { + response.sendStatus(404).end(); + return; + } + + // Block if disabled by admin. + if (!embed.enabled) { + response.status(503).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: + "This chat has been disabled by the administrator - try again later.", + }); + return; + } + + // Check if requester hostname is in the valid allowlist of domains. + const host = request.hostname; + const allowedHosts = EmbedConfig.parseAllowedHosts(embed); + if (allowedHosts !== null && !allowedHosts.includes(host)) { + response.status(401).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: "Invalid request.", + }); + return; + } + + const { sessionId, message } = reqBody(request); + + if (!message?.length || !VALID_CHAT_MODE.includes(embed.chat_mode)) { + response.status(400).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: !message?.length + ? "Message is empty." + : `${embed.chat_mode} is not a valid mode.`, + }); + return; + } + + if (!isNaN(embed.max_chats_per_day) && Number(embed.max_chats_per_day) > 0) { + const dailyChatCount = await EmbedChats.count({ + embed_id: embed.id, + createdAt: { + gte: new Date(new Date() - 24 * 60 * 60 * 1000), + }, + }); + + if (dailyChatCount >= Number(embed.max_chats_per_day)) { + response.status(429).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: + "The quota for this chat has been reached. Try again later or contact the site owner.", + }); + return; + } + } + + if ( + !isNaN(embed.max_chats_per_session) && + Number(embed.max_chats_per_session) > 0 + ) { + const dailySessionCount = await EmbedChats.count({ + embed_id: embed.id, + session_id: sessionId, + createdAt: { + gte: new Date(new Date() - 24 * 60 * 60 * 1000), + }, + }); + + if (dailySessionCount >= Number(embed.max_chats_per_session)) { + response.status(429).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: + "Your quota for this chat has been reached. Try again later or contact the site owner.", + }); + return; + } + } + + next(); +} + +module.exports = { + setConnectionMeta, + validEmbedConfig, + canRespond, +};