mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2024-11-10 17:00:11 +01:00
[FEAT] Cohere LLM and embedder support (#1233)
* getChatCompletion working WIP streaming * WIP * working streaming WIP abort stream * implement cohere embedder support * remove inputType option from cohere embedder * fix cohere LLM from not aborting stream when canceled by user * Patch Cohere implemention * add cohere to onboarding --------- Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
parent
d02013fd71
commit
3caebc47b4
@ -72,6 +72,10 @@ GID='1000'
|
||||
# GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT=4096
|
||||
# GENERIC_OPEN_AI_API_KEY=sk-123abc
|
||||
|
||||
# LLM_PROVIDER='cohere'
|
||||
# COHERE_API_KEY=
|
||||
# COHERE_MODEL_PREF='command-r'
|
||||
|
||||
###########################################
|
||||
######## Embedding API SElECTION ##########
|
||||
###########################################
|
||||
@ -100,6 +104,10 @@ GID='1000'
|
||||
# EMBEDDING_MODEL_PREF='nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.Q4_0.gguf'
|
||||
# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
|
||||
|
||||
# EMBEDDING_ENGINE='cohere'
|
||||
# COHERE_API_KEY=
|
||||
# EMBEDDING_MODEL_PREF='embed-english-v3.0'
|
||||
|
||||
###########################################
|
||||
######## Vector Database Selection ########
|
||||
###########################################
|
||||
|
@ -0,0 +1,55 @@
|
||||
export default function CohereEmbeddingOptions({ settings }) {
|
||||
return (
|
||||
<div className="w-full flex flex-col gap-y-4">
|
||||
<div className="w-full flex items-center gap-4">
|
||||
<div className="flex flex-col w-60">
|
||||
<label className="text-white text-sm font-semibold block mb-4">
|
||||
API Key
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
name="CohereApiKey"
|
||||
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
|
||||
placeholder="Cohere API Key"
|
||||
defaultValue={settings?.CohereApiKey ? "*".repeat(20) : ""}
|
||||
required={true}
|
||||
autoComplete="off"
|
||||
spellCheck={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col w-60">
|
||||
<label className="text-white text-sm font-semibold block mb-4">
|
||||
Model Preference
|
||||
</label>
|
||||
<select
|
||||
name="EmbeddingModelPref"
|
||||
required={true}
|
||||
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
|
||||
>
|
||||
<optgroup label="Available embedding models">
|
||||
{[
|
||||
"embed-english-v3.0",
|
||||
"embed-multilingual-v3.0",
|
||||
"embed-english-light-v3.0",
|
||||
"embed-multilingual-light-v3.0",
|
||||
"embed-english-v2.0",
|
||||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
].map((model) => {
|
||||
return (
|
||||
<option
|
||||
key={model}
|
||||
value={model}
|
||||
selected={settings?.EmbeddingModelPref === model}
|
||||
>
|
||||
{model}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</optgroup>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
export default function CohereAiOptions({ settings }) {
|
||||
return (
|
||||
<div className="w-full flex flex-col">
|
||||
<div className="w-full flex items-center gap-4">
|
||||
<div className="flex flex-col w-60">
|
||||
<label className="text-white text-sm font-semibold block mb-4">
|
||||
Cohere API Key
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
name="CohereApiKey"
|
||||
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
|
||||
placeholder="Cohere API Key"
|
||||
defaultValue={settings?.CohereApiKey ? "*".repeat(20) : ""}
|
||||
required={true}
|
||||
autoComplete="off"
|
||||
spellCheck={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col w-60">
|
||||
<label className="text-white text-sm font-semibold block mb-4">
|
||||
Chat Model Selection
|
||||
</label>
|
||||
<select
|
||||
name="CohereModelPref"
|
||||
defaultValue={settings?.CohereModelPref || "command-r"}
|
||||
required={true}
|
||||
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
|
||||
>
|
||||
{[
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command",
|
||||
"command-light",
|
||||
"command-nightly",
|
||||
"command-light-nightly",
|
||||
].map((model) => {
|
||||
return (
|
||||
<option key={model} value={model}>
|
||||
{model}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -26,6 +26,14 @@ const PROVIDER_DEFAULT_MODELS = {
|
||||
"gemma-7b-it",
|
||||
],
|
||||
native: [],
|
||||
cohere: [
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command",
|
||||
"command-light",
|
||||
"command-nightly",
|
||||
"command-light-nightly",
|
||||
],
|
||||
};
|
||||
|
||||
// For togetherAi, which has a large model list - we subgroup the options
|
||||
|
BIN
frontend/src/media/llmprovider/cohere.png
Normal file
BIN
frontend/src/media/llmprovider/cohere.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 139 KiB |
@ -9,6 +9,7 @@ import AzureOpenAiLogo from "@/media/llmprovider/azure.png";
|
||||
import LocalAiLogo from "@/media/llmprovider/localai.png";
|
||||
import OllamaLogo from "@/media/llmprovider/ollama.png";
|
||||
import LMStudioLogo from "@/media/llmprovider/lmstudio.png";
|
||||
import CohereLogo from "@/media/llmprovider/cohere.png";
|
||||
import PreLoader from "@/components/Preloader";
|
||||
import ChangeWarningModal from "@/components/ChangeWarning";
|
||||
import OpenAiOptions from "@/components/EmbeddingSelection/OpenAiOptions";
|
||||
@ -17,6 +18,8 @@ import LocalAiOptions from "@/components/EmbeddingSelection/LocalAiOptions";
|
||||
import NativeEmbeddingOptions from "@/components/EmbeddingSelection/NativeEmbeddingOptions";
|
||||
import OllamaEmbeddingOptions from "@/components/EmbeddingSelection/OllamaOptions";
|
||||
import LMStudioEmbeddingOptions from "@/components/EmbeddingSelection/LMStudioOptions";
|
||||
import CohereEmbeddingOptions from "@/components/EmbeddingSelection/CohereOptions";
|
||||
|
||||
import EmbedderItem from "@/components/EmbeddingSelection/EmbedderItem";
|
||||
import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
import { useModal } from "@/hooks/useModal";
|
||||
@ -68,6 +71,13 @@ const EMBEDDERS = [
|
||||
description:
|
||||
"Discover, download, and run thousands of cutting edge LLMs in a few clicks.",
|
||||
},
|
||||
{
|
||||
name: "Cohere",
|
||||
value: "cohere",
|
||||
logo: CohereLogo,
|
||||
options: (settings) => <CohereEmbeddingOptions settings={settings} />,
|
||||
description: "Run powerful embedding models from Cohere.",
|
||||
},
|
||||
];
|
||||
|
||||
export default function GeneralEmbeddingPreference() {
|
||||
|
@ -18,6 +18,7 @@ import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
|
||||
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
|
||||
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
|
||||
import GroqLogo from "@/media/llmprovider/groq.png";
|
||||
import CohereLogo from "@/media/llmprovider/cohere.png";
|
||||
import PreLoader from "@/components/Preloader";
|
||||
import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions";
|
||||
import GenericOpenAiOptions from "@/components/LLMSelection/GenericOpenAiOptions";
|
||||
@ -34,6 +35,7 @@ import HuggingFaceOptions from "@/components/LLMSelection/HuggingFaceOptions";
|
||||
import PerplexityOptions from "@/components/LLMSelection/PerplexityOptions";
|
||||
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
|
||||
import GroqAiOptions from "@/components/LLMSelection/GroqAiOptions";
|
||||
import CohereAiOptions from "@/components/LLMSelection/CohereAiOptions";
|
||||
|
||||
import LLMItem from "@/components/LLMSelection/LLMItem";
|
||||
import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
@ -152,6 +154,14 @@ export const AVAILABLE_LLM_PROVIDERS = [
|
||||
"The fastest LLM inferencing available for real-time AI applications.",
|
||||
requiredConfig: ["GroqApiKey"],
|
||||
},
|
||||
{
|
||||
name: "Cohere",
|
||||
value: "cohere",
|
||||
logo: CohereLogo,
|
||||
options: (settings) => <CohereAiOptions settings={settings} />,
|
||||
description: "Run Cohere's powerful Command models.",
|
||||
requiredConfig: ["CohereApiKey"],
|
||||
},
|
||||
{
|
||||
name: "Generic OpenAI",
|
||||
value: "generic-openai",
|
||||
|
@ -15,6 +15,7 @@ import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
|
||||
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
|
||||
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
|
||||
import GroqLogo from "@/media/llmprovider/groq.png";
|
||||
import CohereLogo from "@/media/llmprovider/cohere.png";
|
||||
import ZillizLogo from "@/media/vectordbs/zilliz.png";
|
||||
import AstraDBLogo from "@/media/vectordbs/astraDB.png";
|
||||
import ChromaLogo from "@/media/vectordbs/chroma.png";
|
||||
@ -144,6 +145,13 @@ export const LLM_SELECTION_PRIVACY = {
|
||||
],
|
||||
logo: GenericOpenAiLogo,
|
||||
},
|
||||
cohere: {
|
||||
name: "Cohere",
|
||||
description: [
|
||||
"Data is shared according to the terms of service of cohere.com and your localities privacy laws.",
|
||||
],
|
||||
logo: CohereLogo,
|
||||
},
|
||||
};
|
||||
|
||||
export const VECTOR_DB_PRIVACY = {
|
||||
@ -252,6 +260,13 @@ export const EMBEDDING_ENGINE_PRIVACY = {
|
||||
],
|
||||
logo: LMStudioLogo,
|
||||
},
|
||||
cohere: {
|
||||
name: "Cohere",
|
||||
description: [
|
||||
"Data is shared according to the terms of service of cohere.com and your localities privacy laws.",
|
||||
],
|
||||
logo: CohereLogo,
|
||||
},
|
||||
};
|
||||
|
||||
export default function DataHandling({ setHeader, setForwardBtn, setBackBtn }) {
|
||||
|
@ -15,6 +15,7 @@ import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
|
||||
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
|
||||
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
|
||||
import GroqLogo from "@/media/llmprovider/groq.png";
|
||||
import CohereLogo from "@/media/llmprovider/cohere.png";
|
||||
import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions";
|
||||
import GenericOpenAiOptions from "@/components/LLMSelection/GenericOpenAiOptions";
|
||||
import AzureAiOptions from "@/components/LLMSelection/AzureAiOptions";
|
||||
@ -30,6 +31,8 @@ import TogetherAiOptions from "@/components/LLMSelection/TogetherAiOptions";
|
||||
import PerplexityOptions from "@/components/LLMSelection/PerplexityOptions";
|
||||
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
|
||||
import GroqAiOptions from "@/components/LLMSelection/GroqAiOptions";
|
||||
import CohereAiOptions from "@/components/LLMSelection/CohereAiOptions";
|
||||
|
||||
import LLMItem from "@/components/LLMSelection/LLMItem";
|
||||
import System from "@/models/system";
|
||||
import paths from "@/utils/paths";
|
||||
@ -136,6 +139,13 @@ const LLMS = [
|
||||
description:
|
||||
"The fastest LLM inferencing available for real-time AI applications.",
|
||||
},
|
||||
{
|
||||
name: "Cohere",
|
||||
value: "cohere",
|
||||
logo: CohereLogo,
|
||||
options: (settings) => <CohereAiOptions settings={settings} />,
|
||||
description: "Run Cohere's powerful Command models.",
|
||||
},
|
||||
{
|
||||
name: "Generic OpenAI",
|
||||
value: "generic-openai",
|
||||
|
@ -69,6 +69,10 @@ JWT_SECRET="my-random-string-for-seeding" # Please generate random string at lea
|
||||
# GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT=4096
|
||||
# GENERIC_OPEN_AI_API_KEY=sk-123abc
|
||||
|
||||
# LLM_PROVIDER='cohere'
|
||||
# COHERE_API_KEY=
|
||||
# COHERE_MODEL_PREF='command-r'
|
||||
|
||||
###########################################
|
||||
######## Embedding API SElECTION ##########
|
||||
###########################################
|
||||
@ -97,6 +101,10 @@ JWT_SECRET="my-random-string-for-seeding" # Please generate random string at lea
|
||||
# EMBEDDING_MODEL_PREF='nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.Q4_0.gguf'
|
||||
# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
|
||||
|
||||
# EMBEDDING_ENGINE='cohere'
|
||||
# COHERE_API_KEY=
|
||||
# EMBEDDING_MODEL_PREF='embed-english-v3.0'
|
||||
|
||||
###########################################
|
||||
######## Vector Database Selection ########
|
||||
###########################################
|
||||
|
@ -364,6 +364,10 @@ const SystemSettings = {
|
||||
GenericOpenAiModelPref: process.env.GENERIC_OPEN_AI_MODEL_PREF,
|
||||
GenericOpenAiTokenLimit: process.env.GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT,
|
||||
GenericOpenAiKey: !!process.env.GENERIC_OPEN_AI_API_KEY,
|
||||
|
||||
// Cohere API Keys
|
||||
CohereApiKey: !!process.env.COHERE_API_KEY,
|
||||
CohereModelPref: process.env.COHERE_MODEL_PREF,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
@ -41,6 +41,7 @@
|
||||
"chalk": "^4",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"chromadb": "^1.5.2",
|
||||
"cohere-ai": "^7.9.5",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.0.3",
|
||||
"express": "^4.18.2",
|
||||
|
226
server/utils/AiProviders/cohere/index.js
Normal file
226
server/utils/AiProviders/cohere/index.js
Normal file
@ -0,0 +1,226 @@
|
||||
const { v4 } = require("uuid");
|
||||
const { writeResponseChunk } = require("../../helpers/chat/responses");
|
||||
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
|
||||
|
||||
class CohereLLM {
|
||||
constructor(embedder = null) {
|
||||
const { CohereClient } = require("cohere-ai");
|
||||
if (!process.env.COHERE_API_KEY)
|
||||
throw new Error("No Cohere API key was set.");
|
||||
|
||||
const cohere = new CohereClient({
|
||||
token: process.env.COHERE_API_KEY,
|
||||
});
|
||||
|
||||
this.cohere = cohere;
|
||||
this.model = process.env.COHERE_MODEL_PREF;
|
||||
this.limits = {
|
||||
history: this.promptWindowLimit() * 0.15,
|
||||
system: this.promptWindowLimit() * 0.15,
|
||||
user: this.promptWindowLimit() * 0.7,
|
||||
};
|
||||
this.embedder = !!embedder ? embedder : new NativeEmbedder();
|
||||
}
|
||||
|
||||
#appendContext(contextTexts = []) {
|
||||
if (!contextTexts || !contextTexts.length) return "";
|
||||
return (
|
||||
"\nContext:\n" +
|
||||
contextTexts
|
||||
.map((text, i) => {
|
||||
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
|
||||
})
|
||||
.join("")
|
||||
);
|
||||
}
|
||||
|
||||
#convertChatHistoryCohere(chatHistory = []) {
|
||||
let cohereHistory = [];
|
||||
chatHistory.forEach((message) => {
|
||||
switch (message.role) {
|
||||
case "system":
|
||||
cohereHistory.push({ role: "SYSTEM", message: message.content });
|
||||
break;
|
||||
case "user":
|
||||
cohereHistory.push({ role: "USER", message: message.content });
|
||||
break;
|
||||
case "assistant":
|
||||
cohereHistory.push({ role: "CHATBOT", message: message.content });
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
return cohereHistory;
|
||||
}
|
||||
|
||||
streamingEnabled() {
|
||||
return "streamGetChatCompletion" in this;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
switch (this.model) {
|
||||
case "command-r":
|
||||
return 128_000;
|
||||
case "command-r-plus":
|
||||
return 128_000;
|
||||
case "command":
|
||||
return 4_096;
|
||||
case "command-light":
|
||||
return 4_096;
|
||||
case "command-nightly":
|
||||
return 8_192;
|
||||
case "command-light-nightly":
|
||||
return 8_192;
|
||||
default:
|
||||
return 4_096;
|
||||
}
|
||||
}
|
||||
|
||||
async isValidChatCompletionModel(model = "") {
|
||||
const validModels = [
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command",
|
||||
"command-light",
|
||||
"command-nightly",
|
||||
"command-light-nightly",
|
||||
];
|
||||
return validModels.includes(model);
|
||||
}
|
||||
|
||||
constructPrompt({
|
||||
systemPrompt = "",
|
||||
contextTexts = [],
|
||||
chatHistory = [],
|
||||
userPrompt = "",
|
||||
}) {
|
||||
const prompt = {
|
||||
role: "system",
|
||||
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
|
||||
};
|
||||
return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
|
||||
}
|
||||
|
||||
async isSafe(_input = "") {
|
||||
// Not implemented so must be stubbed
|
||||
return { safe: true, reasons: [] };
|
||||
}
|
||||
|
||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
if (!(await this.isValidChatCompletionModel(this.model)))
|
||||
throw new Error(
|
||||
`Cohere chat: ${this.model} is not valid for chat completion!`
|
||||
);
|
||||
|
||||
const message = messages[messages.length - 1].content; // Get the last message
|
||||
const cohereHistory = this.#convertChatHistoryCohere(messages.slice(0, -1)); // Remove the last message and convert to Cohere
|
||||
|
||||
const chat = await this.cohere.chat({
|
||||
model: this.model,
|
||||
message: message,
|
||||
chatHistory: cohereHistory,
|
||||
temperature,
|
||||
});
|
||||
|
||||
if (!chat.hasOwnProperty("text")) return null;
|
||||
return chat.text;
|
||||
}
|
||||
|
||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
if (!(await this.isValidChatCompletionModel(this.model)))
|
||||
throw new Error(
|
||||
`Cohere chat: ${this.model} is not valid for chat completion!`
|
||||
);
|
||||
|
||||
const message = messages[messages.length - 1].content; // Get the last message
|
||||
const cohereHistory = this.#convertChatHistoryCohere(messages.slice(0, -1)); // Remove the last message and convert to Cohere
|
||||
|
||||
const stream = await this.cohere.chatStream({
|
||||
model: this.model,
|
||||
message: message,
|
||||
chatHistory: cohereHistory,
|
||||
temperature,
|
||||
});
|
||||
|
||||
return { type: "stream", stream: stream };
|
||||
}
|
||||
|
||||
async handleStream(response, stream, responseProps) {
|
||||
return new Promise(async (resolve) => {
|
||||
let fullText = "";
|
||||
const { uuid = v4(), sources = [] } = responseProps;
|
||||
|
||||
const handleAbort = () => {
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources,
|
||||
type: "abort",
|
||||
textResponse: fullText,
|
||||
close: true,
|
||||
error: false,
|
||||
});
|
||||
response.removeListener("close", handleAbort);
|
||||
resolve(fullText);
|
||||
};
|
||||
response.on("close", handleAbort);
|
||||
|
||||
try {
|
||||
for await (const chat of stream.stream) {
|
||||
if (chat.eventType === "text-generation") {
|
||||
const text = chat.text;
|
||||
fullText += text;
|
||||
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources,
|
||||
type: "textResponseChunk",
|
||||
textResponse: text,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources,
|
||||
type: "textResponseChunk",
|
||||
textResponse: "",
|
||||
close: true,
|
||||
error: false,
|
||||
});
|
||||
response.removeListener("close", handleAbort);
|
||||
resolve(fullText);
|
||||
} catch (error) {
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources,
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
close: true,
|
||||
error: error.message,
|
||||
});
|
||||
response.removeListener("close", handleAbort);
|
||||
resolve(fullText);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
|
||||
async embedTextInput(textInput) {
|
||||
return await this.embedder.embedTextInput(textInput);
|
||||
}
|
||||
async embedChunks(textChunks = []) {
|
||||
return await this.embedder.embedChunks(textChunks);
|
||||
}
|
||||
|
||||
async compressMessages(promptArgs = {}, rawHistory = []) {
|
||||
const { messageArrayCompressor } = require("../../helpers/chat");
|
||||
const messageArray = this.constructPrompt(promptArgs);
|
||||
return await messageArrayCompressor(this, messageArray, rawHistory);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
CohereLLM,
|
||||
};
|
86
server/utils/EmbeddingEngines/cohere/index.js
Normal file
86
server/utils/EmbeddingEngines/cohere/index.js
Normal file
@ -0,0 +1,86 @@
|
||||
const { toChunks } = require("../../helpers");
|
||||
|
||||
class CohereEmbedder {
|
||||
constructor() {
|
||||
if (!process.env.COHERE_API_KEY)
|
||||
throw new Error("No Cohere API key was set.");
|
||||
|
||||
const { CohereClient } = require("cohere-ai");
|
||||
const cohere = new CohereClient({
|
||||
token: process.env.COHERE_API_KEY,
|
||||
});
|
||||
|
||||
this.cohere = cohere;
|
||||
this.model = process.env.EMBEDDING_MODEL_PREF || "embed-english-v3.0";
|
||||
this.inputType = "search_document";
|
||||
|
||||
// Limit of how many strings we can process in a single pass to stay with resource or network limits
|
||||
this.maxConcurrentChunks = 96; // Cohere's limit per request is 96
|
||||
this.embeddingMaxChunkLength = 1945; // https://docs.cohere.com/docs/embed-2 - assume a token is roughly 4 letters with some padding
|
||||
}
|
||||
|
||||
async embedTextInput(textInput) {
|
||||
this.inputType = "search_query";
|
||||
const result = await this.embedChunks([textInput]);
|
||||
return result?.[0] || [];
|
||||
}
|
||||
|
||||
async embedChunks(textChunks = []) {
|
||||
const embeddingRequests = [];
|
||||
this.inputType = "search_document";
|
||||
|
||||
for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
|
||||
embeddingRequests.push(
|
||||
new Promise((resolve) => {
|
||||
this.cohere
|
||||
.embed({
|
||||
texts: chunk,
|
||||
model: this.model,
|
||||
inputType: this.inputType,
|
||||
})
|
||||
.then((res) => {
|
||||
resolve({ data: res.embeddings, error: null });
|
||||
})
|
||||
.catch((e) => {
|
||||
e.type =
|
||||
e?.response?.data?.error?.code ||
|
||||
e?.response?.status ||
|
||||
"failed_to_embed";
|
||||
e.message = e?.response?.data?.error?.message || e.message;
|
||||
resolve({ data: [], error: e });
|
||||
});
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
const { data = [], error = null } = await Promise.all(
|
||||
embeddingRequests
|
||||
).then((results) => {
|
||||
const errors = results
|
||||
.filter((res) => !!res.error)
|
||||
.map((res) => res.error)
|
||||
.flat();
|
||||
|
||||
if (errors.length > 0) {
|
||||
let uniqueErrors = new Set();
|
||||
errors.map((error) =>
|
||||
uniqueErrors.add(`[${error.type}]: ${error.message}`)
|
||||
);
|
||||
return { data: [], error: Array.from(uniqueErrors).join(", ") };
|
||||
}
|
||||
|
||||
return {
|
||||
data: results.map((res) => res?.data || []).flat(),
|
||||
error: null,
|
||||
};
|
||||
});
|
||||
|
||||
if (!!error) throw new Error(`Cohere Failed to embed: ${error}`);
|
||||
|
||||
return data.length > 0 ? data : null;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
CohereEmbedder,
|
||||
};
|
@ -77,6 +77,9 @@ function getLLMProvider({ provider = null, model = null } = {}) {
|
||||
case "groq":
|
||||
const { GroqLLM } = require("../AiProviders/groq");
|
||||
return new GroqLLM(embedder, model);
|
||||
case "cohere":
|
||||
const { CohereLLM } = require("../AiProviders/cohere");
|
||||
return new CohereLLM(embedder, model);
|
||||
case "generic-openai":
|
||||
const { GenericOpenAiLLM } = require("../AiProviders/genericOpenAi");
|
||||
return new GenericOpenAiLLM(embedder, model);
|
||||
@ -110,6 +113,9 @@ function getEmbeddingEngineSelection() {
|
||||
case "lmstudio":
|
||||
const { LMStudioEmbedder } = require("../EmbeddingEngines/lmstudio");
|
||||
return new LMStudioEmbedder();
|
||||
case "cohere":
|
||||
const { CohereEmbedder } = require("../EmbeddingEngines/cohere");
|
||||
return new CohereEmbedder();
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
@ -290,6 +290,16 @@ const KEY_MAPPING = {
|
||||
checks: [isNotEmpty],
|
||||
},
|
||||
|
||||
// Cohere Options
|
||||
CohereApiKey: {
|
||||
envKey: "COHERE_API_KEY",
|
||||
checks: [isNotEmpty],
|
||||
},
|
||||
CohereModelPref: {
|
||||
envKey: "COHERE_MODEL_PREF",
|
||||
checks: [isNotEmpty],
|
||||
},
|
||||
|
||||
// Whisper (transcription) providers
|
||||
WhisperProvider: {
|
||||
envKey: "WHISPER_PROVIDER",
|
||||
@ -393,6 +403,7 @@ function supportedLLM(input = "") {
|
||||
"perplexity",
|
||||
"openrouter",
|
||||
"groq",
|
||||
"cohere",
|
||||
"generic-openai",
|
||||
].includes(input);
|
||||
return validSelection ? null : `${input} is not a valid LLM provider.`;
|
||||
@ -434,6 +445,7 @@ function supportedEmbeddingModel(input = "") {
|
||||
"native",
|
||||
"ollama",
|
||||
"lmstudio",
|
||||
"cohere",
|
||||
];
|
||||
return supported.includes(input)
|
||||
? null
|
||||
|
@ -1817,6 +1817,17 @@ cmake-js@^7.2.1:
|
||||
which "^2.0.2"
|
||||
yargs "^17.7.2"
|
||||
|
||||
cohere-ai@^7.9.5:
|
||||
version "7.9.5"
|
||||
resolved "https://registry.yarnpkg.com/cohere-ai/-/cohere-ai-7.9.5.tgz#05a592fe19decb8692d1b19d93ac835d7f816b8b"
|
||||
integrity sha512-tr8LUR3Q46agFpfEwaYwzYO4qAuN0/R/8YroG4bc86LadOacBAabctZUq0zfCdLiL7gB4yWJs4QCzfpRH3rQuw==
|
||||
dependencies:
|
||||
form-data "4.0.0"
|
||||
js-base64 "3.7.2"
|
||||
node-fetch "2.7.0"
|
||||
qs "6.11.2"
|
||||
url-join "4.0.1"
|
||||
|
||||
color-convert@^1.9.3:
|
||||
version "1.9.3"
|
||||
resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-1.9.3.tgz#bb71850690e1f136567de629d2d5471deda4c1e8"
|
||||
@ -2846,19 +2857,19 @@ form-data-encoder@1.7.2:
|
||||
resolved "https://registry.yarnpkg.com/form-data-encoder/-/form-data-encoder-1.7.2.tgz#1f1ae3dccf58ed4690b86d87e4f57c654fbab040"
|
||||
integrity sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==
|
||||
|
||||
form-data@^3.0.0:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f"
|
||||
integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==
|
||||
form-data@4.0.0, form-data@^4.0.0:
|
||||
version "4.0.0"
|
||||
resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.0.tgz#93919daeaf361ee529584b9b31664dc12c9fa452"
|
||||
integrity sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==
|
||||
dependencies:
|
||||
asynckit "^0.4.0"
|
||||
combined-stream "^1.0.8"
|
||||
mime-types "^2.1.12"
|
||||
|
||||
form-data@^4.0.0:
|
||||
version "4.0.0"
|
||||
resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.0.tgz#93919daeaf361ee529584b9b31664dc12c9fa452"
|
||||
integrity sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==
|
||||
form-data@^3.0.0:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f"
|
||||
integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==
|
||||
dependencies:
|
||||
asynckit "^0.4.0"
|
||||
combined-stream "^1.0.8"
|
||||
@ -3652,6 +3663,11 @@ joi@^17.11.0:
|
||||
"@sideway/formula" "^3.0.1"
|
||||
"@sideway/pinpoint" "^2.0.0"
|
||||
|
||||
js-base64@3.7.2:
|
||||
version "3.7.2"
|
||||
resolved "https://registry.yarnpkg.com/js-base64/-/js-base64-3.7.2.tgz#816d11d81a8aff241603d19ce5761e13e41d7745"
|
||||
integrity sha512-NnRs6dsyqUXejqk/yv2aiXlAvOs56sLkX6nUdeaNezI5LFFLlsZjOThmwnrcwh5ZZRwZlCMnVAY3CvhIhoVEKQ==
|
||||
|
||||
js-tiktoken@^1.0.11, js-tiktoken@^1.0.7, js-tiktoken@^1.0.8:
|
||||
version "1.0.11"
|
||||
resolved "https://registry.yarnpkg.com/js-tiktoken/-/js-tiktoken-1.0.11.tgz#d7d707b849f703841112660d9d55169424a35344"
|
||||
@ -4324,7 +4340,7 @@ node-domexception@1.0.0:
|
||||
resolved "https://registry.yarnpkg.com/node-domexception/-/node-domexception-1.0.0.tgz#6888db46a1f71c0b76b3f7555016b63fe64766e5"
|
||||
integrity sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==
|
||||
|
||||
node-fetch@^2.6.1, node-fetch@^2.6.12, node-fetch@^2.6.7, node-fetch@^2.6.9:
|
||||
node-fetch@2.7.0, node-fetch@^2.6.1, node-fetch@^2.6.12, node-fetch@^2.6.7, node-fetch@^2.6.9:
|
||||
version "2.7.0"
|
||||
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.7.0.tgz#d0f0fa6e3e2dc1d27efcd8ad99d550bda94d187d"
|
||||
integrity sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==
|
||||
@ -4947,6 +4963,13 @@ qs@6.11.0:
|
||||
dependencies:
|
||||
side-channel "^1.0.4"
|
||||
|
||||
qs@6.11.2:
|
||||
version "6.11.2"
|
||||
resolved "https://registry.yarnpkg.com/qs/-/qs-6.11.2.tgz#64bea51f12c1f5da1bc01496f48ffcff7c69d7d9"
|
||||
integrity sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==
|
||||
dependencies:
|
||||
side-channel "^1.0.4"
|
||||
|
||||
qs@^6.7.0:
|
||||
version "6.12.1"
|
||||
resolved "https://registry.yarnpkg.com/qs/-/qs-6.12.1.tgz#39422111ca7cbdb70425541cba20c7d7b216599a"
|
||||
@ -5862,7 +5885,7 @@ uri-js@^4.2.2, uri-js@^4.4.1:
|
||||
dependencies:
|
||||
punycode "^2.1.0"
|
||||
|
||||
url-join@^4.0.1:
|
||||
url-join@4.0.1, url-join@^4.0.1:
|
||||
version "4.0.1"
|
||||
resolved "https://registry.yarnpkg.com/url-join/-/url-join-4.0.1.tgz#b642e21a2646808ffa178c4c5fda39844e12cde7"
|
||||
integrity sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==
|
||||
|
Loading…
Reference in New Issue
Block a user