From 608f28d7455a64e6521fc2522a158fa398ab474d Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Tue, 6 Feb 2024 11:24:33 -0800 Subject: [PATCH] [FEAT] create custom prompt suggestions per workspace (#664) * create custom suggested chat messages per workspace * update how suggestedChats are passed to chat window * update mobile styles * update edit change handler --------- Co-authored-by: timothycarambat --- frontend/src/App.jsx | 5 + .../Modals/MangeWorkspace/Settings/index.jsx | 8 + .../ChatContainer/ChatHistory/index.jsx | 54 +++-- .../WorkspaceChat/ChatContainer/index.jsx | 6 +- frontend/src/models/workspace.js | 36 +++ frontend/src/pages/WorkspaceChat/index.jsx | 6 +- .../src/pages/WorkspaceSettings/index.jsx | 208 ++++++++++++++++++ frontend/src/utils/paths.js | 3 + server/endpoints/workspaces.js | 50 +++++ server/models/workspacesSuggestedMessages.js | 83 +++++++ .../20240206181106_init/migration.sql | 13 ++ server/prisma/schema.prisma | 43 ++-- 12 files changed, 483 insertions(+), 32 deletions(-) create mode 100644 frontend/src/pages/WorkspaceSettings/index.jsx create mode 100644 server/models/workspacesSuggestedMessages.js create mode 100644 server/prisma/migrations/20240206181106_init/migration.sql diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 0c6bfaf4..a0836512 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -41,6 +41,7 @@ const DataConnectors = lazy( const DataConnectorSetup = lazy( () => import("@/pages/GeneralSettings/DataConnectors/Connectors") ); +const WorkspaceSettings = lazy(() => import("@/pages/WorkspaceSettings")); const EmbedConfigSetup = lazy( () => import("@/pages/GeneralSettings/EmbedConfigs") ); @@ -62,6 +63,10 @@ export default function App() { } /> {/* Admin */} + } + /> } diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx index a9471388..48a3ff5d 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx @@ -7,6 +7,7 @@ import PreLoader from "../../../Preloader"; import { useParams } from "react-router-dom"; import showToast from "../../../../utils/toast"; import ChatModelPreference from "./ChatModelPreference"; +import { Link } from "react-router-dom"; // Ensure that a type is correct before sending the body // to the backend. @@ -313,6 +314,13 @@ export default function WorkspaceSettings({ active, workspace, settings }) { + diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx index 358e520a..74c159f4 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx @@ -6,7 +6,7 @@ import ManageWorkspace from "../../../Modals/MangeWorkspace"; import { ArrowDown } from "@phosphor-icons/react"; import debounce from "lodash.debounce"; -export default function ChatHistory({ history = [], workspace }) { +export default function ChatHistory({ history = [], workspace, sendCommand }) { const replyRef = useRef(null); const { showing, showModal, hideModal } = useManageWorkspaceModal(); const [isAtBottom, setIsAtBottom] = useState(true); @@ -46,25 +46,31 @@ export default function ChatHistory({ history = [], workspace }) { } }; + const handleSendSuggestedMessage = (heading, message) => { + sendCommand(`${heading} ${message}`, true); + }; + if (history.length === 0) { return ( -
-
+
+

Welcome to your new workspace.

-
-

- To get started either{" "} - - upload a document - - or send a chat. -

-
+

+ To get started either{" "} + + upload a document + + or send a chat. +

+
{showing && ( ); } + +function WorkspaceChatSuggestions({ suggestions = [], sendSuggestion }) { + if (suggestions.length === 0) return null; + return ( +
+ {suggestions.map((suggestion, index) => ( + + ))} +
+ ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 372c79a7..7a5a974a 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -97,7 +97,11 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { > {isMobile && }
- + { + if (!res.ok) throw new Error("Could not fetch suggested messages."); + return res.json(); + }) + .then((res) => res.suggestedMessages) + .catch((e) => { + console.error(e); + return null; + }); + }, + setSuggestedMessages: async function (slug, messages) { + return fetch(`${API_BASE}/workspace/${slug}/suggested-messages`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify({ messages }), + }) + .then((res) => { + if (!res.ok) { + throw new Error( + res.statusText || "Error setting suggested messages." + ); + } + return { success: true, ...res.json() }; + }) + .catch((e) => { + console.error(e); + return { success: false, error: e.message }; + }); + }, }; export default Workspace; diff --git a/frontend/src/pages/WorkspaceChat/index.jsx b/frontend/src/pages/WorkspaceChat/index.jsx index 7db652a9..9575b64c 100644 --- a/frontend/src/pages/WorkspaceChat/index.jsx +++ b/frontend/src/pages/WorkspaceChat/index.jsx @@ -27,7 +27,11 @@ function ShowWorkspaceChat() { async function getWorkspace() { if (!slug) return; const _workspace = await Workspace.bySlug(slug); - setWorkspace(_workspace); + const suggestedMessages = await Workspace.getSuggestedMessages(slug); + setWorkspace({ + ..._workspace, + suggestedMessages, + }); setLoading(false); } getWorkspace(); diff --git a/frontend/src/pages/WorkspaceSettings/index.jsx b/frontend/src/pages/WorkspaceSettings/index.jsx new file mode 100644 index 00000000..35743d13 --- /dev/null +++ b/frontend/src/pages/WorkspaceSettings/index.jsx @@ -0,0 +1,208 @@ +import React, { useState, useEffect } from "react"; +import { useParams } from "react-router-dom"; +import { isMobile } from "react-device-detect"; +import showToast from "@/utils/toast"; +import { ArrowUUpLeft, Plus, X } from "@phosphor-icons/react"; +import Workspace from "@/models/workspace"; +import paths from "@/utils/paths"; + +export default function WorkspaceSettings() { + const [hasChanges, setHasChanges] = useState(false); + const [workspace, setWorkspace] = useState(null); + const [suggestedMessages, setSuggestedMessages] = useState([]); + const [editingIndex, setEditingIndex] = useState(-1); + const [newMessage, setNewMessage] = useState({ heading: "", message: "" }); + const { slug } = useParams(); + + useEffect(() => { + async function fetchWorkspace() { + if (!slug) return; + const workspace = await Workspace.bySlug(slug); + const suggestedMessages = await Workspace.getSuggestedMessages(slug); + setWorkspace(workspace); + setSuggestedMessages(suggestedMessages); + } + fetchWorkspace(); + }, [slug]); + + const handleSaveSuggestedMessages = async () => { + const validMessages = suggestedMessages.filter( + (msg) => + msg?.heading?.trim()?.length > 0 || msg?.message?.trim()?.length > 0 + ); + const { success, error } = await Workspace.setSuggestedMessages( + slug, + validMessages + ); + if (!success) { + showToast(`Failed to update welcome messages: ${error}`, "error"); + return; + } + showToast("Successfully updated welcome messages.", "success"); + setHasChanges(false); + }; + + const addMessage = () => { + setEditingIndex(-1); + if (suggestedMessages.length >= 4) { + showToast("Maximum of 4 messages allowed.", "warning"); + return; + } + const defaultMessage = { + heading: "Explain to me", + message: "the benefits of AnythingLLM", + }; + setNewMessage(defaultMessage); + setSuggestedMessages([...suggestedMessages, { ...defaultMessage }]); + setHasChanges(true); + }; + + const removeMessage = (index) => { + const messages = [...suggestedMessages]; + messages.splice(index, 1); + setSuggestedMessages(messages); + setHasChanges(true); + }; + + const startEditing = (index) => { + setEditingIndex(index); + setNewMessage({ ...suggestedMessages[index] }); + }; + + const handleRemoveMessage = (index) => { + removeMessage(index); + setEditingIndex(-1); + }; + + const onEditChange = (e) => { + const updatedNewMessage = { + ...newMessage, + [e.target.name]: e.target.value, + }; + setNewMessage(updatedNewMessage); + const updatedMessages = suggestedMessages.map((message, index) => { + if (index === editingIndex) { + return { ...message, [e.target.name]: e.target.value }; + } + return message; + }); + + setSuggestedMessages(updatedMessages); + setHasChanges(true); + }; + + return ( +
+ + + +
+
+
+
+

+ Workspace Settings ({workspace?.name}) +

+
+

+ Customize your workspace. +

+
+
+
+

+ Suggested Chat Messages +

+

+ Customize the messages that will be suggested to your workspace + users. +

+
+ +
+ {suggestedMessages.map((suggestion, index) => ( +
+ + +
+ ))} +
+ {editingIndex >= 0 && ( +
+
+ + +
+
+ + +
+
+ )} + {suggestedMessages.length < 4 && ( + + )} + + {hasChanges && ( +
+ +
+ )} +
+
+
+
+ ); +} diff --git a/frontend/src/utils/paths.js b/frontend/src/utils/paths.js index a9669300..8fbaacec 100644 --- a/frontend/src/utils/paths.js +++ b/frontend/src/utils/paths.js @@ -55,6 +55,9 @@ export default { chat: (slug) => { return `/workspace/${slug}`; }, + additionalSettings: (slug) => { + return `/workspace/${slug}/settings`; + }, }, apiDocs: () => { return `${API_BASE}/docs`; diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js index 25e39103..b04d2337 100644 --- a/server/endpoints/workspaces.js +++ b/server/endpoints/workspaces.js @@ -17,6 +17,9 @@ const { flexUserRoleValid, ROLES, } = require("../utils/middleware/multiUserProtected"); +const { + WorkspaceSuggestedMessages, +} = require("../models/workspacesSuggestedMessages"); const { handleUploads } = setupMulter(); function workspaceEndpoints(app) { @@ -283,6 +286,53 @@ function workspaceEndpoints(app) { } } ); + + app.get( + "/workspace/:slug/suggested-messages", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async function (request, response) { + try { + const { slug } = request.params; + const suggestedMessages = + await WorkspaceSuggestedMessages.getMessages(slug); + response.status(200).json({ success: true, suggestedMessages }); + } catch (error) { + console.error("Error fetching suggested messages:", error); + response + .status(500) + .json({ success: false, message: "Internal server error" }); + } + } + ); + + app.post( + "/workspace/:slug/suggested-messages", + [validatedRequest, flexUserRoleValid([ROLES.admin, ROLES.manager])], + async (request, response) => { + try { + const { messages = [] } = reqBody(request); + const { slug } = request.params; + if (!Array.isArray(messages)) { + return response.status(400).json({ + success: false, + message: "Invalid message format. Expected an array of messages.", + }); + } + + await WorkspaceSuggestedMessages.saveAll(messages, slug); + return response.status(200).json({ + success: true, + message: "Suggested messages saved successfully.", + }); + } catch (error) { + console.error("Error processing the suggested messages:", error); + response.status(500).json({ + success: true, + message: "Error saving the suggested messages.", + }); + } + } + ); } module.exports = { workspaceEndpoints }; diff --git a/server/models/workspacesSuggestedMessages.js b/server/models/workspacesSuggestedMessages.js new file mode 100644 index 00000000..ef35a5bb --- /dev/null +++ b/server/models/workspacesSuggestedMessages.js @@ -0,0 +1,83 @@ +const prisma = require("../utils/prisma"); + +const WorkspaceSuggestedMessages = { + get: async function (clause = {}) { + try { + const message = await prisma.workspace_suggested_messages.findFirst({ + where: clause, + }); + return message || null; + } catch (error) { + console.error(error.message); + return null; + } + }, + + where: async function (clause = {}, limit) { + try { + const messages = await prisma.workspace_suggested_messages.findMany({ + where: clause, + take: limit || undefined, + }); + return messages; + } catch (error) { + console.error(error.message); + return []; + } + }, + + saveAll: async function (messages, workspaceSlug) { + try { + const workspace = await prisma.workspaces.findUnique({ + where: { slug: workspaceSlug }, + }); + + if (!workspace) throw new Error("Workspace not found"); + + // Delete all existing messages for the workspace + await prisma.workspace_suggested_messages.deleteMany({ + where: { workspaceId: workspace.id }, + }); + + // Create new messages + // We create each message individually because prisma + // with sqlite does not support createMany() + for (const message of messages) { + await prisma.workspace_suggested_messages.create({ + data: { + workspaceId: workspace.id, + heading: message.heading, + message: message.message, + }, + }); + } + } catch (error) { + console.error("Failed to save all messages", error.message); + } + }, + + getMessages: async function (workspaceSlug) { + try { + const workspace = await prisma.workspaces.findUnique({ + where: { slug: workspaceSlug }, + }); + + if (!workspace) throw new Error("Workspace not found"); + + const messages = await prisma.workspace_suggested_messages.findMany({ + where: { workspaceId: workspace.id }, + orderBy: { createdAt: "asc" }, + }); + + return messages.map((msg) => ({ + heading: msg.heading, + message: msg.message, + })); + } catch (error) { + console.error("Failed to get all messages", error.message); + return []; + } + }, +}; + +module.exports.WorkspaceSuggestedMessages = WorkspaceSuggestedMessages; diff --git a/server/prisma/migrations/20240206181106_init/migration.sql b/server/prisma/migrations/20240206181106_init/migration.sql new file mode 100644 index 00000000..9655c7b7 --- /dev/null +++ b/server/prisma/migrations/20240206181106_init/migration.sql @@ -0,0 +1,13 @@ +-- CreateTable +CREATE TABLE "workspace_suggested_messages" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "workspaceId" INTEGER NOT NULL, + "heading" TEXT NOT NULL, + "message" TEXT NOT NULL, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "lastUpdatedAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "workspace_suggested_messages_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "workspaces" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateIndex +CREATE INDEX "workspace_suggested_messages_workspaceId_idx" ON "workspace_suggested_messages"("workspaceId"); diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index 314a8359..ede8a1fd 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -85,21 +85,34 @@ model welcome_messages { } model workspaces { - id Int @id @default(autoincrement()) - name String - slug String @unique - vectorTag String? - createdAt DateTime @default(now()) - openAiTemp Float? - openAiHistory Int @default(20) - lastUpdatedAt DateTime @default(now()) - openAiPrompt String? - similarityThreshold Float? @default(0.25) - chatModel String? - topN Int? @default(4) - workspace_users workspace_users[] - documents workspace_documents[] - embed_configs embed_configs[] + id Int @id @default(autoincrement()) + name String + slug String @unique + vectorTag String? + createdAt DateTime @default(now()) + openAiTemp Float? + openAiHistory Int @default(20) + lastUpdatedAt DateTime @default(now()) + openAiPrompt String? + similarityThreshold Float? @default(0.25) + chatModel String? + topN Int? @default(4) + workspace_users workspace_users[] + documents workspace_documents[] + workspace_suggested_messages workspace_suggested_messages[] + embed_configs embed_configs[] +} + +model workspace_suggested_messages { + id Int @id @default(autoincrement()) + workspaceId Int + heading String + message String + createdAt DateTime @default(now()) + lastUpdatedAt DateTime @default(now()) + workspace workspaces @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + + @@index([workspaceId]) } model workspace_chats {