[Fork] Batch embed by jwaltz (#153)

* refactor: convert chunk embedding to one API call

* chore: lint

* fix chroma for batch and single vectorization of text

* Fix LanceDB multi and single vectorization

* Fix pinecone for single and multiple embeddings

---------

Co-authored-by: Jonathan Waltz <volcanicislander@gmail.com>
This commit is contained in:
Timothy Carambat 2023-07-20 12:05:23 -07:00 committed by GitHub
parent 5a7d8add6f
commit c1deca4928
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 61 deletions

View File

@ -26,14 +26,24 @@ function curateSources(sources = []) {
const knownDocs = [];
const documents = [];
// Sometimes the source may or may not have a metadata property
// in the response so we search for it explicitly or just spread the entire
// source and check to see if at least title exists.
for (const source of sources) {
const { metadata = {} } = source;
if (
Object.keys(metadata).length > 0 &&
!knownDocs.includes(metadata.title)
) {
documents.push({ ...metadata });
knownDocs.push(metadata.title);
if (source.hasOwnProperty("metadata")) {
const { metadata = {} } = source;
if (
Object.keys(metadata).length > 0 &&
!knownDocs.includes(metadata.title)
) {
documents.push({ ...metadata });
knownDocs.push(metadata.title);
}
} else {
if (Object.keys(source).length > 0 && !knownDocs.includes(source.title)) {
documents.push({ ...source });
knownDocs.push(source.title);
}
}
}

View File

@ -80,15 +80,20 @@ const Chroma = {
temperature,
});
},
embedChunk: async function (openai, textChunk) {
embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: textChunk,
input: chunks,
});
return data.length > 0 && data[0].hasOwnProperty("embedding")
? data[0].embedding
return data.length > 0 &&
data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null;
},
similarityResponse: async function (client, namespace, queryVector) {
@ -205,7 +210,7 @@ const Chroma = {
const documentVectors = [];
const vectors = [];
const openai = this.openai();
const vectorValues = await this.embedChunks(openai, textChunks);
const submission = {
ids: [],
embeddings: [],
@ -213,31 +218,29 @@ const Chroma = {
documents: [],
};
for (const textChunk of textChunks) {
const vectorValues = await this.embedChunk(openai, textChunk);
if (!!vectorValues) {
if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
const vectorRecord = {
id: uuidv4(),
values: vectorValues,
values: vector,
// [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk },
metadata: { ...metadata, text: textChunks[i] },
};
submission.ids.push(vectorRecord.id);
submission.embeddings.push(vectorRecord.values);
submission.metadatas.push(metadata);
submission.documents.push(textChunk);
submission.documents.push(textChunks[i]);
vectors.push(vectorRecord);
documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
}
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
}
const { client } = await this.connect();
@ -340,7 +343,7 @@ const Chroma = {
};
}
const queryVector = await this.embedChunk(this.openai(), input);
const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse(
client,
namespace,

View File

@ -51,6 +51,22 @@ const LanceDb = {
process.env.OPEN_AI_KEY
);
},
embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: chunks,
});
return data.length > 0 &&
data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null;
},
embedder: function () {
return new OpenAIEmbeddings({ openAIApiKey: process.env.OPEN_AI_KEY });
},
@ -59,17 +75,6 @@ const LanceDb = {
const openai = new OpenAIApi(config);
return openai;
},
embedChunk: async function (openai, textChunk) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: textChunk,
});
return data.length > 0 && data[0].hasOwnProperty("embedding")
? data[0].embedding
: null;
},
getChatCompletion: async function (
openai,
messages = [],
@ -194,18 +199,17 @@ const LanceDb = {
const vectors = [];
const submissions = [];
const openai = this.openai();
const vectorValues = await this.embedChunks(openai, textChunks);
for (const textChunk of textChunks) {
const vectorValues = await this.embedChunk(openai, textChunk);
if (!!vectorValues) {
if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
const vectorRecord = {
id: uuidv4(),
values: vectorValues,
values: vector,
// [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk },
metadata: { ...metadata, text: textChunks[i] },
};
vectors.push(vectorRecord);
@ -215,11 +219,11 @@ const LanceDb = {
...vectorRecord.metadata,
});
documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
}
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
}
if (vectors.length > 0) {
@ -253,7 +257,7 @@ const LanceDb = {
}
// LanceDB does not have langchainJS support so we roll our own here.
const queryVector = await this.embedChunk(this.openai(), input);
const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse(
client,
namespace,
@ -302,7 +306,7 @@ const LanceDb = {
};
}
const queryVector = await this.embedChunk(this.openai(), input);
const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse(
client,
namespace,

View File

@ -54,15 +54,20 @@ const Pinecone = {
if (!data.hasOwnProperty("choices")) return null;
return data.choices[0].message.content;
},
embedChunk: async function (openai, textChunk) {
embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: textChunk,
input: chunks,
});
return data.length > 0 && data[0].hasOwnProperty("embedding")
? data[0].embedding
return data.length > 0 &&
data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null;
},
llm: function ({ temperature = 0.7 }) {
@ -175,25 +180,26 @@ const Pinecone = {
const documentVectors = [];
const vectors = [];
const openai = this.openai();
for (const textChunk of textChunks) {
const vectorValues = await this.embedChunk(openai, textChunk);
const vectorValues = await this.embedChunks(openai, textChunks);
if (!!vectorValues) {
if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
const vectorRecord = {
id: uuidv4(),
values: vectorValues,
values: vector,
// [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk },
metadata: { ...metadata, text: textChunks[i] },
};
vectors.push(vectorRecord);
documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
}
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
}
if (vectors.length > 0) {
@ -311,7 +317,7 @@ const Pinecone = {
"Invalid namespace - has it been collected and seeded yet?"
);
const queryVector = await this.embedChunk(this.openai(), input);
const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse(
pineconeIndex,
namespace,