Validate max_tokens is number (#1445)

This commit is contained in:
Timothy Carambat 2024-05-17 21:44:55 -07:00 committed by GitHub
parent 1a5aacb001
commit 9ace0e67e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 3 deletions

View File

@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const {
handleDefaultStreamResponseV2,
} = require("../../helpers/chat/responses");
const { toValidNumber } = require("../../http");
class GenericOpenAiLLM {
constructor(embedder = null, modelPreference = null) {
@ -18,7 +19,9 @@ class GenericOpenAiLLM {
});
this.model =
modelPreference ?? process.env.GENERIC_OPEN_AI_MODEL_PREF ?? null;
this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024;
this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS
? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024)
: 1024;
if (!this.model)
throw new Error("GenericOpenAI must have a valid model set.");
this.limits = {

View File

@ -1,6 +1,5 @@
const { Document } = require("../../../../models/documents");
const { safeJsonParse } = require("../../../http");
const { validate } = require("uuid");
const { summarizeContent } = require("../utils/summarize");
const Provider = require("../providers/ai-provider");

View File

@ -2,6 +2,7 @@ const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
const { toValidNumber } = require("../../../http/index.js");
/**
* The agent provider for the Generic OpenAI provider.
@ -24,7 +25,9 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
this._client = client;
this.model = model;
this.verbose = true;
this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS ?? 1024;
this.maxTokens = process.env.GENERIC_OPEN_AI_MAX_TOKENS
? toValidNumber(process.env.GENERIC_OPEN_AI_MAX_TOKENS, 1024)
: 1024;
}
get client() {

View File

@ -91,6 +91,11 @@ function isValidUrl(urlString = "") {
return false;
}
function toValidNumber(number = null, fallback = null) {
if (isNaN(Number(number))) return fallback;
return Number(number);
}
module.exports = {
reqBody,
multiUserMode,
@ -101,4 +106,5 @@ module.exports = {
parseAuthHeader,
safeJsonParse,
isValidUrl,
toValidNumber,
};