From 9ace0e67e68aa5dbe9c29c2fc66d981de18469f6 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Fri, 17 May 2024 21:44:55 -0700 Subject: [PATCH] Validate max_tokens is number (#1445) --- server/utils/AiProviders/genericOpenAi/index.js | 5 ++++- server/utils/agents/aibitat/plugins/summarize.js | 1 - server/utils/agents/aibitat/providers/genericOpenAi.js | 5 ++++- server/utils/http/index.js | 6 ++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js index dc0264e4..46b8aefb 100644 --- a/server/utils/AiProviders/genericOpenAi/index.js +++ b/server/utils/AiProviders/genericOpenAi/index.js @@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, } = require("../../helpers/chat/responses"); +const { toValidNumber } = require("../../http"); class GenericOpenAiLLM { constructor(embedder = null, modelPreference = null) { @@ -18,7 +19,9 @@ class GenericOpenAiLLM { }); this.model = modelPreference ?? process.env.GENERIC_OPEN_AI_MODEL_PREF ?? null; - this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024; + this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS + ? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024) + : 1024; if (!this.model) throw new Error("GenericOpenAI must have a valid model set."); this.limits = { diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js index 526de116..de1657c9 100644 --- a/server/utils/agents/aibitat/plugins/summarize.js +++ b/server/utils/agents/aibitat/plugins/summarize.js @@ -1,6 +1,5 @@ const { Document } = require("../../../../models/documents"); const { safeJsonParse } = require("../../../http"); -const { validate } = require("uuid"); const { summarizeContent } = require("../utils/summarize"); const Provider = require("../providers/ai-provider"); diff --git a/server/utils/agents/aibitat/providers/genericOpenAi.js b/server/utils/agents/aibitat/providers/genericOpenAi.js index a1b2db3e..9a753ca2 100644 --- a/server/utils/agents/aibitat/providers/genericOpenAi.js +++ b/server/utils/agents/aibitat/providers/genericOpenAi.js @@ -2,6 +2,7 @@ const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); +const { toValidNumber } = require("../../../http/index.js"); /** * The agent provider for the Generic OpenAI provider. @@ -24,7 +25,9 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { this._client = client; this.model = model; this.verbose = true; - this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024; + this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS + ? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024) + : 1024; } get client() { diff --git a/server/utils/http/index.js b/server/utils/http/index.js index 6400c36b..e812b8ab 100644 --- a/server/utils/http/index.js +++ b/server/utils/http/index.js @@ -91,6 +91,11 @@ function isValidUrl(urlString = "") { return false; } +function toValidNumber(number = null, fallback = null) { + if (isNaN(Number(number))) return fallback; + return Number(number); +} + module.exports = { reqBody, multiUserMode, @@ -101,4 +106,5 @@ module.exports = { parseAuthHeader, safeJsonParse, isValidUrl, + toValidNumber, };