From 0e46a11cb65844640111c44d7588ef7d3861ffba Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Tue, 12 Mar 2024 15:21:27 -0700 Subject: [PATCH 1/6] Stop generation button during stream-response (#892) * Stop generation button during stream-response * add custom stop icon * add stop to thread chats --- .../StopGenerationButton/index.jsx | 50 +++++++++++++++++++ .../PromptInput/StopGenerationButton/stop.svg | 4 ++ .../ChatContainer/PromptInput/index.jsx | 26 +++++----- .../WorkspaceChat/ChatContainer/index.jsx | 6 +-- frontend/src/models/workspace.js | 11 ++++ frontend/src/models/workspaceThread.js | 11 ++++ frontend/src/utils/chat/index.js | 18 +++++++ server/utils/AiProviders/anthropic/index.js | 13 ++++- server/utils/AiProviders/azureOpenAi/index.js | 14 +++++- server/utils/AiProviders/gemini/index.js | 14 +++++- server/utils/AiProviders/huggingface/index.js | 16 +++++- server/utils/AiProviders/native/index.js | 14 +++++- server/utils/AiProviders/ollama/index.js | 17 ++++++- server/utils/AiProviders/openRouter/index.js | 15 +++++- server/utils/AiProviders/togetherAi/index.js | 15 +++++- server/utils/helpers/chat/responses.js | 19 +++++++ 16 files changed, 236 insertions(+), 27 deletions(-) create mode 100644 frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/index.jsx create mode 100644 frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/stop.svg diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/index.jsx new file mode 100644 index 000000000..09a7c2ceb --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/index.jsx @@ -0,0 +1,50 @@ +import { ABORT_STREAM_EVENT } from "@/utils/chat"; +import { Tooltip } from "react-tooltip"; + +export default function StopGenerationButton() { + function emitHaltEvent() { + window.dispatchEvent(new CustomEvent(ABORT_STREAM_EVENT)); + } + + return ( + <> + + + + ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/stop.svg b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/stop.svg new file mode 100644 index 000000000..ab98895c2 --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/StopGenerationButton/stop.svg @@ -0,0 +1,4 @@ + + + + diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx index 2b9c5ca4f..52e870123 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/index.jsx @@ -1,4 +1,3 @@ -import { CircleNotch, PaperPlaneRight } from "@phosphor-icons/react"; import React, { useState, useRef } from "react"; import SlashCommandsButton, { SlashCommands, @@ -6,6 +5,8 @@ import SlashCommandsButton, { } from "./SlashCommands"; import { isMobile } from "react-device-detect"; import debounce from "lodash.debounce"; +import { PaperPlaneRight } from "@phosphor-icons/react"; +import StopGenerationButton from "./StopGenerationButton"; export default function PromptInput({ workspace, @@ -83,19 +84,18 @@ export default function PromptInput({ className="cursor-text max-h-[100px] md:min-h-[40px] mx-2 md:mx-0 py-2 w-full text-[16px] md:text-md text-white bg-transparent placeholder:text-white/60 resize-none active:outline-none focus:outline-none flex-grow" placeholder={"Send a message"} /> - + Send message + + )}
diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 293da491f..209fed5d6 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -68,11 +68,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { const remHistory = chatHistory.length > 0 ? chatHistory.slice(0, -1) : []; var _chatHistory = [...remHistory]; - if (!promptMessage || !promptMessage?.userMessage) { - setLoadingResponse(false); - return false; - } - + if (!promptMessage || !promptMessage?.userMessage) return false; if (!!threadSlug) { await Workspace.threads.streamChat( { workspaceSlug: workspace.slug, threadSlug }, diff --git a/frontend/src/models/workspace.js b/frontend/src/models/workspace.js index 6786abffd..ae2cd5590 100644 --- a/frontend/src/models/workspace.js +++ b/frontend/src/models/workspace.js @@ -3,6 +3,7 @@ import { baseHeaders } from "@/utils/request"; import { fetchEventSource } from "@microsoft/fetch-event-source"; import WorkspaceThread from "@/models/workspaceThread"; import { v4 } from "uuid"; +import { ABORT_STREAM_EVENT } from "@/utils/chat"; const Workspace = { new: async function (data = {}) { @@ -75,6 +76,16 @@ const Workspace = { }, streamChat: async function ({ slug }, message, handleChat) { const ctrl = new AbortController(); + + // Listen for the ABORT_STREAM_EVENT key to be emitted by the client + // to early abort the streaming response. On abort we send a special `stopGeneration` + // event to be handled which resets the UI for us to be able to send another message. + // The backend response abort handling is done in each LLM's handleStreamResponse. + window.addEventListener(ABORT_STREAM_EVENT, () => { + ctrl.abort(); + handleChat({ id: v4(), type: "stopGeneration" }); + }); + await fetchEventSource(`${API_BASE}/workspace/${slug}/stream-chat`, { method: "POST", body: JSON.stringify({ message }), diff --git a/frontend/src/models/workspaceThread.js b/frontend/src/models/workspaceThread.js index f9fad3173..b1bcaf644 100644 --- a/frontend/src/models/workspaceThread.js +++ b/frontend/src/models/workspaceThread.js @@ -1,3 +1,4 @@ +import { ABORT_STREAM_EVENT } from "@/utils/chat"; import { API_BASE } from "@/utils/constants"; import { baseHeaders } from "@/utils/request"; import { fetchEventSource } from "@microsoft/fetch-event-source"; @@ -80,6 +81,16 @@ const WorkspaceThread = { handleChat ) { const ctrl = new AbortController(); + + // Listen for the ABORT_STREAM_EVENT key to be emitted by the client + // to early abort the streaming response. On abort we send a special `stopGeneration` + // event to be handled which resets the UI for us to be able to send another message. + // The backend response abort handling is done in each LLM's handleStreamResponse. + window.addEventListener(ABORT_STREAM_EVENT, () => { + ctrl.abort(); + handleChat({ id: v4(), type: "stopGeneration" }); + }); + await fetchEventSource( `${API_BASE}/workspace/${workspaceSlug}/thread/${threadSlug}/stream-chat`, { diff --git a/frontend/src/utils/chat/index.js b/frontend/src/utils/chat/index.js index f1df11fea..37237c9ec 100644 --- a/frontend/src/utils/chat/index.js +++ b/frontend/src/utils/chat/index.js @@ -1,3 +1,5 @@ +export const ABORT_STREAM_EVENT = "abort-chat-stream"; + // For handling of chat responses in the frontend by their various types. export default function handleChat( chatResult, @@ -108,6 +110,22 @@ export default function handleChat( _chatHistory[chatIdx] = updatedHistory; } setChatHistory([..._chatHistory]); + setLoadingResponse(false); + } else if (type === "stopGeneration") { + const chatIdx = _chatHistory.length - 1; + const existingHistory = { ..._chatHistory[chatIdx] }; + const updatedHistory = { + ...existingHistory, + sources: [], + closed: true, + error: null, + animate: false, + pending: false, + }; + _chatHistory[chatIdx] = updatedHistory; + + setChatHistory([..._chatHistory]); + setLoadingResponse(false); } } diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index a48058e81..fea083329 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -1,6 +1,9 @@ const { v4 } = require("uuid"); const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); class AnthropicLLM { constructor(embedder = null, modelPreference = null) { if (!process.env.ANTHROPIC_API_KEY) @@ -150,6 +153,13 @@ class AnthropicLLM { let fullText = ""; const { uuid = v4(), sources = [] } = responseProps; + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + stream.on("streamEvent", (message) => { const data = message; if ( @@ -181,6 +191,7 @@ class AnthropicLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } }); diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 2ac6de3a1..21fc5cd91 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -1,6 +1,9 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi"); const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); class AzureOpenAiLLM { constructor(embedder = null, _modelPreference = null) { @@ -174,6 +177,14 @@ class AzureOpenAiLLM { return new Promise(async (resolve) => { let fullText = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + for await (const event of stream) { for (const choice of event.choices) { const delta = choice.delta?.content; @@ -198,6 +209,7 @@ class AzureOpenAiLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); }); } diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index bd84a3856..3d334b291 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -1,5 +1,8 @@ const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); class GeminiLLM { constructor(embedder = null, modelPreference = null) { @@ -198,6 +201,14 @@ class GeminiLLM { return new Promise(async (resolve) => { let fullText = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + for await (const chunk of stream) { fullText += chunk.text(); writeResponseChunk(response, { @@ -218,6 +229,7 @@ class GeminiLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); }); } diff --git a/server/utils/AiProviders/huggingface/index.js b/server/utils/AiProviders/huggingface/index.js index 416e622a3..751d3595c 100644 --- a/server/utils/AiProviders/huggingface/index.js +++ b/server/utils/AiProviders/huggingface/index.js @@ -1,7 +1,10 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi"); const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); class HuggingFaceLLM { constructor(embedder = null, _modelPreference = null) { @@ -172,6 +175,14 @@ class HuggingFaceLLM { return new Promise((resolve) => { let fullText = ""; let chunk = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + stream.data.on("data", (data) => { const lines = data ?.toString() @@ -218,6 +229,7 @@ class HuggingFaceLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } else { let error = null; @@ -241,6 +253,7 @@ class HuggingFaceLLM { close: true, error, }); + response.removeListener("close", handleAbort); resolve(""); return; } @@ -266,6 +279,7 @@ class HuggingFaceLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } } diff --git a/server/utils/AiProviders/native/index.js b/server/utils/AiProviders/native/index.js index 157fb7520..5764d4ee2 100644 --- a/server/utils/AiProviders/native/index.js +++ b/server/utils/AiProviders/native/index.js @@ -2,7 +2,10 @@ const fs = require("fs"); const path = require("path"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); // Docs: https://api.js.langchain.com/classes/chat_models_llama_cpp.ChatLlamaCpp.html const ChatLlamaCpp = (...args) => @@ -176,6 +179,14 @@ class NativeLLM { return new Promise(async (resolve) => { let fullText = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + for await (const chunk of stream) { if (chunk === undefined) throw new Error( @@ -202,6 +213,7 @@ class NativeLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); }); } diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index 035d4a9d0..6bd857b4e 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -1,6 +1,9 @@ const { chatPrompt } = require("../../chats"); const { StringOutputParser } = require("langchain/schema/output_parser"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md class OllamaAILLM { @@ -180,8 +183,16 @@ class OllamaAILLM { const { uuid = uuidv4(), sources = [] } = responseProps; return new Promise(async (resolve) => { + let fullText = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + try { - let fullText = ""; for await (const chunk of stream) { if (chunk === undefined) throw new Error( @@ -210,6 +221,7 @@ class OllamaAILLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } catch (error) { writeResponseChunk(response, { @@ -222,6 +234,7 @@ class OllamaAILLM { error?.cause ?? error.message }`, }); + response.removeListener("close", handleAbort); } }); } diff --git a/server/utils/AiProviders/openRouter/index.js b/server/utils/AiProviders/openRouter/index.js index 38a6f9f09..a1f606f60 100644 --- a/server/utils/AiProviders/openRouter/index.js +++ b/server/utils/AiProviders/openRouter/index.js @@ -1,7 +1,10 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { chatPrompt } = require("../../chats"); const { v4: uuidv4 } = require("uuid"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); function openRouterModels() { const { MODELS } = require("./models.js"); @@ -195,6 +198,13 @@ class OpenRouterLLM { let chunk = ""; let lastChunkTime = null; // null when first token is still not received. + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + // NOTICE: Not all OpenRouter models will return a stop reason // which keeps the connection open and so the model never finalizes the stream // like the traditional OpenAI response schema does. So in the case the response stream @@ -220,6 +230,7 @@ class OpenRouterLLM { error: false, }); clearInterval(timeoutCheck); + response.removeListener("close", handleAbort); resolve(fullText); } }, 500); @@ -269,6 +280,7 @@ class OpenRouterLLM { error: false, }); clearInterval(timeoutCheck); + response.removeListener("close", handleAbort); resolve(fullText); } else { let finishReason = null; @@ -305,6 +317,7 @@ class OpenRouterLLM { error: false, }); clearInterval(timeoutCheck); + response.removeListener("close", handleAbort); resolve(fullText); } } diff --git a/server/utils/AiProviders/togetherAi/index.js b/server/utils/AiProviders/togetherAi/index.js index 15b254a15..def03df96 100644 --- a/server/utils/AiProviders/togetherAi/index.js +++ b/server/utils/AiProviders/togetherAi/index.js @@ -1,5 +1,8 @@ const { chatPrompt } = require("../../chats"); -const { writeResponseChunk } = require("../../helpers/chat/responses"); +const { + writeResponseChunk, + clientAbortedHandler, +} = require("../../helpers/chat/responses"); function togetherAiModels() { const { MODELS } = require("./models.js"); @@ -185,6 +188,14 @@ class TogetherAiLLM { return new Promise((resolve) => { let fullText = ""; let chunk = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + stream.data.on("data", (data) => { const lines = data ?.toString() @@ -230,6 +241,7 @@ class TogetherAiLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } else { let finishReason = null; @@ -263,6 +275,7 @@ class TogetherAiLLM { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } } diff --git a/server/utils/helpers/chat/responses.js b/server/utils/helpers/chat/responses.js index c4371d818..e2ec7bd0d 100644 --- a/server/utils/helpers/chat/responses.js +++ b/server/utils/helpers/chat/responses.js @@ -1,6 +1,14 @@ const { v4: uuidv4 } = require("uuid"); const moment = require("moment"); +function clientAbortedHandler(resolve, fullText) { + console.log( + "\x1b[43m\x1b[34m[STREAM ABORTED]\x1b[0m Client requested to abort stream. Exiting LLM stream handler early." + ); + resolve(fullText); + return; +} + // The default way to handle a stream response. Functions best with OpenAI. // Currently used for LMStudio, LocalAI, Mistral API, and OpenAI function handleDefaultStreamResponse(response, stream, responseProps) { @@ -9,6 +17,14 @@ function handleDefaultStreamResponse(response, stream, responseProps) { return new Promise((resolve) => { let fullText = ""; let chunk = ""; + + // Establish listener to early-abort a streaming response + // in case things go sideways or the user does not like the response. + // We preserve the generated text but continue as if chat was completed + // to preserve previously generated content. + const handleAbort = () => clientAbortedHandler(resolve, fullText); + response.on("close", handleAbort); + stream.data.on("data", (data) => { const lines = data ?.toString() @@ -52,6 +68,7 @@ function handleDefaultStreamResponse(response, stream, responseProps) { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } else { let finishReason = null; @@ -85,6 +102,7 @@ function handleDefaultStreamResponse(response, stream, responseProps) { close: true, error: false, }); + response.removeListener("close", handleAbort); resolve(fullText); } } @@ -141,4 +159,5 @@ module.exports = { convertToChatHistory, convertToPromptHistory, writeResponseChunk, + clientAbortedHandler, }; From ac0e62d490eb8ee51d126735c710273e6ce2fd76 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Wed, 13 Mar 2024 17:32:02 -0700 Subject: [PATCH 2/6] [FEAT] Anthropic Haiku model support (#901) add Haiku model support --- .../src/components/LLMSelection/AnthropicAiOptions/index.jsx | 1 + frontend/src/hooks/useGetProvidersModels.js | 1 + server/utils/AiProviders/anthropic/index.js | 3 +++ server/utils/helpers/updateENV.js | 1 + 4 files changed, 6 insertions(+) diff --git a/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx index 6bc18a5ac..e8c288d60 100644 --- a/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx @@ -52,6 +52,7 @@ export default function AnthropicAiOptions({ settings, showAlert = false }) { "claude-instant-1.2", "claude-2.0", "claude-2.1", + "claude-3-haiku-20240307", "claude-3-opus-20240229", "claude-3-sonnet-20240229", ].map((model) => { diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js index 57a95ea7a..f578c929f 100644 --- a/frontend/src/hooks/useGetProvidersModels.js +++ b/frontend/src/hooks/useGetProvidersModels.js @@ -19,6 +19,7 @@ const PROVIDER_DEFAULT_MODELS = { "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", ], azure: [], lmstudio: [], diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index fea083329..24a07f6e5 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -48,6 +48,8 @@ class AnthropicLLM { return 200_000; case "claude-3-sonnet-20240229": return 200_000; + case "claude-3-haiku-20240307": + return 200_000; default: return 100_000; // assume a claude-instant-1.2 model } @@ -60,6 +62,7 @@ class AnthropicLLM { "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", ]; return validModels.includes(modelName); } diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index aa814d690..e46074a7c 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -365,6 +365,7 @@ function validAnthropicModel(input = "") { "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", ]; return validModels.includes(input) ? null From 1352b18b5fc55e2fafc66c7ee9c2eca891cb3dbd Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Thu, 14 Mar 2024 10:52:35 -0700 Subject: [PATCH 3/6] [FEAT] Implement correct highlight colors on document picker (#897) * implement alternating color rows for file picker * implement alternating colors for workspace directory * remove unneeded props/variables * remove unused border classes * remove unneeded expanded prop from filerow component --- .../Documents/Directory/FileRow/index.jsx | 12 ++++++------ .../Documents/Directory/FolderRow/index.jsx | 9 ++++----- .../WorkspaceFileRow/index.jsx | 4 ++-- .../Documents/WorkspaceDirectory/index.jsx | 4 ++-- frontend/src/index.css | 16 ++++++++++++++++ 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FileRow/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FileRow/index.jsx index 7e2259b22..976c65988 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FileRow/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FileRow/index.jsx @@ -13,7 +13,6 @@ export default function FileRow({ folderName, selected, toggleSelection, - expanded, fetchKeys, setLoading, setLoadingMessage, @@ -53,12 +52,13 @@ export default function FileRow({ const handleMouseEnter = debounce(handleShowTooltip, 500); const handleMouseLeave = debounce(handleHideTooltip, 500); + return ( -
toggleSelection(item)} - className={`transition-all duration-200 text-white/80 text-xs grid grid-cols-12 py-2 pl-3.5 pr-8 border-b border-white/20 hover:bg-sky-500/20 cursor-pointer ${`${ - selected ? "bg-sky-500/20" : "" - } ${expanded ? "bg-sky-500/10" : ""}`}`} + className={`transition-all duration-200 text-white/80 text-xs grid grid-cols-12 py-2 pl-3.5 pr-8 hover:bg-sky-500/20 cursor-pointer file-row ${ + selected ? "selected" : "" + }`} >
-
+ ); } diff --git a/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FolderRow/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FolderRow/index.jsx index 5b7f1be39..48953ab1f 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FolderRow/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Documents/Directory/FolderRow/index.jsx @@ -47,10 +47,10 @@ export default function FolderRow({ return ( <> -
@@ -88,7 +88,7 @@ export default function FolderRow({ /> )}
-
+ {expanded && (
{item.items.map((fileItem) => ( @@ -97,7 +97,6 @@ export default function FolderRow({ item={fileItem} folderName={item.name} selected={isSelected(fileItem.id)} - expanded={expanded} toggleSelection={toggleSelection} fetchKeys={fetchKeys} setLoading={setLoading} diff --git a/frontend/src/components/Modals/MangeWorkspace/Documents/WorkspaceDirectory/WorkspaceFileRow/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Documents/WorkspaceDirectory/WorkspaceFileRow/index.jsx index 3367c7289..f73916290 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Documents/WorkspaceDirectory/WorkspaceFileRow/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Documents/WorkspaceDirectory/WorkspaceFileRow/index.jsx @@ -53,8 +53,8 @@ export default function WorkspaceFileRow({ const handleMouseLeave = debounce(handleHideTooltip, 500); return (
-
+

Name

Date

Kind

@@ -148,7 +148,7 @@ const PinAlert = memo(() => {
-
+

diff --git a/frontend/src/index.css b/frontend/src/index.css index e2141d8de..b355eb20a 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -597,3 +597,19 @@ dialog::backdrop { font-weight: 600; color: #fff; } + +.file-row:nth-child(odd) { + @apply bg-[#1C1E21]; +} + +.file-row:nth-child(even) { + @apply bg-[#2C2C2C]; +} + +.file-row.selected:nth-child(odd) { + @apply bg-sky-500/20; +} + +.file-row.selected:nth-child(even) { + @apply bg-sky-500/10; +} From 0ada8829915472cb9c94eff97382d5664aa0dec4 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Thu, 14 Mar 2024 15:43:26 -0700 Subject: [PATCH 4/6] Support external transcription providers (#909) * Support External Transcription providers * patch files * update docs * fix return data --- collector/index.js | 4 +- collector/package.json | 3 +- .../processSingleFile/convert/asAudio.js | 117 ++---------- collector/processSingleFile/index.js | 3 +- .../utils/WhisperProviders/OpenAiWhisper.js | 44 +++++ .../utils/WhisperProviders/localWhisper.js | 126 +++++++++++- collector/yarn.lock | 20 ++ docker/.env.example | 10 + frontend/src/App.jsx | 9 + .../src/components/SettingsSidebar/index.jsx | 11 +- .../NativeTranscriptionOptions/index.jsx | 38 ++++ .../OpenAiOptions/index.jsx | 41 ++++ .../TranscriptionPreference/index.jsx | 180 ++++++++++++++++++ frontend/src/utils/paths.js | 3 + server/.env.example | 10 + server/models/systemSettings.js | 1 + server/storage/models/README.md | 3 + server/utils/collectorApi/index.js | 14 +- server/utils/helpers/updateENV.js | 14 ++ 19 files changed, 541 insertions(+), 110 deletions(-) create mode 100644 collector/utils/WhisperProviders/OpenAiWhisper.js create mode 100644 frontend/src/components/TranscriptionSelection/NativeTranscriptionOptions/index.jsx create mode 100644 frontend/src/components/TranscriptionSelection/OpenAiOptions/index.jsx create mode 100644 frontend/src/pages/GeneralSettings/TranscriptionPreference/index.jsx diff --git a/collector/index.js b/collector/index.js index 9ebe5f1ce..a1142d756 100644 --- a/collector/index.js +++ b/collector/index.js @@ -25,7 +25,7 @@ app.use( ); app.post("/process", async function (request, response) { - const { filename } = reqBody(request); + const { filename, options = {} } = reqBody(request); try { const targetFilename = path .normalize(filename) @@ -34,7 +34,7 @@ app.post("/process", async function (request, response) { success, reason, documents = [], - } = await processSingleFile(targetFilename); + } = await processSingleFile(targetFilename, options); response .status(200) .json({ filename: targetFilename, success, reason, documents }); diff --git a/collector/package.json b/collector/package.json index d145ab865..8a0441d78 100644 --- a/collector/package.json +++ b/collector/package.json @@ -33,6 +33,7 @@ "moment": "^2.29.4", "multer": "^1.4.5-lts.1", "officeparser": "^4.0.5", + "openai": "^3.2.1", "pdf-parse": "^1.1.1", "puppeteer": "~21.5.2", "slugify": "^1.6.6", @@ -46,4 +47,4 @@ "nodemon": "^2.0.22", "prettier": "^2.4.1" } -} +} \ No newline at end of file diff --git a/collector/processSingleFile/convert/asAudio.js b/collector/processSingleFile/convert/asAudio.js index 15ae5cf00..170426e40 100644 --- a/collector/processSingleFile/convert/asAudio.js +++ b/collector/processSingleFile/convert/asAudio.js @@ -1,5 +1,3 @@ -const fs = require("fs"); -const path = require("path"); const { v4 } = require("uuid"); const { createdDate, @@ -9,39 +7,35 @@ const { const { tokenizeString } = require("../../utils/tokenizer"); const { default: slugify } = require("slugify"); const { LocalWhisper } = require("../../utils/WhisperProviders/localWhisper"); +const { OpenAiWhisper } = require("../../utils/WhisperProviders/OpenAiWhisper"); -async function asAudio({ fullFilePath = "", filename = "" }) { - const whisper = new LocalWhisper(); +const WHISPER_PROVIDERS = { + openai: OpenAiWhisper, + local: LocalWhisper, +}; + +async function asAudio({ fullFilePath = "", filename = "", options = {} }) { + const WhisperProvider = WHISPER_PROVIDERS.hasOwnProperty( + options?.whisperProvider + ) + ? WHISPER_PROVIDERS[options?.whisperProvider] + : WHISPER_PROVIDERS.local; console.log(`-- Working ${filename} --`); - const transcriberPromise = new Promise((resolve) => - whisper.client().then((client) => resolve(client)) - ); - const audioDataPromise = new Promise((resolve) => - convertToWavAudioData(fullFilePath).then((audioData) => resolve(audioData)) - ); - const [audioData, transcriber] = await Promise.all([ - audioDataPromise, - transcriberPromise, - ]); + const whisper = new WhisperProvider({ options }); + const { content, error } = await whisper.processFile(fullFilePath, filename); - if (!audioData) { - console.error(`Failed to parse content from ${filename}.`); + if (!!error) { + console.error(`Error encountered for parsing of ${filename}.`); trashFile(fullFilePath); return { success: false, - reason: `Failed to parse content from ${filename}.`, + reason: error, documents: [], }; } - console.log(`[Model Working]: Transcribing audio data to text`); - const { text: content } = await transcriber(audioData, { - chunk_length_s: 30, - stride_length_s: 5, - }); - - if (!content.length) { + if (!content?.length) { console.error(`Resulting text content was empty for ${filename}.`); trashFile(fullFilePath); return { @@ -76,79 +70,4 @@ async function asAudio({ fullFilePath = "", filename = "" }) { return { success: true, reason: null, documents: [document] }; } -async function convertToWavAudioData(sourcePath) { - try { - let buffer; - const wavefile = require("wavefile"); - const ffmpeg = require("fluent-ffmpeg"); - const outFolder = path.resolve(__dirname, `../../storage/tmp`); - if (!fs.existsSync(outFolder)) fs.mkdirSync(outFolder, { recursive: true }); - - const fileExtension = path.extname(sourcePath).toLowerCase(); - if (fileExtension !== ".wav") { - console.log( - `[Conversion Required] ${fileExtension} file detected - converting to .wav` - ); - const outputFile = path.resolve(outFolder, `${v4()}.wav`); - const convert = new Promise((resolve) => { - ffmpeg(sourcePath) - .toFormat("wav") - .on("error", (error) => { - console.error(`[Conversion Error] ${error.message}`); - resolve(false); - }) - .on("progress", (progress) => - console.log( - `[Conversion Processing]: ${progress.targetSize}KB converted` - ) - ) - .on("end", () => { - console.log("[Conversion Complete]: File converted to .wav!"); - resolve(true); - }) - .save(outputFile); - }); - const success = await convert; - if (!success) - throw new Error( - "[Conversion Failed]: Could not convert file to .wav format!" - ); - - const chunks = []; - const stream = fs.createReadStream(outputFile); - for await (let chunk of stream) chunks.push(chunk); - buffer = Buffer.concat(chunks); - fs.rmSync(outputFile); - } else { - const chunks = []; - const stream = fs.createReadStream(sourcePath); - for await (let chunk of stream) chunks.push(chunk); - buffer = Buffer.concat(chunks); - } - - const wavFile = new wavefile.WaveFile(buffer); - wavFile.toBitDepth("32f"); - wavFile.toSampleRate(16000); - - let audioData = wavFile.getSamples(); - if (Array.isArray(audioData)) { - if (audioData.length > 1) { - const SCALING_FACTOR = Math.sqrt(2); - - // Merge channels into first channel to save memory - for (let i = 0; i < audioData[0].length; ++i) { - audioData[0][i] = - (SCALING_FACTOR * (audioData[0][i] + audioData[1][i])) / 2; - } - } - audioData = audioData[0]; - } - - return audioData; - } catch (error) { - console.error(`convertToWavAudioData`, error); - return null; - } -} - module.exports = asAudio; diff --git a/collector/processSingleFile/index.js b/collector/processSingleFile/index.js index 569a2cde2..5d9e6a38a 100644 --- a/collector/processSingleFile/index.js +++ b/collector/processSingleFile/index.js @@ -7,7 +7,7 @@ const { const { trashFile, isTextType } = require("../utils/files"); const RESERVED_FILES = ["__HOTDIR__.md"]; -async function processSingleFile(targetFilename) { +async function processSingleFile(targetFilename, options = {}) { const fullFilePath = path.resolve(WATCH_DIRECTORY, targetFilename); if (RESERVED_FILES.includes(targetFilename)) return { @@ -54,6 +54,7 @@ async function processSingleFile(targetFilename) { return await FileTypeProcessor({ fullFilePath, filename: targetFilename, + options, }); } diff --git a/collector/utils/WhisperProviders/OpenAiWhisper.js b/collector/utils/WhisperProviders/OpenAiWhisper.js new file mode 100644 index 000000000..3b9d08e6a --- /dev/null +++ b/collector/utils/WhisperProviders/OpenAiWhisper.js @@ -0,0 +1,44 @@ +const fs = require("fs"); + +class OpenAiWhisper { + constructor({ options }) { + const { Configuration, OpenAIApi } = require("openai"); + if (!options.openAiKey) throw new Error("No OpenAI API key was set."); + + const config = new Configuration({ + apiKey: options.openAiKey, + }); + this.openai = new OpenAIApi(config); + this.model = "whisper-1"; + this.temperature = 0; + this.#log("Initialized."); + } + + #log(text, ...args) { + console.log(`\x1b[32m[OpenAiWhisper]\x1b[0m ${text}`, ...args); + } + + async processFile(fullFilePath) { + return await this.openai + .createTranscription( + fs.createReadStream(fullFilePath), + this.model, + undefined, + "text", + this.temperature + ) + .then((res) => { + if (res.hasOwnProperty("data")) + return { content: res.data, error: null }; + return { content: "", error: "No content was able to be transcribed." }; + }) + .catch((e) => { + this.#log(`Could not get any response from openai whisper`, e.message); + return { content: "", error: e.message }; + }); + } +} + +module.exports = { + OpenAiWhisper, +}; diff --git a/collector/utils/WhisperProviders/localWhisper.js b/collector/utils/WhisperProviders/localWhisper.js index 6503e2021..46dbe226b 100644 --- a/collector/utils/WhisperProviders/localWhisper.js +++ b/collector/utils/WhisperProviders/localWhisper.js @@ -1,5 +1,6 @@ -const path = require("path"); const fs = require("fs"); +const path = require("path"); +const { v4 } = require("uuid"); class LocalWhisper { constructor() { @@ -16,12 +17,94 @@ class LocalWhisper { // Make directory when it does not exist in existing installations if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir, { recursive: true }); + + this.#log("Initialized."); + } + + #log(text, ...args) { + console.log(`\x1b[32m[LocalWhisper]\x1b[0m ${text}`, ...args); + } + + async #convertToWavAudioData(sourcePath) { + try { + let buffer; + const wavefile = require("wavefile"); + const ffmpeg = require("fluent-ffmpeg"); + const outFolder = path.resolve(__dirname, `../../storage/tmp`); + if (!fs.existsSync(outFolder)) + fs.mkdirSync(outFolder, { recursive: true }); + + const fileExtension = path.extname(sourcePath).toLowerCase(); + if (fileExtension !== ".wav") { + this.#log( + `File conversion required! ${fileExtension} file detected - converting to .wav` + ); + const outputFile = path.resolve(outFolder, `${v4()}.wav`); + const convert = new Promise((resolve) => { + ffmpeg(sourcePath) + .toFormat("wav") + .on("error", (error) => { + this.#log(`Conversion Error! ${error.message}`); + resolve(false); + }) + .on("progress", (progress) => + this.#log( + `Conversion Processing! ${progress.targetSize}KB converted` + ) + ) + .on("end", () => { + this.#log(`Conversion Complete! File converted to .wav!`); + resolve(true); + }) + .save(outputFile); + }); + const success = await convert; + if (!success) + throw new Error( + "[Conversion Failed]: Could not convert file to .wav format!" + ); + + const chunks = []; + const stream = fs.createReadStream(outputFile); + for await (let chunk of stream) chunks.push(chunk); + buffer = Buffer.concat(chunks); + fs.rmSync(outputFile); + } else { + const chunks = []; + const stream = fs.createReadStream(sourcePath); + for await (let chunk of stream) chunks.push(chunk); + buffer = Buffer.concat(chunks); + } + + const wavFile = new wavefile.WaveFile(buffer); + wavFile.toBitDepth("32f"); + wavFile.toSampleRate(16000); + + let audioData = wavFile.getSamples(); + if (Array.isArray(audioData)) { + if (audioData.length > 1) { + const SCALING_FACTOR = Math.sqrt(2); + + // Merge channels into first channel to save memory + for (let i = 0; i < audioData[0].length; ++i) { + audioData[0][i] = + (SCALING_FACTOR * (audioData[0][i] + audioData[1][i])) / 2; + } + } + audioData = audioData[0]; + } + + return audioData; + } catch (error) { + console.error(`convertToWavAudioData`, error); + return null; + } } async client() { if (!fs.existsSync(this.modelPath)) { - console.log( - "\x1b[34m[INFO]\x1b[0m The native whisper model has never been run and will be downloaded right now. Subsequent runs will be faster. (~250MB)\n\n" + this.#log( + `The native whisper model has never been run and will be downloaded right now. Subsequent runs will be faster. (~250MB)` ); } @@ -48,10 +131,45 @@ class LocalWhisper { : {}), }); } catch (error) { - console.error("Failed to load the native whisper model:", error); + this.#log("Failed to load the native whisper model:", error); throw error; } } + + async processFile(fullFilePath, filename) { + try { + const transcriberPromise = new Promise((resolve) => + this.client().then((client) => resolve(client)) + ); + const audioDataPromise = new Promise((resolve) => + this.#convertToWavAudioData(fullFilePath).then((audioData) => + resolve(audioData) + ) + ); + const [audioData, transcriber] = await Promise.all([ + audioDataPromise, + transcriberPromise, + ]); + + if (!audioData) { + this.#log(`Failed to parse content from ${filename}.`); + return { + content: null, + error: `Failed to parse content from ${filename}.`, + }; + } + + this.#log(`Transcribing audio data to text...`); + const { text } = await transcriber(audioData, { + chunk_length_s: 30, + stride_length_s: 5, + }); + + return { content: text, error: null }; + } catch (error) { + return { content: null, error: error.message }; + } + } } module.exports = { diff --git a/collector/yarn.lock b/collector/yarn.lock index bf979c86c..3bb0f1ea7 100644 --- a/collector/yarn.lock +++ b/collector/yarn.lock @@ -372,6 +372,13 @@ asynckit@^0.4.0: resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q== +axios@^0.26.0: + version "0.26.1" + resolved "https://registry.yarnpkg.com/axios/-/axios-0.26.1.tgz#1ede41c51fcf51bbbd6fd43669caaa4f0495aaa9" + integrity sha512-fPwcX4EvnSHuInCMItEhAGnaSEXRBjtzh9fOtsE6E1G6p7vl7edEeZe11QHf18+6+9gR5PbKV/sGKNaD8YaMeA== + dependencies: + follow-redirects "^1.14.8" + b4a@^1.6.4: version "1.6.4" resolved "https://registry.yarnpkg.com/b4a/-/b4a-1.6.4.tgz#ef1c1422cae5ce6535ec191baeed7567443f36c9" @@ -1203,6 +1210,11 @@ fluent-ffmpeg@^2.1.2: async ">=0.2.9" which "^1.1.1" +follow-redirects@^1.14.8: + version "1.15.6" + resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.15.6.tgz#7f815c0cda4249c74ff09e95ef97c23b5fd0399b" + integrity sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA== + form-data-encoder@1.7.2: version "1.7.2" resolved "https://registry.yarnpkg.com/form-data-encoder/-/form-data-encoder-1.7.2.tgz#1f1ae3dccf58ed4690b86d87e4f57c654fbab040" @@ -2304,6 +2316,14 @@ onnxruntime-web@1.14.0: onnxruntime-common "~1.14.0" platform "^1.3.6" +openai@^3.2.1: + version "3.3.0" + resolved "https://registry.yarnpkg.com/openai/-/openai-3.3.0.tgz#a6408016ad0945738e1febf43f2fccca83a3f532" + integrity sha512-uqxI/Au+aPRnsaQRe8CojU0eCR7I0mBiKjD3sNMzY6DaC1ZVrc85u98mtJW6voDug8fgGN+DIZmTDxTthxb7dQ== + dependencies: + axios "^0.26.0" + form-data "^4.0.0" + openai@^4.19.0: version "4.20.1" resolved "https://registry.yarnpkg.com/openai/-/openai-4.20.1.tgz#afa0d496d125b5a0f6cebcb4b9aeabf71e00214e" diff --git a/docker/.env.example b/docker/.env.example index ae4913dc4..ed6fd3bce 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -131,6 +131,16 @@ GID='1000' # ASTRA_DB_APPLICATION_TOKEN= # ASTRA_DB_ENDPOINT= +########################################### +######## Audio Model Selection ############ +########################################### +# (default) use built-in whisper-small model. +# WHISPER_PROVIDER="local" + +# use openai hosted whisper model. +# WHISPER_PROVIDER="openai" +# OPEN_AI_KEY=sk-xxxxxxxx + # CLOUD DEPLOYMENT VARIRABLES ONLY # AUTH_TOKEN="hunter2" # This is the password to your application if remote hosting. # DISABLE_TELEMETRY="false" diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 86f6eb08a..8a57d27bb 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -29,6 +29,9 @@ const GeneralApiKeys = lazy(() => import("@/pages/GeneralSettings/ApiKeys")); const GeneralLLMPreference = lazy( () => import("@/pages/GeneralSettings/LLMPreference") ); +const GeneralTranscriptionPreference = lazy( + () => import("@/pages/GeneralSettings/TranscriptionPreference") +); const GeneralEmbeddingPreference = lazy( () => import("@/pages/GeneralSettings/EmbeddingPreference") ); @@ -76,6 +79,12 @@ export default function App() { path="/settings/llm-preference" element={} /> + + } + /> } diff --git a/frontend/src/components/SettingsSidebar/index.jsx b/frontend/src/components/SettingsSidebar/index.jsx index 84b78064a..a7aca7ffe 100644 --- a/frontend/src/components/SettingsSidebar/index.jsx +++ b/frontend/src/components/SettingsSidebar/index.jsx @@ -19,6 +19,7 @@ import { Notepad, CodeBlock, Barcode, + ClosedCaptioning, } from "@phosphor-icons/react"; import useUser from "@/hooks/useUser"; import { USER_BACKGROUND_COLOR } from "@/utils/constants"; @@ -278,9 +279,17 @@ const SidebarOptions = ({ user = null }) => ( flex={true} allowedRole={["admin"]} /> +