From 6f52a2b72970d02eb76ddbc8951ec5fcbfc0447a Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Sat, 6 Apr 2024 11:49:15 -0700 Subject: [PATCH] Embedder download - fallback URL (#1056) * Embedder download - fallback URL * improve logging for native embedder --- server/utils/EmbeddingEngines/native/index.js | 101 +++++++++++++----- server/utils/helpers/index.js | 1 - 2 files changed, 73 insertions(+), 29 deletions(-) diff --git a/server/utils/EmbeddingEngines/native/index.js b/server/utils/EmbeddingEngines/native/index.js index fc933e1b..04b754e0 100644 --- a/server/utils/EmbeddingEngines/native/index.js +++ b/server/utils/EmbeddingEngines/native/index.js @@ -4,6 +4,12 @@ const { toChunks } = require("../../helpers"); const { v4 } = require("uuid"); class NativeEmbedder { + // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download + // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained + // and may go offline at any time at Mintplex Labs's discretion. + #fallbackHost = + "https://s3.us-west-1.amazonaws.com/public.useanything.com/support/models/"; + constructor() { // Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 this.model = "Xenova/all-MiniLM-L6-v2"; @@ -13,6 +19,7 @@ class NativeEmbedder { : path.resolve(__dirname, `../../../storage/models`) ); this.modelPath = path.resolve(this.cacheDir, "Xenova", "all-MiniLM-L6-v2"); + this.modelDownloaded = fs.existsSync(this.modelPath); // Limit of how many strings we can process in a single pass to stay with resource or network limits this.maxConcurrentChunks = 25; @@ -20,6 +27,11 @@ class NativeEmbedder { // Make directory when it does not exist in existing installations if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); + this.log("Initialized"); + } + + log(text, ...args) { + console.log(`\x1b[36m[NativeEmbedder]\x1b[0m ${text}`, ...args); } #tempfilePath() { @@ -39,41 +51,73 @@ class NativeEmbedder { } } - async embedderClient() { - if (!fs.existsSync(this.modelPath)) { - console.log( - "\x1b[34m[INFO]\x1b[0m The native embedding model has never been run and will be downloaded right now. Subsequent runs will be faster. (~23MB)\n\n" - ); - } - + async #fetchWithHost(hostOverride = null) { try { // Convert ESM to CommonJS via import so we can load this library. const pipeline = (...args) => - import("@xenova/transformers").then(({ pipeline }) => - pipeline(...args) - ); - return await pipeline("feature-extraction", this.model, { - cache_dir: this.cacheDir, - ...(!fs.existsSync(this.modelPath) - ? { - // Show download progress if we need to download any files - progress_callback: (data) => { - if (!data.hasOwnProperty("progress")) return; - console.log( - `\x1b[34m[Embedding - Downloading Model Files]\x1b[0m ${ - data.file - } ${~~data?.progress}%` - ); - }, + import("@xenova/transformers").then(({ pipeline, env }) => { + if (!this.modelDownloaded) { + // if model is not downloaded, we will log where we are fetching from. + if (hostOverride) { + env.remoteHost = hostOverride; + env.remotePathTemplate = "{model}/"; // Our S3 fallback url does not support revision File structure. } - : {}), - }); + this.log(`Downloading ${this.model} from ${env.remoteHost}`); + } + return pipeline(...args); + }); + return { + pipeline: await pipeline("feature-extraction", this.model, { + cache_dir: this.cacheDir, + ...(!this.modelDownloaded + ? { + // Show download progress if we need to download any files + progress_callback: (data) => { + if (!data.hasOwnProperty("progress")) return; + console.log( + `\x1b[36m[NativeEmbedder - Downloading model]\x1b[0m ${ + data.file + } ${~~data?.progress}%` + ); + }, + } + : {}), + }), + retry: false, + error: null, + }; } catch (error) { - console.error("Failed to load the native embedding model:", error); - throw error; + return { + pipeline: null, + retry: hostOverride === null ? this.#fallbackHost : false, + error, + }; } } + // This function will do a single fallback attempt (not recursive on purpose) to try to grab the embedder model on first embed + // since at time, some clients cannot properly download the model from HF servers due to a number of reasons (IP, VPN, etc). + // Given this model is critical and nobody reads the GitHub issues before submitting the bug, we get the same bug + // report 20 times a day: https://github.com/Mintplex-Labs/anything-llm/issues/821 + // So to attempt to monkey-patch this we have a single fallback URL to help alleviate duplicate bug reports. + async embedderClient() { + if (!this.modelDownloaded) + this.log( + "The native embedding model has never been run and will be downloaded right now. Subsequent runs will be faster. (~23MB)" + ); + + let fetchResponse = await this.#fetchWithHost(); + if (fetchResponse.pipeline !== null) return fetchResponse.pipeline; + + this.log( + `Failed to download model from primary URL. Using fallback ${fetchResponse.retry}` + ); + if (!!fetchResponse.retry) + fetchResponse = await this.#fetchWithHost(fetchResponse.retry); + if (fetchResponse.pipeline !== null) return fetchResponse.pipeline; + throw fetchResponse.error; + } + async embedTextInput(textInput) { const result = await this.embedChunks(textInput); return result?.[0] || []; @@ -89,6 +133,7 @@ class NativeEmbedder { // during a very large document (>100K words) but can spike up to 70% before gc. // This seems repeatable for all document sizes. // While this does take a while, it is zero set up and is 100% free and on-instance. + // It still may crash depending on other elements at play - so no promises it works under all conditions. async embedChunks(textChunks = []) { const tmpFilePath = this.#tempfilePath(); const chunks = toChunks(textChunks, this.maxConcurrentChunks); @@ -112,7 +157,7 @@ class NativeEmbedder { data = JSON.stringify(output.tolist()); await this.#writeToTempfile(tmpFilePath, data); - console.log(`\x1b[34m[Embedded Chunk ${idx + 1} of ${chunkLen}]\x1b[0m`); + this.log(`Embedded Chunk ${idx + 1} of ${chunkLen}`); if (chunkLen - 1 !== idx) await this.#writeToTempfile(tmpFilePath, ","); if (chunkLen - 1 === idx) await this.#writeToTempfile(tmpFilePath, "]"); pipeline = null; diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index a441bf82..3d8bb915 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -101,7 +101,6 @@ function getEmbeddingEngineSelection() { return new OllamaEmbedder(); case "native": const { NativeEmbedder } = require("../EmbeddingEngines/native"); - console.log("\x1b[34m[INFO]\x1b[0m Using Native Embedder"); return new NativeEmbedder(); default: return null;