From c22c50cca8f9b83de86596d2decc2646be93dff4 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Mon, 13 Nov 2023 15:07:30 -0800 Subject: [PATCH] Enable chat streaming for LLMs (#354) * [Draft] Enable chat streaming for LLMs * stream only, move sendChat to deprecated * Update TODO deprecation comments update console output color for streaming disabled --- frontend/package.json | 3 +- .../ChatHistory/PromptReply/index.jsx | 2 +- .../ChatContainer/ChatHistory/index.jsx | 4 +- .../WorkspaceChat/ChatContainer/index.jsx | 35 ++- frontend/src/index.css | 21 ++ frontend/src/models/workspace.js | 63 +++- frontend/src/utils/chat/index.js | 39 ++- frontend/yarn.lock | 5 + server/endpoints/chat.js | 85 ++++++ server/utils/AiProviders/anthropic/index.js | 4 + server/utils/AiProviders/azureOpenAi/index.js | 4 + server/utils/AiProviders/lmStudio/index.js | 48 +++ server/utils/AiProviders/openAi/index.js | 49 +++ server/utils/chats/index.js | 3 + server/utils/chats/stream.js | 279 ++++++++++++++++++ 15 files changed, 618 insertions(+), 26 deletions(-) create mode 100644 server/utils/chats/stream.js diff --git a/frontend/package.json b/frontend/package.json index 9328570b..a1531b9c 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -12,6 +12,7 @@ "dependencies": { "@esbuild-plugins/node-globals-polyfill": "^0.1.1", "@metamask/jazzicon": "^2.0.0", + "@microsoft/fetch-event-source": "^2.0.1", "@phosphor-icons/react": "^2.0.13", "buffer": "^6.0.3", "he": "^1.2.0", @@ -46,4 +47,4 @@ "tailwindcss": "^3.3.1", "vite": "^4.3.0" } -} \ No newline at end of file +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/PromptReply/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/PromptReply/index.jsx index 9b2ade1e..427da6c3 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/PromptReply/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/PromptReply/index.jsx @@ -72,7 +72,7 @@ const PromptReply = forwardRef( role="assistant" /> diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx index e684d237..20ee990e 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx @@ -53,8 +53,10 @@ export default function ChatHistory({ history = [], workspace }) { > {history.map((props, index) => { const isLastMessage = index === history.length - 1; + const isLastBotReply = + index === history.length - 1 && props.role === "assistant"; - if (props.role === "assistant" && props.animate) { + if (isLastBotReply && props.animate) { return ( + handleChat( + chatResult, + setLoadingResponse, + setChatHistory, + remHistory, + _chatHistory + ) ); + return; } loadingResponse === true && fetchReply(); }, [loadingResponse, chatHistory, workspace]); diff --git a/frontend/src/index.css b/frontend/src/index.css index 2553629d..937631be 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -358,3 +358,24 @@ dialog::backdrop { .user-reply > div:first-of-type { border: 2px solid white; } + +.reply > *:last-child::after { + content: "|"; + animation: blink 1.5s steps(1) infinite; + color: white; + font-size: 14px; +} + +@keyframes blink { + 0% { + opacity: 0; + } + + 50% { + opacity: 1; + } + + 100% { + opacity: 0; + } +} diff --git a/frontend/src/models/workspace.js b/frontend/src/models/workspace.js index 540a6f13..0f30592d 100644 --- a/frontend/src/models/workspace.js +++ b/frontend/src/models/workspace.js @@ -1,5 +1,7 @@ import { API_BASE } from "../utils/constants"; import { baseHeaders } from "../utils/request"; +import { fetchEventSource } from "@microsoft/fetch-event-source"; +import { v4 } from "uuid"; const Workspace = { new: async function (data = {}) { @@ -57,19 +59,44 @@ const Workspace = { .catch(() => []); return history; }, - sendChat: async function ({ slug }, message, mode = "query") { - const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, { + streamChat: async function ({ slug }, message, mode = "query", handleChat) { + const ctrl = new AbortController(); + await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, { method: "POST", body: JSON.stringify({ message, mode }), headers: baseHeaders(), - }) - .then((res) => res.json()) - .catch((e) => { - console.error(e); - return null; - }); - - return chatResult; + signal: ctrl.signal, + async onopen(response) { + if (response.ok) { + return; // everything's good + } else if ( + response.status >= 400 && + response.status < 500 && + response.status !== 429 + ) { + throw new Error("Invalid Status code response."); + } else { + throw new Error("Unknown error"); + } + }, + async onmessage(msg) { + try { + const chatResult = JSON.parse(msg.data); + handleChat(chatResult); + } catch {} + }, + onerror(err) { + handleChat({ + id: v4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `An error occurred while streaming response. ${err.message}`, + }); + ctrl.abort(); + }, + }); }, all: async function () { const workspaces = await fetch(`${API_BASE}/workspaces`, { @@ -111,6 +138,22 @@ const Workspace = { const data = await response.json(); return { response, data }; }, + + // TODO: Deprecated and should be removed from frontend. + sendChat: async function ({ slug }, message, mode = "query") { + const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, { + method: "POST", + body: JSON.stringify({ message, mode }), + headers: baseHeaders(), + }) + .then((res) => res.json()) + .catch((e) => { + console.error(e); + return null; + }); + + return chatResult; + }, }; export default Workspace; diff --git a/frontend/src/utils/chat/index.js b/frontend/src/utils/chat/index.js index 35a911d0..f2587484 100644 --- a/frontend/src/utils/chat/index.js +++ b/frontend/src/utils/chat/index.js @@ -19,7 +19,8 @@ export default function handleChat( sources, closed: true, error, - animate: true, + animate: false, + pending: false, }, ]); _chatHistory.push({ @@ -29,7 +30,8 @@ export default function handleChat( sources, closed: true, error, - animate: true, + animate: false, + pending: false, }); } else if (type === "textResponse") { setLoadingResponse(false); @@ -42,7 +44,8 @@ export default function handleChat( sources, closed: close, error, - animate: true, + animate: !close, + pending: false, }, ]); _chatHistory.push({ @@ -52,8 +55,36 @@ export default function handleChat( sources, closed: close, error, - animate: true, + animate: !close, + pending: false, }); + } else if (type === "textResponseChunk") { + const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid); + if (chatIdx !== -1) { + const existingHistory = { ..._chatHistory[chatIdx] }; + const updatedHistory = { + ...existingHistory, + content: existingHistory.content + textResponse, + sources, + error, + closed: close, + animate: !close, + pending: false, + }; + _chatHistory[chatIdx] = updatedHistory; + } else { + _chatHistory.push({ + uuid, + sources, + error, + content: textResponse, + role: "assistant", + closed: close, + animate: !close, + pending: false, + }); + } + setChatHistory([..._chatHistory]); } } diff --git a/frontend/yarn.lock b/frontend/yarn.lock index 27023b51..fdb7aae6 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -426,6 +426,11 @@ color "^0.11.3" mersenne-twister "^1.1.0" +"@microsoft/fetch-event-source@^2.0.1": + version "2.0.1" + resolved "https://registry.yarnpkg.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d" + integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA== + "@nodelib/fs.scandir@2.1.5": version "2.1.5" resolved "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz" diff --git a/server/endpoints/chat.js b/server/endpoints/chat.js index b0afed76..efe10c69 100644 --- a/server/endpoints/chat.js +++ b/server/endpoints/chat.js @@ -6,10 +6,95 @@ const { validatedRequest } = require("../utils/middleware/validatedRequest"); const { WorkspaceChats } = require("../models/workspaceChats"); const { SystemSettings } = require("../models/systemSettings"); const { Telemetry } = require("../models/telemetry"); +const { + streamChatWithWorkspace, + writeResponseChunk, +} = require("../utils/chats/stream"); function chatEndpoints(app) { if (!app) return; + app.post( + "/workspace/:slug/stream-chat", + [validatedRequest], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { slug } = request.params; + const { message, mode = "query" } = reqBody(request); + + const workspace = multiUserMode(response) + ? await Workspace.getWithUser(user, { slug }) + : await Workspace.get({ slug }); + + if (!workspace) { + response.sendStatus(400).end(); + return; + } + + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Content-Type", "text/event-stream"); + response.setHeader("Access-Control-Allow-Origin", "*"); + response.setHeader("Connection", "keep-alive"); + response.flushHeaders(); + + if (multiUserMode(response) && user.role !== "admin") { + const limitMessagesSetting = await SystemSettings.get({ + label: "limit_user_messages", + }); + const limitMessages = limitMessagesSetting?.value === "true"; + + if (limitMessages) { + const messageLimitSetting = await SystemSettings.get({ + label: "message_limit", + }); + const systemLimit = Number(messageLimitSetting?.value); + + if (!!systemLimit) { + const currentChatCount = await WorkspaceChats.count({ + user_id: user.id, + createdAt: { + gte: new Date(new Date() - 24 * 60 * 60 * 1000), + }, + }); + + if (currentChatCount >= systemLimit) { + writeResponseChunk(response, { + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `You have met your maximum 24 hour chat quota of ${systemLimit} chats set by the instance administrators. Try again later.`, + }); + return; + } + } + } + } + + await streamChatWithWorkspace(response, workspace, message, mode, user); + await Telemetry.sendTelemetry("sent_chat", { + multiUserMode: multiUserMode(response), + LLMSelection: process.env.LLM_PROVIDER || "openai", + VectorDbSelection: process.env.VECTOR_DB || "pinecone", + }); + response.end(); + } catch (e) { + console.error(e); + writeResponseChunk(response, { + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: e.message, + }); + response.end(); + } + } + ); + app.post( "/workspace/:slug/chat", [validatedRequest], diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index dca21422..703c0859 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -27,6 +27,10 @@ class AnthropicLLM { this.answerKey = v4().split("-")[0]; } + streamingEnabled() { + return "streamChat" in this && "streamGetChatCompletion" in this; + } + promptWindowLimit() { switch (this.model) { case "claude-instant-1": diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 30059035..a424902b 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -22,6 +22,10 @@ class AzureOpenAiLLM extends AzureOpenAiEmbedder { }; } + streamingEnabled() { + return "streamChat" in this && "streamGetChatCompletion" in this; + } + // Sure the user selected a proper value for the token limit // could be any of these https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-models // and if undefined - assume it is the lowest end. diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index bb025b3b..e0ccc316 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -27,6 +27,10 @@ class LMStudioLLM { this.embedder = embedder; } + streamingEnabled() { + return "streamChat" in this && "streamGetChatCompletion" in this; + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { @@ -103,6 +107,32 @@ Context: return textResponse; } + async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { + if (!this.model) + throw new Error( + `LMStudio chat: ${model} is not valid or defined for chat completion!` + ); + + const streamRequest = await this.lmstudio.createChatCompletion( + { + model: this.model, + temperature: Number(workspace?.openAiTemp ?? 0.7), + n: 1, + stream: true, + messages: await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + userPrompt: prompt, + chatHistory, + }, + rawHistory + ), + }, + { responseType: "stream" } + ); + return streamRequest; + } + async getChatCompletion(messages = null, { temperature = 0.7 }) { if (!this.model) throw new Error( @@ -119,6 +149,24 @@ Context: return data.choices[0].message.content; } + async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { + if (!this.model) + throw new Error( + `LMStudio chat: ${this.model} is not valid or defined model for chat completion!` + ); + + const streamRequest = await this.lmstudio.createChatCompletion( + { + model: this.model, + stream: true, + messages, + temperature, + }, + { responseType: "stream" } + ); + return streamRequest; + } + // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations async embedTextInput(textInput) { return await this.embedder.embedTextInput(textInput); diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 1a5072f2..33ed1b19 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -19,6 +19,10 @@ class OpenAiLLM extends OpenAiEmbedder { }; } + streamingEnabled() { + return "streamChat" in this && "streamGetChatCompletion" in this; + } + promptWindowLimit() { switch (this.model) { case "gpt-3.5-turbo": @@ -140,6 +144,33 @@ Context: return textResponse; } + async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { + const model = process.env.OPEN_MODEL_PREF; + if (!(await this.isValidChatCompletionModel(model))) + throw new Error( + `OpenAI chat: ${model} is not valid for chat completion!` + ); + + const streamRequest = await this.openai.createChatCompletion( + { + model, + stream: true, + temperature: Number(workspace?.openAiTemp ?? 0.7), + n: 1, + messages: await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + userPrompt: prompt, + chatHistory, + }, + rawHistory + ), + }, + { responseType: "stream" } + ); + return streamRequest; + } + async getChatCompletion(messages = null, { temperature = 0.7 }) { if (!(await this.isValidChatCompletionModel(this.model))) throw new Error( @@ -156,6 +187,24 @@ Context: return data.choices[0].message.content; } + async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { + if (!(await this.isValidChatCompletionModel(this.model))) + throw new Error( + `OpenAI chat: ${this.model} is not valid for chat completion!` + ); + + const streamRequest = await this.openai.createChatCompletion( + { + model: this.model, + stream: true, + messages, + temperature, + }, + { responseType: "stream" } + ); + return streamRequest; + } + async compressMessages(promptArgs = {}, rawHistory = []) { const { messageArrayCompressor } = require("../../helpers/chat"); const messageArray = this.constructPrompt(promptArgs); diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index a1ed4758..7e9be6e5 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -242,8 +242,11 @@ function chatPrompt(workspace) { } module.exports = { + recentChatHistory, convertToPromptHistory, convertToChatHistory, chatWithWorkspace, chatPrompt, + grepCommand, + VALID_COMMANDS, }; diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js new file mode 100644 index 00000000..b6d011bd --- /dev/null +++ b/server/utils/chats/stream.js @@ -0,0 +1,279 @@ +const { v4: uuidv4 } = require("uuid"); +const { WorkspaceChats } = require("../../models/workspaceChats"); +const { getVectorDbClass, getLLMProvider } = require("../helpers"); +const { + grepCommand, + recentChatHistory, + VALID_COMMANDS, + chatPrompt, +} = require("."); + +function writeResponseChunk(response, data) { + response.write(`data: ${JSON.stringify(data)}\n\n`); + return; +} + +async function streamChatWithWorkspace( + response, + workspace, + message, + chatMode = "chat", + user = null +) { + const uuid = uuidv4(); + const command = grepCommand(message); + + if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { + const data = await VALID_COMMANDS[command](workspace, message, uuid, user); + writeResponseChunk(response, data); + return; + } + + const LLMConnector = getLLMProvider(); + const VectorDb = getVectorDbClass(); + const { safe, reasons = [] } = await LLMConnector.isSafe(message); + if (!safe) { + writeResponseChunk(response, { + id: uuid, + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `This message was moderated and will not be allowed. Violations for ${reasons.join( + ", " + )} found.`, + }); + return; + } + + const messageLimit = workspace?.openAiHistory || 20; + const hasVectorizedSpace = await VectorDb.hasNamespace(workspace.slug); + const embeddingsCount = await VectorDb.namespaceCount(workspace.slug); + if (!hasVectorizedSpace || embeddingsCount === 0) { + // If there are no embeddings - chat like a normal LLM chat interface. + return await streamEmptyEmbeddingChat({ + response, + uuid, + user, + message, + workspace, + messageLimit, + LLMConnector, + }); + } + + let completeText; + const { rawHistory, chatHistory } = await recentChatHistory( + user, + workspace, + messageLimit, + chatMode + ); + const { + contextTexts = [], + sources = [], + message: error, + } = await VectorDb.performSimilaritySearch({ + namespace: workspace.slug, + input: message, + LLMConnector, + similarityThreshold: workspace?.similarityThreshold, + }); + + // Failed similarity search. + if (!!error) { + writeResponseChunk(response, { + id: uuid, + type: "abort", + textResponse: null, + sources: [], + close: true, + error, + }); + return; + } + + // Compress message to ensure prompt passes token limit with room for response + // and build system messages based on inputs and history. + const messages = await LLMConnector.compressMessages( + { + systemPrompt: chatPrompt(workspace), + userPrompt: message, + contextTexts, + chatHistory, + }, + rawHistory + ); + + // If streaming is not explicitly enabled for connector + // we do regular waiting of a response and send a single chunk. + if (LLMConnector.streamingEnabled() !== true) { + console.log( + `\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.` + ); + completeText = await LLMConnector.getChatCompletion(messages, { + temperature: workspace?.openAiTemp ?? 0.7, + }); + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: completeText, + close: true, + error: false, + }); + } else { + const stream = await LLMConnector.streamGetChatCompletion(messages, { + temperature: workspace?.openAiTemp ?? 0.7, + }); + completeText = await handleStreamResponses(response, stream, { + uuid, + sources, + }); + } + + await WorkspaceChats.new({ + workspaceId: workspace.id, + prompt: message, + response: { text: completeText, sources, type: chatMode }, + user, + }); + return; +} + +async function streamEmptyEmbeddingChat({ + response, + uuid, + user, + message, + workspace, + messageLimit, + LLMConnector, +}) { + let completeText; + const { rawHistory, chatHistory } = await recentChatHistory( + user, + workspace, + messageLimit + ); + + // If streaming is not explicitly enabled for connector + // we do regular waiting of a response and send a single chunk. + if (LLMConnector.streamingEnabled() !== true) { + console.log( + `\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.` + ); + completeText = await LLMConnector.sendChat( + chatHistory, + message, + workspace, + rawHistory + ); + writeResponseChunk(response, { + uuid, + type: "textResponseChunk", + textResponse: completeText, + sources: [], + close: true, + error: false, + }); + } else { + const stream = await LLMConnector.streamChat( + chatHistory, + message, + workspace, + rawHistory + ); + completeText = await handleStreamResponses(response, stream, { + uuid, + sources: [], + }); + } + + await WorkspaceChats.new({ + workspaceId: workspace.id, + prompt: message, + response: { text: completeText, sources: [], type: "chat" }, + user, + }); + return; +} + +function handleStreamResponses(response, stream, responseProps) { + const { uuid = uuidv4(), sources = [] } = responseProps; + return new Promise((resolve) => { + let fullText = ""; + let chunk = ""; + stream.data.on("data", (data) => { + const lines = data + ?.toString() + ?.split("\n") + .filter((line) => line.trim() !== ""); + + for (const line of lines) { + const message = chunk + line.replace(/^data: /, ""); + + // JSON chunk is incomplete and has not ended yet + // so we need to stitch it together. You would think JSON + // chunks would only come complete - but they don't! + if (message.slice(-3) !== "}]}") { + chunk += message; + continue; + } else { + chunk = ""; + } + + if (message == "[DONE]") { + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: "", + close: true, + error: false, + }); + resolve(fullText); + } else { + let finishReason; + let token = ""; + try { + const json = JSON.parse(message); + token = json?.choices?.[0]?.delta?.content; + finishReason = json?.choices?.[0]?.finish_reason; + } catch { + continue; + } + + if (token) { + fullText += token; + writeResponseChunk(response, { + uuid, + sources: [], + type: "textResponseChunk", + textResponse: token, + close: false, + error: false, + }); + } + + if (finishReason !== null) { + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: "", + close: true, + error: false, + }); + resolve(fullText); + } + } + } + }); + }); +} + +module.exports = { + streamChatWithWorkspace, + writeResponseChunk, +};