[FEAT] RLHF on response messages (#708)

* WIP RLHF works on historical messages

* refactor Actions component

* completed RLHF up and down votes for chats

* add defaults for HistoricalMessage params

* refactor RLHF implmenation
remove forwardRef on history items to prevent rerenders

* remove dup id

* Add rating to CSV output

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-02-13 11:33:05 -08:00 committed by GitHub
parent 1b29882c71
commit f4b09a8c79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 319 additions and 140 deletions

View File

@ -89,6 +89,8 @@ export default function ChatHistory({ settings = {}, history = [] }) {
message={props.content} message={props.content}
role={props.role} role={props.role}
sources={props.sources} sources={props.sources}
chatId={props.chatId}
feedbackScore={props.feedbackScore}
error={props.error} error={props.error}
/> />
); );

View File

@ -1,27 +1,91 @@
import React, { memo, useState } from "react";
import useCopyText from "@/hooks/useCopyText"; import useCopyText from "@/hooks/useCopyText";
import { Check, ClipboardText } from "@phosphor-icons/react"; import {
import { memo } from "react"; Check,
ClipboardText,
ThumbsUp,
ThumbsDown,
} from "@phosphor-icons/react";
import { Tooltip } from "react-tooltip"; import { Tooltip } from "react-tooltip";
import Workspace from "@/models/workspace";
const Actions = ({ message, feedbackScore, chatId, slug }) => {
const [selectedFeedback, setSelectedFeedback] = useState(feedbackScore);
const handleFeedback = async (newFeedback) => {
const updatedFeedback =
selectedFeedback === newFeedback ? null : newFeedback;
await Workspace.updateChatFeedback(chatId, slug, updatedFeedback);
setSelectedFeedback(updatedFeedback);
};
const Actions = ({ message }) => {
return ( return (
<div className="flex justify-start items-center gap-x-4"> <div className="flex justify-start items-center gap-x-4">
<CopyMessage message={message} /> <CopyMessage message={message} />
{/* Other actions to go here later. */} {chatId && (
<>
<FeedbackButton
isSelected={selectedFeedback === true}
handleFeedback={() => handleFeedback(true)}
tooltipId={`${chatId}-thumbs-up`}
tooltipContent="Good response"
IconComponent={ThumbsUp}
/>
<FeedbackButton
isSelected={selectedFeedback === false}
handleFeedback={() => handleFeedback(false)}
tooltipId={`${chatId}-thumbs-down`}
tooltipContent="Bad response"
IconComponent={ThumbsDown}
/>
</>
)}
</div> </div>
); );
}; };
function FeedbackButton({
isSelected,
handleFeedback,
tooltipId,
tooltipContent,
IconComponent,
}) {
return (
<div className="mt-3 relative">
<button
onClick={handleFeedback}
data-tooltip-id={tooltipId}
data-tooltip-content={tooltipContent}
className="text-zinc-300"
>
<IconComponent
size={18}
className="mb-1"
weight={isSelected ? "fill" : "regular"}
/>
</button>
<Tooltip
id={tooltipId}
place="bottom"
delayShow={300}
className="tooltip !text-xs"
/>
</div>
);
}
function CopyMessage({ message }) { function CopyMessage({ message }) {
const { copied, copyText } = useCopyText(); const { copied, copyText } = useCopyText();
return ( return (
<> <>
<div className="mt-3 relative"> <div className="mt-3 relative">
<button <button
onClick={() => copyText(message)}
data-tooltip-id="copy-assistant-text" data-tooltip-id="copy-assistant-text"
data-tooltip-content="Copy" data-tooltip-content="Copy"
className="text-zinc-300" className="text-zinc-300"
onClick={() => copyText(message)}
> >
{copied ? ( {copied ? (
<Check size={18} className="mb-1" /> <Check size={18} className="mb-1" />
@ -29,13 +93,13 @@ function CopyMessage({ message }) {
<ClipboardText size={18} className="mb-1" /> <ClipboardText size={18} className="mb-1" />
)} )}
</button> </button>
</div>
<Tooltip <Tooltip
id="copy-assistant-text" id="copy-assistant-text"
place="bottom" place="bottom"
delayShow={300} delayShow={300}
className="tooltip !text-xs" className="tooltip !text-xs"
/> />
</div>
</> </>
); );
} }

View File

@ -1,4 +1,4 @@
import React, { memo, forwardRef } from "react"; import React, { memo } from "react";
import { Warning } from "@phosphor-icons/react"; import { Warning } from "@phosphor-icons/react";
import Jazzicon from "../../../../UserIcon"; import Jazzicon from "../../../../UserIcon";
import Actions from "./Actions"; import Actions from "./Actions";
@ -10,15 +10,19 @@ import { v4 } from "uuid";
import createDOMPurify from "dompurify"; import createDOMPurify from "dompurify";
const DOMPurify = createDOMPurify(window); const DOMPurify = createDOMPurify(window);
const HistoricalMessage = forwardRef( const HistoricalMessage = ({
( uuid = v4(),
{ uuid = v4(), message, role, workspace, sources = [], error = false }, message,
ref role,
) => { workspace,
sources = [],
error = false,
feedbackScore = null,
chatId = null,
}) => {
return ( return (
<div <div
key={uuid} key={uuid}
ref={ref}
className={`flex justify-center items-end w-full ${ className={`flex justify-center items-end w-full ${
role === "user" ? USER_BACKGROUND_COLOR : AI_BACKGROUND_COLOR role === "user" ? USER_BACKGROUND_COLOR : AI_BACKGROUND_COLOR
}`} }`}
@ -31,9 +35,7 @@ const HistoricalMessage = forwardRef(
size={36} size={36}
user={{ user={{
uid: uid:
role === "user" role === "user" ? userFromStorage()?.username : workspace.slug,
? userFromStorage()?.username
: workspace.slug,
}} }}
role={role} role={role}
/> />
@ -60,14 +62,18 @@ const HistoricalMessage = forwardRef(
{role === "assistant" && !error && ( {role === "assistant" && !error && (
<div className="flex gap-x-5"> <div className="flex gap-x-5">
<div className="relative w-[35px] h-[35px] rounded-full flex-shrink-0 overflow-hidden" /> <div className="relative w-[35px] h-[35px] rounded-full flex-shrink-0 overflow-hidden" />
<Actions message={DOMPurify.sanitize(message)} /> <Actions
message={DOMPurify.sanitize(message)}
feedbackScore={feedbackScore}
chatId={chatId}
slug={workspace?.slug}
/>
</div> </div>
)} )}
{role === "assistant" && <Citations sources={sources} />} {role === "assistant" && <Citations sources={sources} />}
</div> </div>
</div> </div>
); );
} };
);
export default memo(HistoricalMessage); export default memo(HistoricalMessage);

View File

@ -1,14 +1,18 @@
import { forwardRef, memo } from "react"; import { memo } from "react";
import { Warning } from "@phosphor-icons/react"; import { Warning } from "@phosphor-icons/react";
import Jazzicon from "../../../../UserIcon"; import Jazzicon from "../../../../UserIcon";
import renderMarkdown from "@/utils/chat/markdown"; import renderMarkdown from "@/utils/chat/markdown";
import Citations from "../Citation"; import Citations from "../Citation";
const PromptReply = forwardRef( const PromptReply = ({
( uuid,
{ uuid, reply, pending, error, workspace, sources = [], closed = true }, reply,
ref pending,
) => { error,
workspace,
sources = [],
closed = true,
}) => {
const assistantBackgroundColor = "bg-historical-msg-system"; const assistantBackgroundColor = "bg-historical-msg-system";
if (!reply && sources.length === 0 && !pending && !error) return null; if (!reply && sources.length === 0 && !pending && !error) return null;
@ -16,7 +20,6 @@ const PromptReply = forwardRef(
if (pending) { if (pending) {
return ( return (
<div <div
ref={ref}
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`} className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
> >
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col"> <div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
@ -61,16 +64,11 @@ const PromptReply = forwardRef(
return ( return (
<div <div
key={uuid} key={uuid}
ref={ref}
className={`flex justify-center items-end w-full ${assistantBackgroundColor}`} className={`flex justify-center items-end w-full ${assistantBackgroundColor}`}
> >
<div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col"> <div className="py-8 px-4 w-full flex gap-x-5 md:max-w-[800px] flex-col">
<div className="flex gap-x-5"> <div className="flex gap-x-5">
<Jazzicon <Jazzicon size={36} user={{ uid: workspace.slug }} role="assistant" />
size={36}
user={{ uid: workspace.slug }}
role="assistant"
/>
<span <span
className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`} className={`reply whitespace-pre-line text-white font-normal text-sm md:text-sm flex flex-col gap-y-1 mt-2`}
dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }} dangerouslySetInnerHTML={{ __html: renderMarkdown(reply) }}
@ -80,7 +78,6 @@ const PromptReply = forwardRef(
</div> </div>
</div> </div>
); );
} };
);
export default memo(PromptReply); export default memo(PromptReply);

View File

@ -7,7 +7,6 @@ import { ArrowDown } from "@phosphor-icons/react";
import debounce from "lodash.debounce"; import debounce from "lodash.debounce";
export default function ChatHistory({ history = [], workspace, sendCommand }) { export default function ChatHistory({ history = [], workspace, sendCommand }) {
const replyRef = useRef(null);
const { showing, showModal, hideModal } = useManageWorkspaceModal(); const { showing, showModal, hideModal } = useManageWorkspaceModal();
const [isAtBottom, setIsAtBottom] = useState(true); const [isAtBottom, setIsAtBottom] = useState(true);
const chatHistoryRef = useRef(null); const chatHistoryRef = useRef(null);
@ -89,7 +88,6 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
ref={chatHistoryRef} ref={chatHistoryRef}
> >
{history.map((props, index) => { {history.map((props, index) => {
const isLastMessage = index === history.length - 1;
const isLastBotReply = const isLastBotReply =
index === history.length - 1 && props.role === "assistant"; index === history.length - 1 && props.role === "assistant";
@ -97,7 +95,6 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
return ( return (
<PromptReply <PromptReply
key={props.uuid} key={props.uuid}
ref={isLastMessage ? replyRef : null}
uuid={props.uuid} uuid={props.uuid}
reply={props.content} reply={props.content}
pending={props.pending} pending={props.pending}
@ -112,11 +109,12 @@ export default function ChatHistory({ history = [], workspace, sendCommand }) {
return ( return (
<HistoricalMessage <HistoricalMessage
key={index} key={index}
ref={isLastMessage ? replyRef : null}
message={props.content} message={props.content}
role={props.role} role={props.role}
workspace={workspace} workspace={workspace}
sources={props.sources} sources={props.sources}
feedbackScore={props.feedbackScore}
chatId={props.chatId}
error={props.error} error={props.error}
/> />
); );

View File

@ -60,6 +60,19 @@ const Workspace = {
.catch(() => []); .catch(() => []);
return history; return history;
}, },
updateChatFeedback: async function (chatId, slug, feedback) {
const result = await fetch(
`${API_BASE}/workspace/${slug}/chat-feedback/${chatId}`,
{
method: "POST",
headers: baseHeaders(),
body: JSON.stringify({ feedback }),
}
)
.then((res) => res.ok)
.catch(() => false);
return result;
},
streamChat: async function ({ slug }, message, mode = "query", handleChat) { streamChat: async function ({ slug }, message, mode = "query", handleChat) {
const ctrl = new AbortController(); const ctrl = new AbortController();
await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, { await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, {

View File

@ -1,4 +1,4 @@
// For handling of synchronous chats that are not utilizing streaming or chat requests. // For handling of chat responses in the frontend by their various types.
export default function handleChat( export default function handleChat(
chatResult, chatResult,
setLoadingResponse, setLoadingResponse,
@ -6,7 +6,15 @@ export default function handleChat(
remHistory, remHistory,
_chatHistory _chatHistory
) { ) {
const { uuid, textResponse, type, sources = [], error, close } = chatResult; const {
uuid,
textResponse,
type,
sources = [],
error,
close,
chatId = null,
} = chatResult;
if (type === "abort") { if (type === "abort") {
setLoadingResponse(false); setLoadingResponse(false);
@ -46,6 +54,7 @@ export default function handleChat(
error, error,
animate: !close, animate: !close,
pending: false, pending: false,
chatId,
}, },
]); ]);
_chatHistory.push({ _chatHistory.push({
@ -57,6 +66,7 @@ export default function handleChat(
error, error,
animate: !close, animate: !close,
pending: false, pending: false,
chatId,
}); });
} else if (type === "textResponseChunk") { } else if (type === "textResponseChunk") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid); const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
@ -70,6 +80,7 @@ export default function handleChat(
closed: close, closed: close,
animate: !close, animate: !close,
pending: false, pending: false,
chatId,
}; };
_chatHistory[chatIdx] = updatedHistory; _chatHistory[chatIdx] = updatedHistory;
} else { } else {
@ -82,9 +93,21 @@ export default function handleChat(
closed: close, closed: close,
animate: !close, animate: !close,
pending: false, pending: false,
chatId,
}); });
} }
setChatHistory([..._chatHistory]); setChatHistory([..._chatHistory]);
} else if (type === "finalizeResponseStream") {
const chatIdx = _chatHistory.findIndex((chat) => chat.uuid === uuid);
if (chatIdx !== -1) {
const existingHistory = { ..._chatHistory[chatIdx] };
const updatedHistory = {
...existingHistory,
chatId, // finalize response stream only has some specific keys for data. we are explicitly listing them here.
};
_chatHistory[chatIdx] = updatedHistory;
}
setChatHistory([..._chatHistory]);
} }
} }

View File

@ -21,6 +21,7 @@ const { EventLogs } = require("../models/eventLogs");
const { const {
WorkspaceSuggestedMessages, WorkspaceSuggestedMessages,
} = require("../models/workspacesSuggestedMessages"); } = require("../models/workspacesSuggestedMessages");
const { validWorkspaceSlug } = require("../utils/middleware/validWorkspace");
const { handleUploads } = setupMulter(); const { handleUploads } = setupMulter();
function workspaceEndpoints(app) { function workspaceEndpoints(app) {
@ -321,6 +322,35 @@ function workspaceEndpoints(app) {
} }
); );
app.post(
"/workspace/:slug/chat-feedback/:chatId",
[validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug],
async (request, response) => {
try {
const { chatId } = request.params;
const { feedback = null } = reqBody(request);
const existingChat = await WorkspaceChats.get({
id: Number(chatId),
workspaceId: response.locals.workspace.id,
});
if (!existingChat) {
response.status(404).end();
return;
}
const result = await WorkspaceChats.updateFeedbackScore(
chatId,
feedback
);
response.status(200).json({ success: result });
} catch (error) {
console.error("Error updating chat feedback:", error);
response.status(500).end();
}
}
);
app.get( app.get(
"/workspace/:slug/suggested-messages", "/workspace/:slug/suggested-messages",
[validatedRequest, flexUserRoleValid([ROLES.all])], [validatedRequest, flexUserRoleValid([ROLES.all])],

View File

@ -203,6 +203,23 @@ const WorkspaceChats = {
return []; return [];
} }
}, },
updateFeedbackScore: async function (chatId = null, feedbackScore = null) {
if (!chatId) return;
try {
await prisma.workspace_chats.update({
where: {
id: Number(chatId),
},
data: {
feedbackScore:
feedbackScore === null ? null : Number(feedbackScore) === 1,
},
});
return;
} catch (error) {
console.error(error.message);
}
},
}; };
module.exports = { WorkspaceChats }; module.exports = { WorkspaceChats };

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspace_chats" ADD COLUMN "feedbackScore" BOOLEAN;

View File

@ -142,6 +142,7 @@ model workspace_chats {
thread_id Int? // No relation to prevent whole table migration thread_id Int? // No relation to prevent whole table migration
createdAt DateTime @default(now()) createdAt DateTime @default(now())
lastUpdatedAt DateTime @default(now()) lastUpdatedAt DateTime @default(now())
feedbackScore Boolean?
users users? @relation(fields: [user_id], references: [id], onDelete: Cascade, onUpdate: Cascade) users users? @relation(fields: [user_id], references: [id], onDelete: Cascade, onUpdate: Cascade)
} }

View File

@ -7,7 +7,7 @@ const { getVectorDbClass, getLLMProvider } = require("../helpers");
function convertToChatHistory(history = []) { function convertToChatHistory(history = []) {
const formattedHistory = []; const formattedHistory = [];
history.forEach((history) => { history.forEach((history) => {
const { prompt, response, createdAt } = history; const { prompt, response, createdAt, feedbackScore = null, id } = history;
const data = JSON.parse(response); const data = JSON.parse(response);
formattedHistory.push([ formattedHistory.push([
{ {
@ -19,7 +19,9 @@ function convertToChatHistory(history = []) {
role: "assistant", role: "assistant",
content: data.text, content: data.text,
sources: data.sources || [], sources: data.sources || [],
chatId: id,
sentAt: moment(createdAt).unix(), sentAt: moment(createdAt).unix(),
feedbackScore,
}, },
]); ]);
}); });
@ -185,8 +187,7 @@ async function chatWithWorkspace(
error: "No text completion could be completed with this input.", error: "No text completion could be completed with this input.",
}; };
} }
const { chat } = await WorkspaceChats.new({
await WorkspaceChats.new({
workspaceId: workspace.id, workspaceId: workspace.id,
prompt: message, prompt: message,
response: { text: textResponse, sources, type: chatMode }, response: { text: textResponse, sources, type: chatMode },
@ -196,9 +197,10 @@ async function chatWithWorkspace(
id: uuid, id: uuid,
type: "textResponse", type: "textResponse",
close: true, close: true,
error: null,
chatId: chat.id,
textResponse, textResponse,
sources, sources,
error,
}; };
} }
@ -271,7 +273,7 @@ async function emptyEmbeddingChat({
workspace, workspace,
rawHistory rawHistory
); );
await WorkspaceChats.new({ const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id, workspaceId: workspace.id,
prompt: message, prompt: message,
response: { text: textResponse, sources: [], type: "chat" }, response: { text: textResponse, sources: [], type: "chat" },
@ -283,6 +285,7 @@ async function emptyEmbeddingChat({
sources: [], sources: [],
close: true, close: true,
error: null, error: null,
chatId: chat.id,
textResponse, textResponse,
}; };
} }

View File

@ -177,12 +177,20 @@ async function streamChatWithWorkspace(
}); });
} }
await WorkspaceChats.new({ const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id, workspaceId: workspace.id,
prompt: message, prompt: message,
response: { text: completeText, sources, type: chatMode }, response: { text: completeText, sources, type: chatMode },
user,
threadId: thread?.id, threadId: thread?.id,
user,
});
writeResponseChunk(response, {
uuid,
type: "finalizeResponseStream",
close: true,
error: false,
chatId: chat.id,
}); });
return; return;
} }
@ -235,12 +243,20 @@ async function streamEmptyEmbeddingChat({
}); });
} }
await WorkspaceChats.new({ const { chat } = await WorkspaceChats.new({
workspaceId: workspace.id, workspaceId: workspace.id,
prompt: message, prompt: message,
response: { text: completeText, sources: [], type: "chat" }, response: { text: completeText, sources: [], type: "chat" },
user,
threadId: thread?.id, threadId: thread?.id,
user,
});
writeResponseChunk(response, {
uuid,
type: "finalizeResponseStream",
close: true,
error: false,
chatId: chat.id,
}); });
return; return;
} }

View File

@ -6,7 +6,7 @@ const { WorkspaceChats } = require("../../../models/workspaceChats");
// Todo: add RLHF feedbackScore field support // Todo: add RLHF feedbackScore field support
async function convertToCSV(preparedData) { async function convertToCSV(preparedData) {
const rows = ["id,username,workspace,prompt,response,sent_at"]; const rows = ["id,username,workspace,prompt,response,sent_at,rating"];
for (const item of preparedData) { for (const item of preparedData) {
const record = [ const record = [
item.id, item.id,
@ -15,6 +15,7 @@ async function convertToCSV(preparedData) {
escapeCsv(item.prompt), escapeCsv(item.prompt),
escapeCsv(item.response), escapeCsv(item.response),
item.sent_at, item.sent_at,
item.feedback,
].join(","); ].join(",");
rows.push(record); rows.push(record);
} }
@ -53,6 +54,12 @@ async function prepareWorkspaceChatsForExport(format = "jsonl") {
prompt: chat.prompt, prompt: chat.prompt,
response: responseJson.text, response: responseJson.text,
sent_at: chat.createdAt, sent_at: chat.createdAt,
feedback:
chat.feedbackScore === null
? "--"
: chat.feedbackScore
? "GOOD"
: "BAD",
}; };
}); });