From 406732830f196cf768093bd2505fc5be5e3bdbb4 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Thu, 8 Feb 2024 18:37:22 -0800 Subject: [PATCH] Implement workspace threading that is backwards compatible (#699) * Implement workspace thread that is compatible with legacy versions * last touches * comment on chat qty enforcement --- frontend/src/App.jsx | 4 + .../ThreadContainer/ThreadItem/index.jsx | 189 ++++++++++++++++++ .../ThreadContainer/index.jsx | 90 +++++++++ .../Sidebar/ActiveWorkspaces/index.jsx | 96 ++++----- .../WorkspaceChat/ChatContainer/index.jsx | 49 +++-- .../src/components/WorkspaceChat/index.jsx | 6 +- frontend/src/models/workspace.js | 2 + frontend/src/models/workspaceThread.js | 146 ++++++++++++++ frontend/src/pages/WorkspaceChat/index.jsx | 6 +- frontend/src/utils/paths.js | 3 + server/endpoints/chat.js | 114 +++++++++++ server/endpoints/workspaceThreads.js | 150 ++++++++++++++ server/index.js | 2 + server/models/workspaceChats.js | 34 +++- server/models/workspaceThread.js | 86 ++++++++ .../20240208224848_init/migration.sql | 24 +++ server/prisma/schema.prisma | 30 ++- server/utils/chats/commands/reset.js | 18 +- server/utils/chats/index.js | 27 +++ server/utils/chats/stream.js | 42 ++-- server/utils/middleware/validWorkspace.js | 52 +++++ 21 files changed, 1087 insertions(+), 83 deletions(-) create mode 100644 frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/ThreadItem/index.jsx create mode 100644 frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/index.jsx create mode 100644 frontend/src/models/workspaceThread.js create mode 100644 server/endpoints/workspaceThreads.js create mode 100644 server/models/workspaceThread.js create mode 100644 server/prisma/migrations/20240208224848_init/migration.sql create mode 100644 server/utils/middleware/validWorkspace.js diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 7d4ee4c5..7a1395f1 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -61,6 +61,10 @@ export default function App() { path="/workspace/:slug" element={} /> + } + /> } /> {/* Admin */} diff --git a/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/ThreadItem/index.jsx b/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/ThreadItem/index.jsx new file mode 100644 index 00000000..29e4f67e --- /dev/null +++ b/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/ThreadItem/index.jsx @@ -0,0 +1,189 @@ +import Workspace from "@/models/workspace"; +import paths from "@/utils/paths"; +import showToast from "@/utils/toast"; +import { DotsThree, PencilSimple, Trash } from "@phosphor-icons/react"; +import { useEffect, useRef, useState } from "react"; +import { useParams } from "react-router-dom"; +import truncate from "truncate"; + +const THREAD_CALLOUT_DETAIL_WIDTH = 26; +export default function ThreadItem({ workspace, thread, onRemove, hasNext }) { + const optionsContainer = useRef(null); + const { slug, threadSlug = null } = useParams(); + const [showOptions, setShowOptions] = useState(false); + const [name, setName] = useState(thread.name); + + const isActive = threadSlug === thread.slug; + const linkTo = !thread.slug + ? paths.workspace.chat(slug) + : paths.workspace.thread(slug, thread.slug); + + return ( +
+ {/* Curved line Element and leader if required */} +
+ {hasNext && ( +
+ )} + + {/* Curved line inline placeholder for spacing */} +
+
+ +

+ {truncate(name, 25)} +

+
+ {!!thread.slug && ( +
+
+ +
+ {showOptions && ( + setShowOptions(false)} + /> + )} +
+ )} +
+
+ ); +} + +function OptionsMenu({ + containerRef, + workspace, + thread, + onRename, + onRemove, + close, +}) { + const menuRef = useRef(null); + + // Ref menu options + const outsideClick = (e) => { + if (!menuRef.current) return false; + if ( + !menuRef.current?.contains(e.target) && + !containerRef.current?.contains(e.target) + ) + close(); + return false; + }; + + const isEsc = (e) => { + if (e.key === "Escape" || e.key === "Esc") close(); + }; + + function cleanupListeners() { + window.removeEventListener("click", outsideClick); + window.removeEventListener("keyup", isEsc); + } + // end Ref menu options + + useEffect(() => { + function setListeners() { + if (!menuRef?.current || !containerRef.current) return false; + window.document.addEventListener("click", outsideClick); + window.document.addEventListener("keyup", isEsc); + } + + setListeners(); + return cleanupListeners; + }, [menuRef.current, containerRef.current]); + + const renameThread = async () => { + const name = window + .prompt("What would you like to rename this thread to?") + ?.trim(); + if (!name || name.length === 0) { + close(); + return; + } + + const { message } = await Workspace.threads.update( + workspace.slug, + thread.slug, + { name } + ); + if (!!message) { + showToast(`Thread could not be updated! ${message}`, "error", { + clear: true, + }); + close(); + return; + } + + onRename(name); + close(); + }; + + const handleDelete = async () => { + if ( + !window.confirm( + "Are you sure you want to delete this thread? All of its chats will be deleted. You cannot undo this." + ) + ) + return; + const success = await Workspace.threads.delete(workspace.slug, thread.slug); + if (!success) { + showToast("Thread could not be deleted!", "error", { clear: true }); + return; + } + if (success) { + showToast("Thread deleted successfully!", "success", { clear: true }); + onRemove(thread.id); + return; + } + }; + + return ( +
+ + +
+ ); +} diff --git a/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/index.jsx b/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/index.jsx new file mode 100644 index 00000000..5667d639 --- /dev/null +++ b/frontend/src/components/Sidebar/ActiveWorkspaces/ThreadContainer/index.jsx @@ -0,0 +1,90 @@ +import Workspace from "@/models/workspace"; +import paths from "@/utils/paths"; +import showToast from "@/utils/toast"; +import { Plus, CircleNotch } from "@phosphor-icons/react"; +import { useEffect, useState } from "react"; +import ThreadItem from "./ThreadItem"; + +export default function ThreadContainer({ workspace }) { + const [threads, setThreads] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function fetchThreads() { + if (!workspace.slug) return; + const { threads } = await Workspace.threads.all(workspace.slug); + setLoading(false); + setThreads(threads); + } + fetchThreads(); + }, [workspace.slug]); + + function removeThread(threadId) { + setThreads((prev) => prev.filter((thread) => thread.id !== threadId)); + } + + if (loading) { + return ( +
+

+ loading threads.... +

+
+ ); + } + + return ( +
+ 0} + /> + {threads.map((thread, i) => ( + + ))} + +
+ ); +} + +function NewThreadButton({ workspace }) { + const [loading, setLoading] = useState(); + const onClick = async () => { + setLoading(true); + const { thread, error } = await Workspace.threads.new(workspace.slug); + if (!!error) { + showToast(`Could not create thread - ${error}`, "error", { clear: true }); + setLoading(false); + return; + } + window.location.replace( + paths.workspace.thread(workspace.slug, thread.slug) + ); + }; + + return ( + + ); +} diff --git a/frontend/src/components/Sidebar/ActiveWorkspaces/index.jsx b/frontend/src/components/Sidebar/ActiveWorkspaces/index.jsx index cefd6b97..4603009b 100644 --- a/frontend/src/components/Sidebar/ActiveWorkspaces/index.jsx +++ b/frontend/src/components/Sidebar/ActiveWorkspaces/index.jsx @@ -10,6 +10,7 @@ import { useParams } from "react-router-dom"; import { GearSix, SquaresFour } from "@phosphor-icons/react"; import truncate from "truncate"; import useUser from "@/hooks/useUser"; +import ThreadContainer from "./ThreadContainer"; export default function ActiveWorkspaces() { const { slug } = useParams(); @@ -68,15 +69,16 @@ export default function ActiveWorkspaces() { const isHovered = hoverStates[workspace.id]; const isGearHovered = settingHover[workspace.id]; return ( -
handleMouseEnter(workspace.id)} - onMouseLeave={() => handleMouseLeave(workspace.id)} - > - +
handleMouseEnter(workspace.id)} + onMouseLeave={() => handleMouseLeave(workspace.id)} + > + - + {isActive && ( + + )}
); })} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 7a5a974a..543d6105 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -5,8 +5,10 @@ import Workspace from "@/models/workspace"; import handleChat from "@/utils/chat"; import { isMobile } from "react-device-detect"; import { SidebarMobileHeader } from "../../Sidebar"; +import { useParams } from "react-router-dom"; export default function ChatContainer({ workspace, knownHistory = [] }) { + const { threadSlug = null } = useParams(); const [message, setMessage] = useState(""); const [loadingResponse, setLoadingResponse] = useState(false); const [chatHistory, setChatHistory] = useState(knownHistory); @@ -71,20 +73,39 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { return false; } - await Workspace.streamChat( - workspace, - promptMessage.userMessage, - window.localStorage.getItem(`workspace_chat_mode_${workspace.slug}`) ?? - "chat", - (chatResult) => - handleChat( - chatResult, - setLoadingResponse, - setChatHistory, - remHistory, - _chatHistory - ) - ); + if (!!threadSlug) { + await Workspace.threads.streamChat( + { workspaceSlug: workspace.slug, threadSlug }, + promptMessage.userMessage, + window.localStorage.getItem( + `workspace_chat_mode_${workspace.slug}` + ) ?? "chat", + (chatResult) => + handleChat( + chatResult, + setLoadingResponse, + setChatHistory, + remHistory, + _chatHistory + ) + ); + } else { + await Workspace.streamChat( + workspace, + promptMessage.userMessage, + window.localStorage.getItem( + `workspace_chat_mode_${workspace.slug}` + ) ?? "chat", + (chatResult) => + handleChat( + chatResult, + setLoadingResponse, + setChatHistory, + remHistory, + _chatHistory + ) + ); + } return; } loadingResponse === true && fetchReply(); diff --git a/frontend/src/components/WorkspaceChat/index.jsx b/frontend/src/components/WorkspaceChat/index.jsx index cbe7dbdd..990ac7f5 100644 --- a/frontend/src/components/WorkspaceChat/index.jsx +++ b/frontend/src/components/WorkspaceChat/index.jsx @@ -4,8 +4,10 @@ import LoadingChat from "./LoadingChat"; import ChatContainer from "./ChatContainer"; import paths from "@/utils/paths"; import ModalWrapper from "../ModalWrapper"; +import { useParams } from "react-router-dom"; export default function WorkspaceChat({ loading, workspace }) { + const { threadSlug = null } = useParams(); const [history, setHistory] = useState([]); const [loadingHistory, setLoadingHistory] = useState(true); @@ -17,7 +19,9 @@ export default function WorkspaceChat({ loading, workspace }) { return false; } - const chatHistory = await Workspace.chatHistory(workspace.slug); + const chatHistory = threadSlug + ? await Workspace.threads.chatHistory(workspace.slug, threadSlug) + : await Workspace.chatHistory(workspace.slug); setHistory(chatHistory); setLoadingHistory(false); } diff --git a/frontend/src/models/workspace.js b/frontend/src/models/workspace.js index 811f662d..0adcf3fa 100644 --- a/frontend/src/models/workspace.js +++ b/frontend/src/models/workspace.js @@ -1,6 +1,7 @@ import { API_BASE } from "@/utils/constants"; import { baseHeaders } from "@/utils/request"; import { fetchEventSource } from "@microsoft/fetch-event-source"; +import WorkspaceThread from "@/models/workspaceThread"; import { v4 } from "uuid"; const Workspace = { @@ -204,6 +205,7 @@ const Workspace = { return { success: false, error: e.message }; }); }, + threads: WorkspaceThread, }; export default Workspace; diff --git a/frontend/src/models/workspaceThread.js b/frontend/src/models/workspaceThread.js new file mode 100644 index 00000000..256ea496 --- /dev/null +++ b/frontend/src/models/workspaceThread.js @@ -0,0 +1,146 @@ +import { API_BASE } from "@/utils/constants"; +import { baseHeaders } from "@/utils/request"; +import { fetchEventSource } from "@microsoft/fetch-event-source"; +import { v4 } from "uuid"; + +const WorkspaceThread = { + all: async function (workspaceSlug) { + const { threads } = await fetch( + `${API_BASE}/workspace/${workspaceSlug}/threads`, + { + method: "GET", + headers: baseHeaders(), + } + ) + .then((res) => res.json()) + .catch((e) => { + return { threads: [] }; + }); + + return { threads }; + }, + new: async function (workspaceSlug) { + const { thread, error } = await fetch( + `${API_BASE}/workspace/${workspaceSlug}/thread/new`, + { + method: "POST", + headers: baseHeaders(), + } + ) + .then((res) => res.json()) + .catch((e) => { + return { thread: null, error: e.message }; + }); + + return { thread, error }; + }, + update: async function (workspaceSlug, threadSlug, data = {}) { + const { thread, message } = await fetch( + `${API_BASE}/workspace/${workspaceSlug}/thread/${threadSlug}/update`, + { + method: "POST", + body: JSON.stringify(data), + headers: baseHeaders(), + } + ) + .then((res) => res.json()) + .catch((e) => { + return { thread: null, message: e.message }; + }); + + return { thread, message }; + }, + delete: async function (workspaceSlug, threadSlug) { + return await fetch( + `${API_BASE}/workspace/${workspaceSlug}/thread/${threadSlug}`, + { + method: "DELETE", + headers: baseHeaders(), + } + ) + .then((res) => res.ok) + .catch(() => false); + }, + chatHistory: async function (workspaceSlug, threadSlug) { + const history = await fetch( + `${API_BASE}/workspace/${workspaceSlug}/thread/${threadSlug}/chats`, + { + method: "GET", + headers: baseHeaders(), + } + ) + .then((res) => res.json()) + .then((res) => res.history || []) + .catch(() => []); + return history; + }, + streamChat: async function ( + { workspaceSlug, threadSlug }, + message, + mode = "query", + handleChat + ) { + const ctrl = new AbortController(); + await fetchEventSource( + `${API_BASE}/workspace/${workspaceSlug}/thread/${threadSlug}/stream-chat`, + { + method: "POST", + body: JSON.stringify({ message, mode }), + headers: baseHeaders(), + signal: ctrl.signal, + openWhenHidden: true, + async onopen(response) { + if (response.ok) { + return; // everything's good + } else if ( + response.status >= 400 && + response.status < 500 && + response.status !== 429 + ) { + handleChat({ + id: v4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `An error occurred while streaming response. Code ${response.status}`, + }); + ctrl.abort(); + throw new Error("Invalid Status code response."); + } else { + handleChat({ + id: v4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `An error occurred while streaming response. Unknown Error.`, + }); + ctrl.abort(); + throw new Error("Unknown error"); + } + }, + async onmessage(msg) { + try { + const chatResult = JSON.parse(msg.data); + handleChat(chatResult); + } catch {} + }, + onerror(err) { + handleChat({ + id: v4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `An error occurred while streaming response. ${err.message}`, + }); + ctrl.abort(); + throw new Error(); + }, + } + ); + }, +}; + +export default WorkspaceThread; diff --git a/frontend/src/pages/WorkspaceChat/index.jsx b/frontend/src/pages/WorkspaceChat/index.jsx index 9575b64c..88a6744c 100644 --- a/frontend/src/pages/WorkspaceChat/index.jsx +++ b/frontend/src/pages/WorkspaceChat/index.jsx @@ -19,7 +19,7 @@ export default function WorkspaceChat() { } function ShowWorkspaceChat() { - const { slug } = useParams(); + const { slug, threadSlug = null } = useParams(); const [workspace, setWorkspace] = useState(null); const [loading, setLoading] = useState(true); @@ -27,6 +27,10 @@ function ShowWorkspaceChat() { async function getWorkspace() { if (!slug) return; const _workspace = await Workspace.bySlug(slug); + if (!_workspace) { + setLoading(false); + return; + } const suggestedMessages = await Workspace.getSuggestedMessages(slug); setWorkspace({ ..._workspace, diff --git a/frontend/src/utils/paths.js b/frontend/src/utils/paths.js index 06428c60..e57a2641 100644 --- a/frontend/src/utils/paths.js +++ b/frontend/src/utils/paths.js @@ -58,6 +58,9 @@ export default { additionalSettings: (slug) => { return `/workspace/${slug}/settings`; }, + thread: (wsSlug, threadSlug) => { + return `/workspace/${wsSlug}/t/${threadSlug}`; + }, }, apiDocs: () => { return `${API_BASE}/docs`; diff --git a/server/endpoints/chat.js b/server/endpoints/chat.js index 848a7a36..d45ad7b4 100644 --- a/server/endpoints/chat.js +++ b/server/endpoints/chat.js @@ -15,6 +15,9 @@ const { flexUserRoleValid, } = require("../utils/middleware/multiUserProtected"); const { EventLogs } = require("../models/eventLogs"); +const { + validWorkspaceAndThreadSlug, +} = require("../utils/middleware/validWorkspace"); function chatEndpoints(app) { if (!app) return; @@ -123,6 +126,117 @@ function chatEndpoints(app) { } } ); + + app.post( + "/workspace/:slug/thread/:threadSlug/stream-chat", + [ + validatedRequest, + flexUserRoleValid([ROLES.all]), + validWorkspaceAndThreadSlug, + ], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { message, mode = "query" } = reqBody(request); + const workspace = response.locals.workspace; + const thread = response.locals.thread; + + if (!message?.length || !VALID_CHAT_MODE.includes(mode)) { + response.status(400).json({ + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: !message?.length + ? "Message is empty." + : `${mode} is not a valid mode.`, + }); + return; + } + + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Content-Type", "text/event-stream"); + response.setHeader("Access-Control-Allow-Origin", "*"); + response.setHeader("Connection", "keep-alive"); + response.flushHeaders(); + + if (multiUserMode(response) && user.role !== ROLES.admin) { + const limitMessagesSetting = await SystemSettings.get({ + label: "limit_user_messages", + }); + const limitMessages = limitMessagesSetting?.value === "true"; + + if (limitMessages) { + const messageLimitSetting = await SystemSettings.get({ + label: "message_limit", + }); + const systemLimit = Number(messageLimitSetting?.value); + + if (!!systemLimit) { + // Chat qty includes all threads because any user can freely + // create threads and would bypass this rule. + const currentChatCount = await WorkspaceChats.count({ + user_id: user.id, + createdAt: { + gte: new Date(new Date() - 24 * 60 * 60 * 1000), + }, + }); + + if (currentChatCount >= systemLimit) { + writeResponseChunk(response, { + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: `You have met your maximum 24 hour chat quota of ${systemLimit} chats set by the instance administrators. Try again later.`, + }); + return; + } + } + } + } + + await streamChatWithWorkspace( + response, + workspace, + message, + mode, + user, + thread + ); + await Telemetry.sendTelemetry("sent_chat", { + multiUserMode: multiUserMode(response), + LLMSelection: process.env.LLM_PROVIDER || "openai", + Embedder: process.env.EMBEDDING_ENGINE || "inherit", + VectorDbSelection: process.env.VECTOR_DB || "pinecone", + }); + + await EventLogs.logEvent( + "sent_chat", + { + workspaceName: workspace.name, + thread: thread.name, + chatModel: workspace?.chatModel || "System Default", + }, + user?.id + ); + response.end(); + } catch (e) { + console.error(e); + writeResponseChunk(response, { + id: uuidv4(), + type: "abort", + textResponse: null, + sources: [], + close: true, + error: e.message, + }); + response.end(); + } + } + ); } module.exports = { chatEndpoints }; diff --git a/server/endpoints/workspaceThreads.js b/server/endpoints/workspaceThreads.js new file mode 100644 index 00000000..d1d0909c --- /dev/null +++ b/server/endpoints/workspaceThreads.js @@ -0,0 +1,150 @@ +const { multiUserMode, userFromSession, reqBody } = require("../utils/http"); +const { validatedRequest } = require("../utils/middleware/validatedRequest"); +const { Telemetry } = require("../models/telemetry"); +const { + flexUserRoleValid, + ROLES, +} = require("../utils/middleware/multiUserProtected"); +const { EventLogs } = require("../models/eventLogs"); +const { WorkspaceThread } = require("../models/workspaceThread"); +const { + validWorkspaceSlug, + validWorkspaceAndThreadSlug, +} = require("../utils/middleware/validWorkspace"); +const { WorkspaceChats } = require("../models/workspaceChats"); +const { convertToChatHistory } = require("../utils/chats"); + +function workspaceThreadEndpoints(app) { + if (!app) return; + + app.post( + "/workspace/:slug/thread/new", + [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const workspace = response.locals.workspace; + const { thread, message } = await WorkspaceThread.new( + workspace, + user?.id + ); + await Telemetry.sendTelemetry( + "workspace_thread_created", + { + multiUserMode: multiUserMode(response), + LLMSelection: process.env.LLM_PROVIDER || "openai", + Embedder: process.env.EMBEDDING_ENGINE || "inherit", + VectorDbSelection: process.env.VECTOR_DB || "pinecone", + }, + user?.id + ); + + await EventLogs.logEvent( + "workspace_thread_created", + { + workspaceName: workspace?.name || "Unknown Workspace", + }, + user?.id + ); + response.status(200).json({ thread, message }); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); + + app.get( + "/workspace/:slug/threads", + [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const workspace = response.locals.workspace; + const threads = await WorkspaceThread.where({ + workspace_id: workspace.id, + user_id: user?.id || null, + }); + response.status(200).json({ threads }); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); + + app.delete( + "/workspace/:slug/thread/:threadSlug", + [ + validatedRequest, + flexUserRoleValid([ROLES.all]), + validWorkspaceAndThreadSlug, + ], + async (_, response) => { + try { + const thread = response.locals.thread; + await WorkspaceThread.delete({ id: thread.id }); + response.sendStatus(200).end(); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); + + app.get( + "/workspace/:slug/thread/:threadSlug/chats", + [ + validatedRequest, + flexUserRoleValid([ROLES.all]), + validWorkspaceAndThreadSlug, + ], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const workspace = response.locals.workspace; + const thread = response.locals.thread; + const history = await WorkspaceChats.where( + { + workspaceId: workspace.id, + user_id: user?.id || null, + thread_id: thread.id, + include: true, + }, + null, + { id: "asc" } + ); + + response.status(200).json({ history: convertToChatHistory(history) }); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); + + app.post( + "/workspace/:slug/thread/:threadSlug/update", + [ + validatedRequest, + flexUserRoleValid([ROLES.all]), + validWorkspaceAndThreadSlug, + ], + async (request, response) => { + try { + const data = reqBody(request); + const currentThread = response.locals.thread; + const { thread, message } = await WorkspaceThread.update( + currentThread, + data + ); + response.status(200).json({ thread, message }); + } catch (e) { + console.log(e.message, e); + response.sendStatus(500).end(); + } + } + ); +} + +module.exports = { workspaceThreadEndpoints }; diff --git a/server/index.js b/server/index.js index 3d613191..ca09cc92 100644 --- a/server/index.js +++ b/server/index.js @@ -19,6 +19,7 @@ const { utilEndpoints } = require("./endpoints/utils"); const { developerEndpoints } = require("./endpoints/api"); const { extensionEndpoints } = require("./endpoints/extensions"); const { bootHTTP, bootSSL } = require("./utils/boot"); +const { workspaceThreadEndpoints } = require("./endpoints/workspaceThreads"); const app = express(); const apiRouter = express.Router(); const FILE_LIMIT = "3GB"; @@ -37,6 +38,7 @@ app.use("/api", apiRouter); systemEndpoints(apiRouter); extensionEndpoints(apiRouter); workspaceEndpoints(apiRouter); +workspaceThreadEndpoints(apiRouter); chatEndpoints(apiRouter); adminEndpoints(apiRouter); inviteEndpoints(apiRouter); diff --git a/server/models/workspaceChats.js b/server/models/workspaceChats.js index b91b675e..1dd20517 100644 --- a/server/models/workspaceChats.js +++ b/server/models/workspaceChats.js @@ -1,7 +1,13 @@ const prisma = require("../utils/prisma"); const WorkspaceChats = { - new: async function ({ workspaceId, prompt, response = {}, user = null }) { + new: async function ({ + workspaceId, + prompt, + response = {}, + user = null, + threadId = null, + }) { try { const chat = await prisma.workspace_chats.create({ data: { @@ -9,6 +15,7 @@ const WorkspaceChats = { prompt, response: JSON.stringify(response), user_id: user?.id || null, + thread_id: threadId, }, }); return { chat, message: null }; @@ -30,6 +37,7 @@ const WorkspaceChats = { where: { workspaceId, user_id: userId, + thread_id: null, // this function is now only used for the default thread on workspaces and users include: true, }, ...(limit !== null ? { take: limit } : {}), @@ -52,6 +60,7 @@ const WorkspaceChats = { const chats = await prisma.workspace_chats.findMany({ where: { workspaceId, + thread_id: null, // this function is now only used for the default thread on workspaces include: true, }, ...(limit !== null ? { take: limit } : {}), @@ -82,6 +91,29 @@ const WorkspaceChats = { } }, + markThreadHistoryInvalid: async function ( + workspaceId = null, + user = null, + threadId = null + ) { + if (!workspaceId || !threadId) return; + try { + await prisma.workspace_chats.updateMany({ + where: { + workspaceId, + thread_id: threadId, + user_id: user?.id, + }, + data: { + include: false, + }, + }); + return; + } catch (error) { + console.error(error.message); + } + }, + get: async function (clause = {}, limit = null, orderBy = null) { try { const chat = await prisma.workspace_chats.findFirst({ diff --git a/server/models/workspaceThread.js b/server/models/workspaceThread.js new file mode 100644 index 00000000..45c9b0f1 --- /dev/null +++ b/server/models/workspaceThread.js @@ -0,0 +1,86 @@ +const prisma = require("../utils/prisma"); +const { v4: uuidv4 } = require("uuid"); + +const WorkspaceThread = { + writable: ["name"], + + new: async function (workspace, userId = null) { + try { + const thread = await prisma.workspace_threads.create({ + data: { + name: "New thread", + slug: uuidv4(), + user_id: userId ? Number(userId) : null, + workspace_id: workspace.id, + }, + }); + + return { thread, message: null }; + } catch (error) { + console.error(error.message); + return { thread: null, message: error.message }; + } + }, + + update: async function (prevThread = null, data = {}) { + if (!prevThread) throw new Error("No thread id provided for update"); + + const validKeys = Object.keys(data).filter((key) => + this.writable.includes(key) + ); + if (validKeys.length === 0) + return { thread: prevThread, message: "No valid fields to update!" }; + + try { + const thread = await prisma.workspace_threads.update({ + where: { id: prevThread.id }, + data, + }); + return { thread, message: null }; + } catch (error) { + console.error(error.message); + return { thread: null, message: error.message }; + } + }, + + get: async function (clause = {}) { + try { + const thread = await prisma.workspace_threads.findFirst({ + where: clause, + }); + + return thread || null; + } catch (error) { + console.error(error.message); + return null; + } + }, + + delete: async function (clause = {}) { + try { + await prisma.workspace_threads.delete({ + where: clause, + }); + return true; + } catch (error) { + console.error(error.message); + return false; + } + }, + + where: async function (clause = {}, limit = null, orderBy = null) { + try { + const results = await prisma.workspace_threads.findMany({ + where: clause, + ...(limit !== null ? { take: limit } : {}), + ...(orderBy !== null ? { orderBy } : {}), + }); + return results; + } catch (error) { + console.error(error.message); + return []; + } + }, +}; + +module.exports = { WorkspaceThread }; diff --git a/server/prisma/migrations/20240208224848_init/migration.sql b/server/prisma/migrations/20240208224848_init/migration.sql new file mode 100644 index 00000000..f7e65619 --- /dev/null +++ b/server/prisma/migrations/20240208224848_init/migration.sql @@ -0,0 +1,24 @@ +-- AlterTable +ALTER TABLE "workspace_chats" ADD COLUMN "thread_id" INTEGER; + +-- CreateTable +CREATE TABLE "workspace_threads" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "name" TEXT NOT NULL, + "slug" TEXT NOT NULL, + "workspace_id" INTEGER NOT NULL, + "user_id" INTEGER, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "lastUpdatedAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "workspace_threads_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces" ("id") ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT "workspace_threads_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateIndex +CREATE UNIQUE INDEX "workspace_threads_slug_key" ON "workspace_threads"("slug"); + +-- CreateIndex +CREATE INDEX "workspace_threads_workspace_id_idx" ON "workspace_threads"("workspace_id"); + +-- CreateIndex +CREATE INDEX "workspace_threads_user_id_idx" ON "workspace_threads"("user_id"); diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index 1747db32..c52e1a4b 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -54,18 +54,19 @@ model system_settings { } model users { - id Int @id @default(autoincrement()) - username String? @unique + id Int @id @default(autoincrement()) + username String? @unique password String pfpFilename String? - role String @default("default") - suspended Int @default(0) - createdAt DateTime @default(now()) - lastUpdatedAt DateTime @default(now()) + role String @default("default") + suspended Int @default(0) + createdAt DateTime @default(now()) + lastUpdatedAt DateTime @default(now()) workspace_chats workspace_chats[] workspace_users workspace_users[] embed_configs embed_configs[] embed_chats embed_chats[] + threads workspace_threads[] } model document_vectors { @@ -101,6 +102,22 @@ model workspaces { documents workspace_documents[] workspace_suggested_messages workspace_suggested_messages[] embed_configs embed_configs[] + threads workspace_threads[] +} + +model workspace_threads { + id Int @id @default(autoincrement()) + name String + slug String @unique + workspace_id Int + user_id Int? + createdAt DateTime @default(now()) + lastUpdatedAt DateTime @default(now()) + workspace workspaces @relation(fields: [workspace_id], references: [id], onDelete: Cascade) + user users? @relation(fields: [user_id], references: [id], onDelete: Cascade) + + @@index([workspace_id]) + @@index([user_id]) } model workspace_suggested_messages { @@ -122,6 +139,7 @@ model workspace_chats { response String include Boolean @default(true) user_id Int? + thread_id Int? // No relation to prevent whole table migration createdAt DateTime @default(now()) lastUpdatedAt DateTime @default(now()) users users? @relation(fields: [user_id], references: [id], onDelete: Cascade, onUpdate: Cascade) diff --git a/server/utils/chats/commands/reset.js b/server/utils/chats/commands/reset.js index 8851efdf..a23eef7a 100644 --- a/server/utils/chats/commands/reset.js +++ b/server/utils/chats/commands/reset.js @@ -1,7 +1,21 @@ const { WorkspaceChats } = require("../../../models/workspaceChats"); -async function resetMemory(workspace, _message, msgUUID, user = null) { - await WorkspaceChats.markHistoryInvalid(workspace.id, user); +async function resetMemory( + workspace, + _message, + msgUUID, + user = null, + thread = null +) { + // If thread is present we are wanting to reset this specific thread. Not the whole workspace. + thread + ? await WorkspaceChats.markThreadHistoryInvalid( + workspace.id, + user, + thread.id + ) + : await WorkspaceChats.markHistoryInvalid(workspace.id, user); + return { uuid: msgUUID, type: "textResponse", diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 102189cb..8ec7d900 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -204,6 +204,8 @@ async function chatWithWorkspace( // On query we dont return message history. All other chat modes and when chatting // with no embeddings we return history. +// TODO: Refactor to just run a .where on WorkspaceChat to simplify what is going on here. +// see recentThreadChatHistory async function recentChatHistory( user = null, workspace, @@ -226,6 +228,30 @@ async function recentChatHistory( return { rawHistory, chatHistory: convertToPromptHistory(rawHistory) }; } +// Extension of recentChatHistory that supports threads +async function recentThreadChatHistory( + user = null, + workspace, + thread, + messageLimit = 20, + chatMode = null +) { + if (chatMode === "query") return []; + const rawHistory = ( + await WorkspaceChats.where( + { + workspaceId: workspace.id, + user_id: user?.id || null, + thread_id: thread?.id || null, + include: true, + }, + messageLimit, + { id: "desc" } + ) + ).reverse(); + return { rawHistory, chatHistory: convertToPromptHistory(rawHistory) }; +} + async function emptyEmbeddingChat({ uuid, user, @@ -270,6 +296,7 @@ function chatPrompt(workspace) { module.exports = { recentChatHistory, + recentThreadChatHistory, convertToPromptHistory, convertToChatHistory, chatWithWorkspace, diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index d16f6e60..11190d63 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -6,6 +6,7 @@ const { recentChatHistory, VALID_COMMANDS, chatPrompt, + recentThreadChatHistory, } = require("."); const VALID_CHAT_MODE = ["chat", "query"]; @@ -19,13 +20,20 @@ async function streamChatWithWorkspace( workspace, message, chatMode = "chat", - user = null + user = null, + thread = null ) { const uuid = uuidv4(); const command = grepCommand(message); if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { - const data = await VALID_COMMANDS[command](workspace, message, uuid, user); + const data = await VALID_COMMANDS[command]( + workspace, + message, + uuid, + user, + thread + ); writeResponseChunk(response, data); return; } @@ -65,6 +73,8 @@ async function streamChatWithWorkspace( } // If there are no embeddings - chat like a normal LLM chat interface. + // no need to pass in chat mode - because if we are here we are in + // "chat" mode + have embeddings. return await streamEmptyEmbeddingChat({ response, uuid, @@ -73,16 +83,21 @@ async function streamChatWithWorkspace( workspace, messageLimit, LLMConnector, + thread, }); } let completeText; - const { rawHistory, chatHistory } = await recentChatHistory( - user, - workspace, - messageLimit, - chatMode - ); + const { rawHistory, chatHistory } = thread + ? await recentThreadChatHistory( + user, + workspace, + thread, + messageLimit, + chatMode + ) + : await recentChatHistory(user, workspace, messageLimit, chatMode); + const { contextTexts = [], sources = [], @@ -167,6 +182,7 @@ async function streamChatWithWorkspace( prompt: message, response: { text: completeText, sources, type: chatMode }, user, + threadId: thread?.id, }); return; } @@ -179,13 +195,12 @@ async function streamEmptyEmbeddingChat({ workspace, messageLimit, LLMConnector, + thread = null, }) { let completeText; - const { rawHistory, chatHistory } = await recentChatHistory( - user, - workspace, - messageLimit - ); + const { rawHistory, chatHistory } = thread + ? await recentThreadChatHistory(user, workspace, thread, messageLimit) + : await recentChatHistory(user, workspace, messageLimit); // If streaming is not explicitly enabled for connector // we do regular waiting of a response and send a single chunk. @@ -225,6 +240,7 @@ async function streamEmptyEmbeddingChat({ prompt: message, response: { text: completeText, sources: [], type: "chat" }, user, + threadId: thread?.id, }); return; } diff --git a/server/utils/middleware/validWorkspace.js b/server/utils/middleware/validWorkspace.js new file mode 100644 index 00000000..10c58d98 --- /dev/null +++ b/server/utils/middleware/validWorkspace.js @@ -0,0 +1,52 @@ +const { Workspace } = require("../../models/workspace"); +const { WorkspaceThread } = require("../../models/workspaceThread"); +const { userFromSession, multiUserMode } = require("../http"); + +// Will pre-validate and set the workspace for a request if the slug is provided in the URL path. +async function validWorkspaceSlug(request, response, next) { + const { slug } = request.params; + const user = await userFromSession(request, response); + const workspace = multiUserMode(response) + ? await Workspace.getWithUser(user, { slug }) + : await Workspace.get({ slug }); + + if (!workspace) { + response.status(404).send("Workspace does not exist."); + return; + } + + response.locals.workspace = workspace; + next(); +} + +// Will pre-validate and set the workspace AND a thread for a request if the slugs are provided in the URL path. +async function validWorkspaceAndThreadSlug(request, response, next) { + const { slug, threadSlug } = request.params; + const user = await userFromSession(request, response); + const workspace = multiUserMode(response) + ? await Workspace.getWithUser(user, { slug }) + : await Workspace.get({ slug }); + + if (!workspace) { + response.status(404).send("Workspace does not exist."); + return; + } + + const thread = await WorkspaceThread.get({ + slug: threadSlug, + user_id: user?.id || null, + }); + if (!thread) { + response.status(404).send("Workspace thread does not exist."); + return; + } + + response.locals.workspace = workspace; + response.locals.thread = thread; + next(); +} + +module.exports = { + validWorkspaceSlug, + validWorkspaceAndThreadSlug, +};