anything-llm/server/models/embedConfig.js

240 lines
6.1 KiB
JavaScript
Raw Normal View History

2024-02-02 20:12:06 +01:00
const { v4 } = require("uuid");
2024-02-01 20:24:42 +01:00
const prisma = require("../utils/prisma");
2024-02-02 20:12:06 +01:00
const { VALID_CHAT_MODE } = require("../utils/chats/stream");
2024-02-01 20:24:42 +01:00
const EmbedConfig = {
writable: [
// Used for generic updates so we can validate keys in request body
2024-02-02 20:12:06 +01:00
"enabled",
2024-02-01 20:24:42 +01:00
"allowlist_domains",
"allow_model_override",
"allow_temperature_override",
"allow_prompt_override",
"max_chats_per_day",
"max_chats_per_session",
2024-02-01 20:24:42 +01:00
"chat_mode",
"workspace_id",
2024-02-01 20:24:42 +01:00
],
2024-02-02 20:12:06 +01:00
new: async function (data, creatorId = null) {
try {
const embed = await prisma.embed_configs.create({
data: {
uuid: v4(),
enabled: true,
chat_mode: validatedCreationData(data?.chat_mode, "chat_mode"),
allowlist_domains: validatedCreationData(
data?.allowlist_domains,
"allowlist_domains"
),
allow_model_override: validatedCreationData(
data?.allow_model_override,
"allow_model_override"
),
allow_temperature_override: validatedCreationData(
data?.allow_temperature_override,
"allow_temperature_override"
),
allow_prompt_override: validatedCreationData(
data?.allow_prompt_override,
"allow_prompt_override"
),
max_chats_per_day: validatedCreationData(
data?.max_chats_per_day,
"max_chats_per_day"
),
max_chats_per_session: validatedCreationData(
data?.max_chats_per_session,
"max_chats_per_session"
),
createdBy: Number(creatorId) ?? null,
workspace: {
connect: { id: Number(data.workspace_id) },
2024-02-02 20:12:06 +01:00
},
},
});
return { embed, message: null };
} catch (error) {
console.error(error.message);
return { embed: null, message: error.message };
}
2024-02-01 20:24:42 +01:00
},
2024-02-02 20:12:06 +01:00
update: async function (embedId = null, data = {}) {
if (!embedId) throw new Error("No embed id provided for update");
const validKeys = Object.keys(data).filter((key) =>
this.writable.includes(key)
);
if (validKeys.length === 0)
return { embed: { id }, message: "No valid fields to update!" };
const updates = {};
validKeys.map((key) => {
updates[key] = validatedCreationData(data[key], key);
});
try {
await prisma.embed_configs.update({
where: { id: Number(embedId) },
data: updates,
});
return { success: true, error: null };
} catch (error) {
console.error(error.message);
return { success: false, error: error.message };
}
2024-02-01 20:24:42 +01:00
},
get: async function (clause = {}) {
try {
const embedConfig = await prisma.embed_configs.findFirst({
where: clause,
});
return embedConfig || null;
} catch (error) {
console.error(error.message);
return null;
}
},
getWithWorkspace: async function (clause = {}) {
try {
const embedConfig = await prisma.embed_configs.findFirst({
where: clause,
include: {
workspace: true,
},
});
return embedConfig || null;
} catch (error) {
console.error(error.message);
return null;
}
},
delete: async function (clause = {}) {
try {
await prisma.embed_configs.delete({
where: clause,
});
return true;
} catch (error) {
console.error(error.message);
return false;
}
},
where: async function (clause = {}, limit = null, orderBy = null) {
try {
const results = await prisma.embed_configs.findMany({
where: clause,
...(limit !== null ? { take: limit } : {}),
...(orderBy !== null ? { orderBy } : {}),
});
return results;
} catch (error) {
console.error(error.message);
return [];
}
},
2024-02-02 20:12:06 +01:00
whereWithWorkspace: async function (
clause = {},
limit = null,
orderBy = null
) {
try {
const results = await prisma.embed_configs.findMany({
where: clause,
include: {
workspace: true,
_count: {
select: { embed_chats: true },
},
},
...(limit !== null ? { take: limit } : {}),
...(orderBy !== null ? { orderBy } : {}),
});
return results;
} catch (error) {
console.error(error.message);
return [];
}
},
// Will return null if process should be skipped
// an empty array means the system will check. This
// prevents a bad parse from allowing all requests
parseAllowedHosts: function (embed) {
if (!embed.allowlist_domains) return null;
try {
return JSON.parse(embed.allowlist_domains);
} catch {
console.error(`Failed to parse allowlist_domains for Embed ${embed.id}!`);
return [];
}
},
2024-02-01 20:24:42 +01:00
};
2024-02-02 20:12:06 +01:00
const BOOLEAN_KEYS = [
"allow_model_override",
"allow_temperature_override",
"allow_prompt_override",
"enabled",
];
const NUMBER_KEYS = [
"max_chats_per_day",
"max_chats_per_session",
"workspace_id",
];
2024-02-02 20:12:06 +01:00
// Helper to validate a data object strictly into the proper format
function validatedCreationData(value, field) {
if (field === "chat_mode") {
if (!value || !VALID_CHAT_MODE.includes(value)) return "query";
return value;
}
if (field === "allowlist_domains") {
try {
if (!value) return null;
return JSON.stringify(
// Iterate and force all domains to URL object
// and stringify the result.
value
.split(",")
.map((input) => {
let url = input;
if (!url.includes("http://") && !url.includes("https://"))
url = `https://${url}`;
try {
new URL(url);
return url;
} catch {
return null;
}
})
.filter((u) => !!u)
);
} catch {
return null;
}
}
if (BOOLEAN_KEYS.includes(field)) {
return value === true || value === false ? value : false;
}
if (NUMBER_KEYS.includes(field)) {
return isNaN(value) || Number(value) <= 0 ? null : Number(value);
}
return null;
}
2024-02-01 20:24:42 +01:00
module.exports = { EmbedConfig };