[FEAT] RLHF on response messages (#708)

* WIP RLHF works on historical messages

* refactor Actions component

* completed RLHF up and down votes for chats

* add defaults for HistoricalMessage params

* refactor RLHF implmenation
remove forwardRef on history items to prevent rerenders

* remove dup id

* Add rating to CSV output

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-02-13 11:33:05 -08:00 committed by GitHub
parent 1b29882c71
commit f4b09a8c79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 319 additions and 140 deletions

View File

@ -89,6 +89,8 @@ export default function ChatHistory({ settings = {}, history = [] }) {
message={props.content}
role={props.role}
sources={props.sources}
chatId={props.chatId}
feedbackScore={props.feedbackScore}
error={props.error}
/>
);

View File

@ -1,27 +1,91 @@
import React, { memo, useState } from "react";
import useCopyText from "@/hooks/useCopyText";
import { Check, ClipboardText } from "@phosphor-icons/react";
import { memo } from "react";
import {
Check,
ClipboardText,
ThumbsUp,
ThumbsDown,
} from "@phosphor-icons/react";
import { Tooltip } from "react-tooltip";
import Workspace from "@/models/workspace";
const Actions = ({ message, feedbackScore, chatId, slug }) => {
const [selectedFeedback, setSelectedFeedback] = useState(feedbackScore);
const handleFeedback = async (newFeedback) => {
const updatedFeedback =
selectedFeedback === newFeedback ? null : newFeedback;
await Workspace.updateChatFeedback(chatId, slug, updatedFeedback);
setSelectedFeedback(updatedFeedback);
};
const Actions = ({ message }) => {
return (
<div className="flex justify-start items-center gap-x-4">
<CopyMessage message={message} />
{/* Other actions to go here later. */}
{chatId && (
<>
<FeedbackButton
isSelected={selectedFeedback === true}
handleFeedback={() => handleFeedback(true)}
tooltipId={`${chatId}-thumbs-up`}
tooltipContent="Good response"
IconComponent={ThumbsUp}
/>
<FeedbackButton
isSelected={selectedFeedback === false}
handleFeedback={() => handleFeedback(false)}
tooltipId={`${chatId}-thumbs-down`}
tooltipContent="Bad response"
IconComponent={ThumbsDown}
/>
</>
)}
</div>
);
};
function FeedbackButton({
isSelected,
handleFeedback,
tooltipId,
tooltipContent,
IconComponent,
}) {
return (
<div className="mt-3 relative">
<button
onClick={handleFeedback}
data-tooltip-id={tooltipId}
data-tooltip-content={tooltipContent}
className="text-zinc-300"
>
<IconComponent
size={18}
className="mb-1"
weight={isSelected ? "fill" : "regular"}
/>
</button>
<Tooltip
id={tooltipId}
place="bottom"
delayShow={300}
className="tooltip !text-xs"
/>
</div>
);
}
function CopyMessage({ message }) {
const { copied, copyText } = useCopyText();
return (
<>
<div className="mt-3 relative">
<button
onClick={() => copyText(message)}
data-tooltip-id="copy-assistant-text"
data-tooltip-content="Copy"
className="text-zinc-300"
onClick={() => copyText(message)}
>
{copied ? (
<Check size={18} className="mb-1" />
@ -29,13 +93,13 @@ function CopyMessage({ message }) {
<ClipboardText size={18} className="mb-1" />
)}
</button>
<Tooltip
id="copy-assistant-text"
place="bottom"
delayShow={300}
className="tooltip !text-xs"
/>
</div>
<Tooltip
id="copy-assistant-text"
place="bottom"
delayShow={300}
className="tooltip !text-xs"
/>
</>
);
}

View File

@ -1,4 +1,4 @@
import React, { memo, forwardRef } from "react";
import React, { memo } from "react";
import { Warning } from "@phosphor-icons/react";
import Jazzicon from "../../../../UserIcon";
import Actions from "./Actions";
@ -10,64 +10,70 @@ import { v4 } from "uuid";
import createDOMPurify from "dompurify";
const DOMPurify = createDOMPurify(window);
const HistoricalMessage = forwardRef(
(
{ uuid = v4(), message, role, workspace, sources = [], error = false },
ref
) => {
return (
const HistoricalMessage = ({
uuid = v4(),
message,
role,
workspace,
sources = [],
error = false,
feedbackScore = null,
chatId = null,
}) => {
return (
<div
key={uuid}
className={`flex justify-center items-end w-full ${
role === "user" ? USER_BACKGROUND_COLOR : AI_BACKGROUND_COLOR
}`}
>
<div
key={uuid}
ref={ref}
className={`flex justify-center items-end w-full ${
role === "user" ? USER_BACKGROUND_COLOR : AI_BACKGROUND_COLOR
}`}
className={`py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col`}
>
<div
className={`py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col`}
>
<div className="flex gap-x-5">
<Jazzicon
size={36}
user={{
uid:
role === "user"
? userFromStorage()?.username
: workspace.slug,
}}
role={role}
/>
<div className="flex gap-x-5">
<Jazzicon
size={36}
user={{
uid:
role === "user" ? userFromStorage()?.username : workspace.slug,
}}
role={role}
/>
{error ? (
<div className="p-2 rounded-lg bg-red-50 text-red-500">
<span className={`inline-block `}>
<Warning className="h-4 w-4 mb-1 inline-block" /> Could not
respond to message.
</span>
<p className="text-xs font-mono mt-2 border-l-2 border-red-300 pl-2 bg-red-200 p-2 rounded-sm">
{error}
</p>
</div>
) : (
<span
className={`whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(renderMarkdown(message)),
}}
/>
)}
</div>
{role === "assistant" && !error && (
<div className="flex gap-x-5">
<div className="relative w-[35px] h-[35px] rounded-full flex-shrink-0 overflow-hidden" />
<Actions message={DOMPurify.sanitize(message)} />
{error ? (
<div className="p-2 rounded-lg bg-red-50 text-red-500">
<span className={`inline-block `}>
<Warning className="h-4 w-4 mb-1 inline-block" /> Could not
respond to message.
</span>
<p className="text-xs font-mono mt-2 border-l-2 border-red-300 pl-2 bg-red-200 p-2 rounded-sm">
{error}
</p>
</div>
) : (
<span
className={`whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(renderMarkdown(message)),
}}
/>
)}
{role === "assistant" && <Citations sources={sources} />}
</div>
{role === "assistant" && !error && (
<div className="flex gap-x-5">
<div className="relative w-[35px] h-[35px] rounded-full flex-shrink-0 overflow-hidden" />
<Actions
message={DOMPurify.sanitize(message)}
feedbackScore={feedbackScore}
chatId={chatId}
slug={workspace?.slug}
/>
</div>
)}
{role === "assistant" && <Citations sources={sources} />}
</div>
);
}
);
</div>
);
};
export default memo(HistoricalMessage);

View File

@ -1,67 +1,44 @@
import { forwardRef, memo } from "react";
import { memo } from "react";
import { Warning } from "@phosphor-icons/react";
import Jazzicon from "../../../../UserIcon";
import renderMarkdown from "@/utils/chat/markdown";
import Citations from "../Citation";
const PromptReply = forwardRef(
(
{ uuid, reply, pending, error, workspace, sources = [], closed = true },
ref
) => {
const assistantBackgroundColor = "bg-historical-msg-system";
const PromptReply = ({
uuid,
reply,
pending,
error,
workspace,
sources = [],
closed = true,
}) => {
const assistantBackgroundColor = "bg-historical-msg-system";
if (!reply && sources.length === 0 && !pending && !error) return null;
if (pending) {
return (
<div
ref={ref}
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
>
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
<div className="flex gap-x-5">
<Jazzicon
size={36}
user={{ uid: workspace.slug }}
role="assistant"
/>
<div className="mt-3 ml-5 dot-falling"></div>
</div>
</div>
</div>
);
}
if (error) {
return (
<div
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
>
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
<div className="flex gap-x-5">
<Jazzicon
size={36}
user={{ uid: workspace.slug }}
role="assistant"
/>
<span
className={`inline-block p-2 rounded-lg bg-red-50 text-red-500`}
>
<Warning className="h-4 w-4 mb-1 inline-block" /> Could not
respond to message.
<span className="text-xs">Reason: {error || "unknown"}</span>
</span>
</div>
</div>
</div>
);
}
if (!reply && sources.length === 0 && !pending && !error) return null;
if (pending) {
return (
<div
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
>
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
<div className="flex gap-x-5">
<Jazzicon
size={36}
user={{ uid: workspace.slug }}
role="assistant"
/>
<div className="mt-3 ml-5 dot-falling"></div>
</div>
</div>
</div>
);
}
if (error) {
return (
<div
key={uuid}
ref={ref}
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
>
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
@ -72,15 +49,35 @@ const PromptReply = forwardRef(
role="assistant"
/>
<span
className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }}
/>
className={`inline-block p-2 rounded-lg bg-red-50 text-red-500`}
>
<Warning className="h-4 w-4 mb-1 inline-block" /> Could not
respond to message.
<span className="text-xs">Reason: {error || "unknown"}</span>
</span>
</div>
<Citations sources={sources} />
</div>
</div>
);
}
);
return (
<div
key={uuid}
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
>
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
<div className="flex gap-x-5">
<Jazzicon size={36} user={{ uid: workspace.slug }} role="assistant" />
<span
className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }}
/>
</div>
<Citations sources={sources} />
</div>
</div>
);
};
export default memo(PromptReply);

View File

@ -7,7 +7,6 @@ import { ArrowDown } from "@phosphor-icons/react";
import debounce from "lodash.debounce";
export default function ChatHistory({ history = [], workspace, sendCommand }) {
const replyRef = useRef(null);
const { showing, showModal, hideModal } = useManageWorkspaceModal();
const [isAtBottom, setIsAtBottom] = useState(true);
const chatHistoryRef = useRef(null);
@ -89,7 +88,6 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
ref={chatHistoryRef}
>
{history.map((props, index) => {
const isLastMessage = index === history.length - 1;
const isLastBotReply =
index === history.length - 1 && props.role === "assistant";
@ -97,7 +95,6 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
return (
<PromptReply
key={props.uuid}
ref={isLastMessage ? replyRef : null}
uuid={props.uuid}
reply={props.content}
pending={props.pending}
@ -112,11 +109,12 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
return (
<HistoricalMessage
key={index}
ref={isLastMessage ? replyRef : null}
message={props.content}
role={props.role}
workspace={workspace}
sources={props.sources}
feedbackScore={props.feedbackScore}
chatId={props.chatId}
error={props.error}
/>
);

View File

@ -60,6 +60,19 @@ const Workspace = {
.catch(() => []);
return history;
},
updateChatFeedback: async function (chatId, slug, feedback) {
const result = await fetch(
`${API_BASE}/workspace/${slug}/chat-feedback/${chatId}`,
{
method: "POST",
headers: baseHeaders(),
body: JSON.stringify({ feedback }),
}
)
.then((res) => res.ok)
.catch(() => false);
return result;
},
streamChat: async function ({ slug }, message, mode = "query", handleChat) {
const ctrl = new AbortController();
await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, {

View File

@ -1,4 +1,4 @@
// For handling of synchronous chats that are not utilizing streaming or chat requests.
// For handling of chat responses in the frontend by their various types.
export default function handleChat(
chatResult,
setLoadingResponse,
@ -6,7 +6,15 @@ export default function handleChat(
remHistory,
_chatHistory
) {
const { uuid, textResponse, type, sources = [], error, close } = chatResult;
const {
uuid,
textResponse,
type,
sources = [],
error,
close,
chatId = null,
} = chatResult;
if (type === "abort") {
setLoadingResponse(false);
@ -46,6 +54,7 @@ export default function handleChat(
error,
animate: !close,
pending: false,
chatId,
},
]);
_chatHistory.push({
@ -57,6 +66,7 @@ export default function handleChat(
error,
animate: !close,
pending: false,
chatId,
});
} else if (type === "textResponseChunk") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
@ -70,6 +80,7 @@ export default function handleChat(
closed: close,
animate: !close,
pending: false,
chatId,
};
_chatHistory[chatIdx] = updatedHistory;
} else {
@ -82,9 +93,21 @@ export default function handleChat(
closed: close,
animate: !close,
pending: false,
chatId,
});
}
setChatHistory([..._chatHistory]);
} else if (type === "finalizeResponseStream") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
if (chatIdx !== -1) {
const existingHistory = { ..._chatHistory[chatIdx] };
const updatedHistory = {
...existingHistory,
chatId, // finalize response stream only has some specific keys for data. we are explicitly listing them here.
};
_chatHistory[chatIdx] = updatedHistory;
}
setChatHistory([..._chatHistory]);
}
}

View File

@ -21,6 +21,7 @@ const { EventLogs } = require("../models/eventLogs");
const {
WorkspaceSuggestedMessages,
} = require("../models/workspacesSuggestedMessages");
const { validWorkspaceSlug } = require("../utils/middleware/validWorkspace");
const { handleUploads } = setupMulter();
function workspaceEndpoints(app) {
@ -321,6 +322,35 @@ function workspaceEndpoints(app) {
}
);
app.post(
"/workspace/:slug/chat-feedback/:chatId",
[validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug],
async (request, response) => {
try {
const { chatId } = request.params;
const { feedback = null } = reqBody(request);
const existingChat = await WorkspaceChats.get({
id: Number(chatId),
workspaceId: response.locals.workspace.id,
});
if (!existingChat) {
response.status(404).end();
return;
}
const result = await WorkspaceChats.updateFeedbackScore(
chatId,
feedback
);
response.status(200).json({ success: result });
} catch (error) {
console.error("Error updating chat feedback:", error);
response.status(500).end();
}
}
);
app.get(
"/workspace/:slug/suggested-messages",
[validatedRequest, flexUserRoleValid([ROLES.all])],

View File

@ -203,6 +203,23 @@ const WorkspaceChats = {
return [];
}
},
updateFeedbackScore: async function (chatId = null, feedbackScore = null) {
if (!chatId) return;
try {
await prisma.workspace_chats.update({
where: {
id: Number(chatId),
},
data: {
feedbackScore:
feedbackScore === null ? null : Number(feedbackScore) === 1,
},
});
return;
} catch (error) {
console.error(error.message);
}
},
};
module.exports = { WorkspaceChats };

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspace_chats" ADD COLUMN "feedbackScore" BOOLEAN;

View File

@ -142,6 +142,7 @@ model workspace_chats {
thread_id Int? // No relation to prevent whole table migration
createdAt DateTime @default(now())
lastUpdatedAt DateTime @default(now())
feedbackScore Boolean?
users users? @relation(fields: [user_id], references: [id], onDelete: Cascade, onUpdate: Cascade)
}

View File

@ -7,7 +7,7 @@ const { getVectorDbClass, getLLMProvider } = require("../helpers");
function convertToChatHistory(history = []) {
const formattedHistory = [];
history.forEach((history) => {
const { prompt, response, createdAt } = history;
const { prompt, response, createdAt, feedbackScore = null, id } = history;
const data = JSON.parse(response);
formattedHistory.push([
{
@ -19,7 +19,9 @@ function convertToChatHistory(history = []) {
role: "assistant",
content: data.text,
sources: data.sources || [],
chatId: id,
sentAt: moment(createdAt).unix(),
feedbackScore,
},
]);
});
@ -185,8 +187,7 @@ async function chatWithWorkspace(
error: "No text completion could be completed with this input.",
};
}
await WorkspaceChats.new({
const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: textResponse, sources, type: chatMode },
@ -196,9 +197,10 @@ async function chatWithWorkspace(
id: uuid,
type: "textResponse",
close: true,
error: null,
chatId: chat.id,
textResponse,
sources,
error,
};
}
@ -271,7 +273,7 @@ async function emptyEmbeddingChat({
workspace,
rawHistory
);
await WorkspaceChats.new({
const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: textResponse, sources: [], type: "chat" },
@ -283,6 +285,7 @@ async function emptyEmbeddingChat({
sources: [],
close: true,
error: null,
chatId: chat.id,
textResponse,
};
}

View File

@ -177,12 +177,20 @@ async function streamChatWithWorkspace(
});
}
await WorkspaceChats.new({
const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources, type: chatMode },
user,
threadId: thread?.id,
user,
});
writeResponseChunk(response, {
uuid,
type: "finalizeResponseStream",
close: true,
error: false,
chatId: chat.id,
});
return;
}
@ -235,12 +243,20 @@ async function streamEmptyEmbeddingChat({
});
}
await WorkspaceChats.new({
const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id,
prompt: message,
response: { text: completeText, sources: [], type: "chat" },
user,
threadId: thread?.id,
user,
});
writeResponseChunk(response, {
uuid,
type: "finalizeResponseStream",
close: true,
error: false,
chatId: chat.id,
});
return;
}

View File

@ -6,7 +6,7 @@ const { WorkspaceChats } = require("../../../models/workspaceChats");
// Todo: add RLHF feedbackScore field support
async function convertToCSV(preparedData) {
const rows = ["id,username,workspace,prompt,response,sent_at"];
const rows = ["id,username,workspace,prompt,response,sent_at,rating"];
for (const item of preparedData) {
const record = [
item.id,
@ -15,6 +15,7 @@ async function convertToCSV(preparedData) {
escapeCsv(item.prompt),
escapeCsv(item.response),
item.sent_at,
item.feedback,
].join(",");
rows.push(record);
}
@ -53,6 +54,12 @@ async function prepareWorkspaceChatsForExport(format = "jsonl") {
prompt: chat.prompt,
response: responseJson.text,
sent_at: chat.createdAt,
feedback:
chat.feedbackScore === null
? "--"
: chat.feedbackScore
? "GOOD"
: "BAD",
};
});