Do not go through LLM to embed when embedding documents (#1428)

This commit is contained in:
Timothy Carambat 2024-05-16 17:51:04 -07:00 committed by GitHub
parent 01cf2fed17
commit cae6cee1b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 34 additions and 66 deletions

View File

@ -100,6 +100,7 @@ function getLLMProvider({ provider = null, model = null } = {}) {
} }
function getEmbeddingEngineSelection() { function getEmbeddingEngineSelection() {
const { NativeEmbedder } = require("../EmbeddingEngines/native");
const engineSelection = process.env.EMBEDDING_ENGINE; const engineSelection = process.env.EMBEDDING_ENGINE;
switch (engineSelection) { switch (engineSelection) {
case "openai": case "openai":
@ -117,7 +118,6 @@ function getEmbeddingEngineSelection() {
const { OllamaEmbedder } = require("../EmbeddingEngines/ollama"); const { OllamaEmbedder } = require("../EmbeddingEngines/ollama");
return new OllamaEmbedder(); return new OllamaEmbedder();
case "native": case "native":
const { NativeEmbedder } = require("../EmbeddingEngines/native");
return new NativeEmbedder(); return new NativeEmbedder();
case "lmstudio": case "lmstudio":
const { LMStudioEmbedder } = require("../EmbeddingEngines/lmstudio"); const { LMStudioEmbedder } = require("../EmbeddingEngines/lmstudio");
@ -126,7 +126,7 @@ function getEmbeddingEngineSelection() {
const { CohereEmbedder } = require("../EmbeddingEngines/cohere"); const { CohereEmbedder } = require("../EmbeddingEngines/cohere");
return new CohereEmbedder(); return new CohereEmbedder();
default: default:
return null; return new NativeEmbedder();
} }
} }

View File

@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const AstraDB = { const AstraDB = {
@ -149,12 +145,13 @@ const AstraDB = {
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} }
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -164,10 +161,9 @@ const AstraDB = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) { for (const [i, vector] of vectorValues.entries()) {

View File

@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { parseAuthHeader } = require("../../http"); const { parseAuthHeader } = require("../../http");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
@ -192,12 +188,13 @@ const Chroma = {
// We have to do this manually as opposed to using LangChains `Chroma.fromDocuments` // We have to do this manually as opposed to using LangChains `Chroma.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents // because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb. // from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -207,10 +204,9 @@ const Chroma = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = { const submission = {
ids: [], ids: [],
embeddings: [], embeddings: [],

View File

@ -1,9 +1,5 @@
const lancedb = require("vectordb"); const lancedb = require("vectordb");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { TextSplitter } = require("../../TextSplitter"); const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
@ -190,12 +186,13 @@ const LanceDb = {
// We have to do this manually as opposed to using LangChains `xyz.fromDocuments` // We have to do this manually as opposed to using LangChains `xyz.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents // because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb. // from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -205,11 +202,10 @@ const LanceDb = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const submissions = []; const submissions = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) { for (const [i, vector] of vectorValues.entries()) {

View File

@ -8,11 +8,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const Milvus = { const Milvus = {
@ -184,12 +180,13 @@ const Milvus = {
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} }
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -199,10 +196,9 @@ const Milvus = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) { for (const [i, vector] of vectorValues.entries()) {

View File

@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const PineconeDB = { const PineconeDB = {
@ -135,12 +131,13 @@ const PineconeDB = {
// because we then cannot atomically control our namespace to granularly find/remove documents // because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb. // from vectordb.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L167 // https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L167
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -150,10 +147,9 @@ const PineconeDB = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) { for (const [i, vector] of vectorValues.entries()) {

View File

@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const QDrant = { const QDrant = {
@ -209,12 +205,13 @@ const QDrant = {
// We have to do this manually as opposed to using LangChains `Qdrant.fromDocuments` // We have to do this manually as opposed to using LangChains `Qdrant.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents // because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb. // from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -224,10 +221,9 @@ const QDrant = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = { const submission = {
ids: [], ids: [],
vectors: [], vectors: [],

View File

@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { camelCase } = require("../../helpers/camelcase"); const { camelCase } = require("../../helpers/camelcase");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
@ -251,12 +247,13 @@ const Weaviate = {
// We have to do this manually as opposed to using LangChains `Chroma.fromDocuments` // We have to do this manually as opposed to using LangChains `Chroma.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents // because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb. // from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -266,10 +263,9 @@ const Weaviate = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = { const submission = {
ids: [], ids: [],
vectors: [], vectors: [],

View File

@ -8,11 +8,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings"); const { SystemSettings } = require("../../../models/systemSettings");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
// Zilliz is basically a copy of Milvus DB class with a different constructor // Zilliz is basically a copy of Milvus DB class with a different constructor
@ -185,12 +181,13 @@ const Zilliz = {
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} }
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({ const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize( chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({ await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size", label: "text_splitter_chunk_size",
}), }),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength EmbedderEngine?.embeddingMaxChunkLength
), ),
chunkOverlap: await SystemSettings.getValueOrFallback( chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" }, { label: "text_splitter_chunk_overlap" },
@ -200,10 +197,9 @@ const Zilliz = {
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Chunks created from document:", textChunks.length); console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) { for (const [i, vector] of vectorValues.entries()) {