From 2a1202de548e8857764f0f129169a0dfd911c5c7 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Thu, 28 Dec 2023 13:59:47 -0800 Subject: [PATCH] Patch Ollama Streaming chunk issues (#500) Replace stream/sync chats with Langchain interface for now connect #499 ref: https://github.com/Mintplex-Labs/anything-llm/issues/495#issuecomment-1871476091 --- .vscode/settings.json | 1 + server/utils/AiProviders/ollama/index.js | 175 ++++++++++------------- server/utils/chats/stream.js | 35 +---- 3 files changed, 79 insertions(+), 132 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 2e43b192..e6e720a9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "cSpell.words": [ "Dockerized", + "Langchain", "Ollama", "openai", "Qdrant", diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index 3aa58f76..f160e5d3 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -1,4 +1,5 @@ const { chatPrompt } = require("../../chats"); +const { StringOutputParser } = require("langchain/schema/output_parser"); // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md class OllamaAILLM { @@ -21,6 +22,42 @@ class OllamaAILLM { this.embedder = embedder; } + #ollamaClient({ temperature = 0.07 }) { + const { ChatOllama } = require("langchain/chat_models/ollama"); + return new ChatOllama({ + baseUrl: this.basePath, + model: this.model, + temperature, + }); + } + + // For streaming we use Langchain's wrapper to handle weird chunks + // or otherwise absorb headaches that can arise from Ollama models + #convertToLangchainPrototypes(chats = []) { + const { + HumanMessage, + SystemMessage, + AIMessage, + } = require("langchain/schema"); + const langchainChats = []; + for (const chat of chats) { + switch (chat.role) { + case "system": + langchainChats.push(new SystemMessage({ content: chat.content })); + break; + case "user": + langchainChats.push(new HumanMessage({ content: chat.content })); + break; + case "assistant": + langchainChats.push(new AIMessage({ content: chat.content })); + break; + default: + break; + } + } + return langchainChats; + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -63,37 +100,21 @@ Context: } async sendChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { - const textResponse = await fetch(`${this.basePath}/api/chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", + const messages = await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + userPrompt: prompt, + chatHistory, }, - body: JSON.stringify({ - model: this.model, - stream: false, - options: { - temperature: Number(workspace?.openAiTemp ?? 0.7), - }, - messages: await this.compressMessages( - { - systemPrompt: chatPrompt(workspace), - userPrompt: prompt, - chatHistory, - }, - rawHistory - ), - }), - }) - .then((res) => { - if (!res.ok) - throw new Error(`Ollama:sendChat ${res.status} ${res.statusText}`); - return res.json(); - }) - .then((data) => data?.message?.content) - .catch((e) => { - console.error(e); - throw new Error(`Ollama::sendChat failed with: ${error.message}`); - }); + rawHistory + ); + + const model = this.#ollamaClient({ + temperature: Number(workspace?.openAiTemp ?? 0.7), + }); + const textResponse = await model + .pipe(new StringOutputParser()) + .invoke(this.#convertToLangchainPrototypes(messages)); if (!textResponse.length) throw new Error(`Ollama::sendChat text response was empty.`); @@ -102,63 +123,29 @@ Context: } async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { - const response = await fetch(`${this.basePath}/api/chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", + const messages = await this.compressMessages( + { + systemPrompt: chatPrompt(workspace), + userPrompt: prompt, + chatHistory, }, - body: JSON.stringify({ - model: this.model, - stream: true, - options: { - temperature: Number(workspace?.openAiTemp ?? 0.7), - }, - messages: await this.compressMessages( - { - systemPrompt: chatPrompt(workspace), - userPrompt: prompt, - chatHistory, - }, - rawHistory - ), - }), - }).catch((e) => { - console.error(e); - throw new Error(`Ollama:streamChat ${error.message}`); - }); + rawHistory + ); - return { type: "ollamaStream", response }; + const model = this.#ollamaClient({ + temperature: Number(workspace?.openAiTemp ?? 0.7), + }); + const stream = await model + .pipe(new StringOutputParser()) + .stream(this.#convertToLangchainPrototypes(messages)); + return stream; } async getChatCompletion(messages = null, { temperature = 0.7 }) { - const textResponse = await fetch(`${this.basePath}/api/chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: this.model, - messages, - stream: false, - options: { - temperature, - }, - }), - }) - .then((res) => { - if (!res.ok) - throw new Error( - `Ollama:getChatCompletion ${res.status} ${res.statusText}` - ); - return res.json(); - }) - .then((data) => data?.message?.content) - .catch((e) => { - console.error(e); - throw new Error( - `Ollama::getChatCompletion failed with: ${error.message}` - ); - }); + const model = this.#ollamaClient({ temperature }); + const textResponse = await model + .pipe(new StringOutputParser()) + .invoke(this.#convertToLangchainPrototypes(messages)); if (!textResponse.length) throw new Error(`Ollama::getChatCompletion text response was empty.`); @@ -167,25 +154,11 @@ Context: } async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { - const response = await fetch(`${this.basePath}/api/chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: this.model, - stream: true, - messages, - options: { - temperature, - }, - }), - }).catch((e) => { - console.error(e); - throw new Error(`Ollama:streamGetChatCompletion ${error.message}`); - }); - - return { type: "ollamaStream", response }; + const model = this.#ollamaClient({ temperature }); + const stream = await model + .pipe(new StringOutputParser()) + .stream(this.#convertToLangchainPrototypes(messages)); + return stream; } // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index b0dc9186..240e4a17 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -232,46 +232,19 @@ function handleStreamResponses(response, stream, responseProps) { }); } - if (stream?.type === "ollamaStream") { - return new Promise(async (resolve) => { - let fullText = ""; - for await (const dataChunk of stream.response.body) { - const chunk = JSON.parse(Buffer.from(dataChunk).toString()); - fullText += chunk.message.content; - writeResponseChunk(response, { - uuid, - sources: [], - type: "textResponseChunk", - textResponse: chunk.message.content, - close: false, - error: false, - }); - } - - writeResponseChunk(response, { - uuid, - sources, - type: "textResponseChunk", - textResponse: "", - close: true, - error: false, - }); - resolve(fullText); - }); - } - - // If stream is not a regular OpenAI Stream (like if using native model) + // If stream is not a regular OpenAI Stream (like if using native model, Ollama, or most LangChain interfaces) // we can just iterate the stream content instead. if (!stream.hasOwnProperty("data")) { return new Promise(async (resolve) => { let fullText = ""; for await (const chunk of stream) { - fullText += chunk.content; + const content = chunk.hasOwnProperty("content") ? chunk.content : chunk; + fullText += content; writeResponseChunk(response, { uuid, sources: [], type: "textResponseChunk", - textResponse: chunk.content, + textResponse: content, close: false, error: false, });