Enable chat streaming for LLMs (#354)

* [Draft] Enable chat streaming for LLMs

* stream only, move sendChat to deprecated

* Update TODO deprecation comments
update console output color for streaming disabled
This commit is contained in:
Timothy Carambat 2023-11-13 15:07:30 -08:00 committed by GitHub
parent fa29003a46
commit c22c50cca8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 618 additions and 26 deletions

View File

@ -12,6 +12,7 @@
"dependencies": {
"@esbuild-plugins/node-globals-polyfill": "^0.1.1",
"@metamask/jazzicon": "^2.0.0",
"@microsoft/fetch-event-source": "^2.0.1",
"@phosphor-icons/react": "^2.0.13",
"buffer": "^6.0.3",
"he": "^1.2.0",
@ -46,4 +47,4 @@
"tailwindcss": "^3.3.1",
"vite": "^4.3.0"
}
}
}

View File

@ -72,7 +72,7 @@ const PromptReply = forwardRef(
role="assistant"
/>
<span
className={`whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }}
/>
</div>

View File

@ -53,8 +53,10 @@ export default function ChatHistory({ history = [], workspace }) {
>
{history.map((props, index) => {
const isLastMessage = index === history.length - 1;
const isLastBotReply =
index === history.length - 1 && props.role === "assistant";
if (props.role === "assistant" && props.animate) {
if (isLastBotReply && props.animate) {
return (
<PromptReply
key={props.uuid}

View File

@ -48,19 +48,36 @@ export default function ChatContainer({ workspace, knownHistory = [] }) {
return false;
}
const chatResult = await Workspace.sendChat(
// TODO: Delete this snippet once we have streaming stable.
// const chatResult = await Workspace.sendChat(
// workspace,
// promptMessage.userMessage,
// window.localStorage.getItem(`workspace_chat_mode_${workspace.slug}`) ??
// "chat",
// )
// handleChat(
// chatResult,
// setLoadingResponse,
// setChatHistory,
// remHistory,
// _chatHistory
// )
await Workspace.streamChat(
workspace,
promptMessage.userMessage,
window.localStorage.getItem(`workspace_chat_mode_${workspace.slug}`) ??
"chat"
);
handleChat(
chatResult,
setLoadingResponse,
setChatHistory,
remHistory,
_chatHistory
"chat",
(chatResult) =>
handleChat(
chatResult,
setLoadingResponse,
setChatHistory,
remHistory,
_chatHistory
)
);
return;
}
loadingResponse === true && fetchReply();
}, [loadingResponse, chatHistory, workspace]);

View File

@ -358,3 +358,24 @@ dialog::backdrop {
.user-reply > div:first-of-type {
border: 2px solid white;
}
.reply > *:last-child::after {
content: "|";
animation: blink 1.5s steps(1) infinite;
color: white;
font-size: 14px;
}
@keyframes blink {
0% {
opacity: 0;
}
50% {
opacity: 1;
}
100% {
opacity: 0;
}
}

View File

@ -1,5 +1,7 @@
import { API_BASE } from "../utils/constants";
import { baseHeaders } from "../utils/request";
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { v4 } from "uuid";
const Workspace = {
new: async function (data = {}) {
@ -57,19 +59,44 @@ const Workspace = {
.catch(() => []);
return history;
},
sendChat: async function ({ slug }, message, mode = "query") {
const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, {
streamChat: async function ({ slug }, message, mode = "query", handleChat) {
const ctrl = new AbortController();
await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, {
method: "POST",
body: JSON.stringify({ message, mode }),
headers: baseHeaders(),
})
.then((res) => res.json())
.catch((e) => {
console.error(e);
return null;
});
return chatResult;
signal: ctrl.signal,
async onopen(response) {
if (response.ok) {
return; // everything's good
} else if (
response.status >= 400 &&
response.status < 500 &&
response.status !== 429
) {
throw new Error("Invalid Status code response.");
} else {
throw new Error("Unknown error");
}
},
async onmessage(msg) {
try {
const chatResult = JSON.parse(msg.data);
handleChat(chatResult);
} catch {}
},
onerror(err) {
handleChat({
id: v4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `An error occurred while streaming response. ${err.message}`,
});
ctrl.abort();
},
});
},
all: async function () {
const workspaces = await fetch(`${API_BASE}/workspaces`, {
@ -111,6 +138,22 @@ const Workspace = {
const data = await response.json();
return { response, data };
},
// TODO: Deprecated and should be removed from frontend.
sendChat: async function ({ slug }, message, mode = "query") {
const chatResult = await fetch(`${API_BASE}/workspace/${slug}/chat`, {
method: "POST",
body: JSON.stringify({ message, mode }),
headers: baseHeaders(),
})
.then((res) => res.json())
.catch((e) => {
console.error(e);
return null;
});
return chatResult;
},
};
export default Workspace;

View File

@ -19,7 +19,8 @@ export default function handleChat(
sources,
closed: true,
error,
animate: true,
animate: false,
pending: false,
},
]);
_chatHistory.push({
@ -29,7 +30,8 @@ export default function handleChat(
sources,
closed: true,
error,
animate: true,
animate: false,
pending: false,
});
} else if (type === "textResponse") {
setLoadingResponse(false);
@ -42,7 +44,8 @@ export default function handleChat(
sources,
closed: close,
error,
animate: true,
animate: !close,
pending: false,
},
]);
_chatHistory.push({
@ -52,8 +55,36 @@ export default function handleChat(
sources,
closed: close,
error,
animate: true,
animate: !close,
pending: false,
});
} else if (type === "textResponseChunk") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
if (chatIdx !== -1) {
const existingHistory = { ..._chatHistory[chatIdx] };
const updatedHistory = {
...existingHistory,
content: existingHistory.content + textResponse,
sources,
error,
closed: close,
animate: !close,
pending: false,
};
_chatHistory[chatIdx] = updatedHistory;
} else {
_chatHistory.push({
uuid,
sources,
error,
content: textResponse,
role: "assistant",
closed: close,
animate: !close,
pending: false,
});
}
setChatHistory([..._chatHistory]);
}
}

View File

@ -426,6 +426,11 @@
color "^0.11.3"
mersenne-twister "^1.1.0"
"@microsoft/fetch-event-source@^2.0.1":
version "2.0.1"
resolved "https://registry.yarnpkg.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d"
integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==
"@nodelib/fs.scandir@2.1.5":
version "2.1.5"
resolved "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz"

View File

@ -6,10 +6,95 @@ const { validatedRequest } = require("../utils/middleware/validatedRequest");
const { WorkspaceChats } = require("../models/workspaceChats");
const { SystemSettings } = require("../models/systemSettings");
const { Telemetry } = require("../models/telemetry");
const {
streamChatWithWorkspace,
writeResponseChunk,
} = require("../utils/chats/stream");
function chatEndpoints(app) {
if (!app) return;
app.post(
"/workspace/:slug/stream-chat",
[validatedRequest],
async (request, response) => {
try {
const user = await userFromSession(request, response);
const { slug } = request.params;
const { message, mode = "query" } = reqBody(request);
const workspace = multiUserMode(response)
? await Workspace.getWithUser(user, { slug })
: await Workspace.get({ slug });
if (!workspace) {
response.sendStatus(400).end();
return;
}
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Content-Type", "text/event-stream");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Connection", "keep-alive");
response.flushHeaders();
if (multiUserMode(response) && user.role !== "admin") {
const limitMessagesSetting = await SystemSettings.get({
label: "limit_user_messages",
});
const limitMessages = limitMessagesSetting?.value === "true";
if (limitMessages) {
const messageLimitSetting = await SystemSettings.get({
label: "message_limit",
});
const systemLimit = Number(messageLimitSetting?.value);
if (!!systemLimit) {
const currentChatCount = await WorkspaceChats.count({
user_id: user.id,
createdAt: {
gte: new Date(new Date() - 24 * 60 * 60 * 1000),
},
});
if (currentChatCount >= systemLimit) {
writeResponseChunk(response, {
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `You have met your maximum 24 hour chat quota of ${systemLimit} chats set by the instance administrators. Try again later.`,
});
return;
}
}
}
}
await streamChatWithWorkspace(response, workspace, message, mode, user);
await Telemetry.sendTelemetry("sent_chat", {
multiUserMode: multiUserMode(response),
LLMSelection: process.env.LLM_PROVIDER || "openai",
VectorDbSelection: process.env.VECTOR_DB || "pinecone",
});
response.end();
} catch (e) {
console.error(e);
writeResponseChunk(response, {
id: uuidv4(),
type: "abort",
textResponse: null,
sources: [],
close: true,
error: e.message,
});
response.end();
}
}
);
app.post(
"/workspace/:slug/chat",
[validatedRequest],

View File

@ -27,6 +27,10 @@ class AnthropicLLM {
this.answerKey = v4().split("-")[0];
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
promptWindowLimit() {
switch (this.model) {
case "claude-instant-1":

View File

@ -22,6 +22,10 @@ class AzureOpenAiLLM extends AzureOpenAiEmbedder {
};
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
// Sure the user selected a proper value for the token limit
// could be any of these https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-models
// and if undefined - assume it is the lowest end.

View File

@ -27,6 +27,10 @@ class LMStudioLLM {
this.embedder = embedder;
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {
@ -103,6 +107,32 @@ Context:
return textResponse;
}
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
if (!this.model)
throw new Error(
`LMStudio chat: ${model} is not valid or defined for chat completion!`
);
const streamRequest = await this.lmstudio.createChatCompletion(
{
model: this.model,
temperature: Number(workspace?.openAiTemp ?? 0.7),
n: 1,
stream: true,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
},
{ responseType: "stream" }
);
return streamRequest;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
if (!this.model)
throw new Error(
@ -119,6 +149,24 @@ Context:
return data.choices[0].message.content;
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
if (!this.model)
throw new Error(
`LMStudio chat: ${this.model} is not valid or defined model for chat completion!`
);
const streamRequest = await this.lmstudio.createChatCompletion(
{
model: this.model,
stream: true,
messages,
temperature,
},
{ responseType: "stream" }
);
return streamRequest;
}
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);

View File

@ -19,6 +19,10 @@ class OpenAiLLM extends OpenAiEmbedder {
};
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
promptWindowLimit() {
switch (this.model) {
case "gpt-3.5-turbo":
@ -140,6 +144,33 @@ Context:
return textResponse;
}
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
const model = process.env.OPEN_MODEL_PREF;
if (!(await this.isValidChatCompletionModel(model)))
throw new Error(
`OpenAI chat: ${model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model,
stream: true,
temperature: Number(workspace?.openAiTemp ?? 0.7),
n: 1,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
},
{ responseType: "stream" }
);
return streamRequest;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
@ -156,6 +187,24 @@ Context:
return data.choices[0].message.content;
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`OpenAI chat: ${this.model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model: this.model,
stream: true,
messages,
temperature,
},
{ responseType: "stream" }
);
return streamRequest;
}
async compressMessages(promptArgs = {}, rawHistory = []) {
const { messageArrayCompressor } = require("../../helpers/chat");
const messageArray = this.constructPrompt(promptArgs);

View File

@ -242,8 +242,11 @@ function chatPrompt(workspace) {
}
module.exports = {
recentChatHistory,
convertToPromptHistory,
convertToChatHistory,
chatWithWorkspace,
chatPrompt,
grepCommand,
VALID_COMMANDS,
};

View File

@ -0,0 +1,279 @@
const { v4: uuidv4 } = require("uuid");
const { WorkspaceChats } = require("../../models/workspaceChats");
const { getVectorDbClass, getLLMProvider } = require("../helpers");
const {
grepCommand,
recentChatHistory,
VALID_COMMANDS,
chatPrompt,
} = require(".");
function writeResponseChunk(response, data) {
response.write(`data: ${JSON.stringify(data)}\n\n`);
return;
}
async function streamChatWithWorkspace(
response,
workspace,
message,
chatMode = "chat",
user = null
) {
const uuid = uuidv4();
const command = grepCommand(message);
if (!!command && Object.keys(VALID_COMMANDS).includes(command)) {
const data = await VALID_COMMANDS[command](workspace, message, uuid, user);
writeResponseChunk(response, data);
return;
}
const LLMConnector = getLLMProvider();
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {
writeResponseChunk(response, {
id: uuid,
type: "abort",
textResponse: null,
sources: [],
close: true,
error: `This message was moderated and will not be allowed. Violations for ${reasons.join(
", "
)} found.`,
});
return;
}
const messageLimit = workspace?.openAiHistory || 20;
const hasVectorizedSpace = await VectorDb.hasNamespace(workspace.slug);
const embeddingsCount = await VectorDb.namespaceCount(workspace.slug);
if (!hasVectorizedSpace || embeddingsCount === 0) {
// If there are no embeddings - chat like a normal LLM chat interface.
return await streamEmptyEmbeddingChat({
response,
uuid,
user,
message,
workspace,
messageLimit,
LLMConnector,
});
}
let completeText;
const { rawHistory, chatHistory } = await recentChatHistory(
user,
workspace,
messageLimit,
chatMode
);
const {
contextTexts = [],
sources = [],
message: error,
} = await VectorDb.performSimilaritySearch({
namespace: workspace.slug,
input: message,
LLMConnector,
similarityThreshold: workspace?.similarityThreshold,
});
// Failed similarity search.
if (!!error) {
writeResponseChunk(response, {
id: uuid,
type: "abort",
textResponse: null,
sources: [],
close: true,
error,
});
return;
}
// Compress message to ensure prompt passes token limit with room for response
// and build system messages based on inputs and history.
const messages = await LLMConnector.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: message,
contextTexts,
chatHistory,
},
rawHistory
);
// If streaming is not explicitly enabled for connector
// we do regular waiting of a response and send a single chunk.
if (LLMConnector.streamingEnabled() !== true) {
console.log(
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
);
completeText = await LLMConnector.getChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? 0.7,
});
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: completeText,
close: true,
error: false,
});
} else {
const stream = await LLMConnector.streamGetChatCompletion(messages, {
temperature: workspace?.openAiTemp ?? 0.7,
});
completeText = await handleStreamResponses(response, stream, {
uuid,
sources,
});
}
await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources, type: chatMode },
user,
});
return;
}
async function streamEmptyEmbeddingChat({
response,
uuid,
user,
message,
workspace,
messageLimit,
LLMConnector,
}) {
let completeText;
const { rawHistory, chatHistory } = await recentChatHistory(
user,
workspace,
messageLimit
);
// If streaming is not explicitly enabled for connector
// we do regular waiting of a response and send a single chunk.
if (LLMConnector.streamingEnabled() !== true) {
console.log(
`\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
);
completeText = await LLMConnector.sendChat(
chatHistory,
message,
workspace,
rawHistory
);
writeResponseChunk(response, {
uuid,
type: "textResponseChunk",
textResponse: completeText,
sources: [],
close: true,
error: false,
});
} else {
const stream = await LLMConnector.streamChat(
chatHistory,
message,
workspace,
rawHistory
);
completeText = await handleStreamResponses(response, stream, {
uuid,
sources: [],
});
}
await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources: [], type: "chat" },
user,
});
return;
}
function handleStreamResponses(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;
return new Promise((resolve) => {
let fullText = "";
let chunk = "";
stream.data.on("data", (data) => {
const lines = data
?.toString()
?.split("\n")
.filter((line) => line.trim() !== "");
for (const line of lines) {
const message = chunk + line.replace(/^data: /, "");
// JSON chunk is incomplete and has not ended yet
// so we need to stitch it together. You would think JSON
// chunks would only come complete - but they don't!
if (message.slice(-3) !== "}]}") {
chunk += message;
continue;
} else {
chunk = "";
}
if (message == "[DONE]") {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
} else {
let finishReason;
let token = "";
try {
const json = JSON.parse(message);
token = json?.choices?.[0]?.delta?.content;
finishReason = json?.choices?.[0]?.finish_reason;
} catch {
continue;
}
if (token) {
fullText += token;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: token,
close: false,
error: false,
});
}
if (finishReason !== null) {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
}
}
}
});
});
}
module.exports = {
streamChatWithWorkspace,
writeResponseChunk,
};