Compare commits

...

5 Commits

Author SHA1 Message Date
Sean Hatfield 34027c6e18
Merge d7d36460b0 into 8eda75d624 2024-04-26 17:08:45 -07:00
Sean Hatfield 8eda75d624
[FIX] Loading message in document picker bug (#1202)
* fix loading message in document picker bug

* linting

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
2024-04-26 17:08:10 -07:00
Timothy Carambat 1b35bcbeab
Strengthen field validations on user Updates (#1201)
* Strengthen field validations on user Updates

* update writables
2024-04-26 16:46:04 -07:00
shatfield4 d7d36460b0 linting + loading small ui tweak 2024-04-26 16:28:27 -07:00
timothycarambat df2c01b176 patch OpenRouter model fetcher when key is not present 2024-04-26 15:58:30 -07:00
8 changed files with 149 additions and 59 deletions

View File

@ -12,11 +12,10 @@ export default function WebsiteDepthOptions() {
try {
setLoading(true);
showToast(
"Scraping website - this may take a while.",
"info",
{ clear: true, autoClose: false }
);
showToast("Scraping website - this may take a while.", "info", {
clear: true,
autoClose: false,
});
const { data, error } = await System.dataConnectors.websiteDepth.scrape({
url: form.get("url"),
@ -33,7 +32,10 @@ export default function WebsiteDepthOptions() {
}
showToast(
`Successfully scraped ${data.length} ${pluralize("page", data.length)}!`,
`Successfully scraped ${data.length} ${pluralize(
"page",
data.length
)}!`,
"success",
{ clear: true }
);
@ -90,7 +92,9 @@ export default function WebsiteDepthOptions() {
</div>
<div className="flex flex-col pr-10">
<div className="flex flex-col gap-y-1 mb-4">
<label className="text-white text-sm font-bold">Max Links</label>
<label className="text-white text-sm font-bold">
Max Links
</label>
<p className="text-xs font-normal text-white/50">
Maximum number of links to scrape.
</p>
@ -111,7 +115,9 @@ export default function WebsiteDepthOptions() {
<button
type="submit"
disabled={loading}
className="mt-2 w-full justify-center border border-slate-200 px-4 py-2 rounded-lg text-[#222628] text-sm font-bold items-center flex gap-x-2 bg-slate-200 hover:bg-slate-300 hover:text-slate-800 disabled:bg-slate-300 disabled:cursor-not-allowed"
className={`mt-2 w-full ${
loading ? "cursor-not-allowed animate-pulse" : ""
} justify-center border border-slate-200 px-4 py-2 rounded-lg text-[#222628] text-sm font-bold items-center flex gap-x-2 bg-slate-200 hover:bg-slate-300 hover:text-slate-800 disabled:bg-slate-300 disabled:cursor-not-allowed`}
>
{loading ? "Scraping website..." : "Submit"}
</button>
@ -126,4 +132,4 @@ export default function WebsiteDepthOptions() {
</div>
</div>
);
}
}

View File

@ -306,6 +306,7 @@ function Directory({
workspace={workspace}
fetchKeys={fetchKeys}
setLoading={setLoading}
setLoadingMessage={setLoadingMessage}
/>
</div>
</div>

View File

@ -12,6 +12,8 @@ function FileUploadProgressComponent({
reason = null,
onUploadSuccess,
onUploadError,
setLoading,
setLoadingMessage,
}) {
const [timerMs, setTimerMs] = useState(10);
const [status, setStatus] = useState("pending");
@ -19,6 +21,8 @@ function FileUploadProgressComponent({
useEffect(() => {
async function uploadFile() {
setLoading(true);
setLoadingMessage("Uploading file...");
const start = Number(new Date());
const formData = new FormData();
formData.append("file", file, file.name);
@ -34,6 +38,8 @@ function FileUploadProgressComponent({
onUploadError(data.error);
setError(data.error);
} else {
setLoading(false);
setLoadingMessage("");
setStatus("complete");
clearInterval(timer);
onUploadSuccess();

View File

@ -7,7 +7,12 @@ import { v4 } from "uuid";
import FileUploadProgress from "./FileUploadProgress";
import Workspace from "../../../../../models/workspace";
export default function UploadFile({ workspace, fetchKeys, setLoading }) {
export default function UploadFile({
workspace,
fetchKeys,
setLoading,
setLoadingMessage,
}) {
const [ready, setReady] = useState(false);
const [files, setFiles] = useState([]);
const [fetchingUrl, setFetchingUrl] = useState(false);
@ -15,6 +20,7 @@ export default function UploadFile({ workspace, fetchKeys, setLoading }) {
const handleSendLink = async (e) => {
e.preventDefault();
setLoading(true);
setLoadingMessage("Scraping link...");
setFetchingUrl(true);
const formEl = e.target;
const form = new FormData(formEl);
@ -114,6 +120,8 @@ export default function UploadFile({ workspace, fetchKeys, setLoading }) {
reason={file?.reason}
onUploadSuccess={handleUploadSuccess}
onUploadError={handleUploadError}
setLoading={setLoading}
setLoadingMessage={setLoadingMessage}
/>
))}
</div>

View File

@ -2,6 +2,23 @@ const prisma = require("../utils/prisma");
const { EventLogs } = require("./eventLogs");
const User = {
writable: [
// Used for generic updates so we can validate keys in request body
"username",
"password",
"pfpFilename",
"role",
"suspended",
],
// validations for the above writable fields.
castColumnValue: function (key, value) {
switch (key) {
case "suspended":
return Number(Boolean(value));
default:
return String(value);
}
},
create: async function ({ username, password, role = "default" }) {
const passwordCheck = this.checkPasswordComplexity(password);
if (!passwordCheck.checkedOK) {
@ -42,13 +59,26 @@ const User = {
update: async function (userId, updates = {}) {
try {
if (!userId) throw new Error("No user id provided for update");
const currentUser = await prisma.users.findUnique({
where: { id: parseInt(userId) },
});
if (!currentUser) {
return { success: false, error: "User not found" };
}
if (!currentUser) return { success: false, error: "User not found" };
// Removes non-writable fields for generic updates
// and force-casts to the proper type;
Object.entries(updates).forEach(([key, value]) => {
if (this.writable.includes(key)) {
updates[key] = this.castColumnValue(key, value);
return;
}
delete updates[key];
});
if (Object.keys(updates).length === 0)
return { success: false, error: "No valid updates applied." };
// Handle password specific updates
if (updates.hasOwnProperty("password")) {
const passwordCheck = this.checkPasswordComplexity(updates.password);
if (!passwordCheck.checkedOK) {
@ -78,6 +108,24 @@ const User = {
}
},
// Explicit direct update of user object.
// Only use this method when directly setting a key value
// that takes no user input for the keys being modified.
_update: async function (id = null, data = {}) {
if (!id) throw new Error("No user id provided for update");
try {
const user = await prisma.users.update({
where: { id },
data,
});
return { user, message: null };
} catch (error) {
console.error(error.message);
return { user: null, message: error.message };
}
},
get: async function (clause = {}) {
try {
const user = await prisma.users.findFirst({ where: clause });

View File

@ -8,6 +8,11 @@ const {
const fs = require("fs");
const path = require("path");
const { safeJsonParse } = require("../../http");
const cacheFolder = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "openrouter")
: path.resolve(__dirname, `../../../storage/models/openrouter`)
);
class OpenRouterLLM {
constructor(embedder = null, modelPreference = null) {
@ -38,12 +43,8 @@ class OpenRouterLLM {
this.embedder = !embedder ? new NativeEmbedder() : embedder;
this.defaultTemp = 0.7;
const cacheFolder = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "openrouter")
: path.resolve(__dirname, `../../../storage/models/openrouter`)
);
fs.mkdirSync(cacheFolder, { recursive: true });
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
this.cacheModelPath = path.resolve(cacheFolder, "models.json");
this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
}
@ -52,11 +53,6 @@ class OpenRouterLLM {
console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
}
async init() {
await this.#syncModels();
return this;
}
// This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
// from the current date. If it is, then we will refetch the API so that all the models are up
// to date.
@ -80,37 +76,7 @@ class OpenRouterLLM {
this.log(
"Model cache is not present or stale. Fetching from OpenRouter API."
);
await fetch(`${this.basePath}/models`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
})
.then((res) => res.json())
.then(({ data = [] }) => {
const models = {};
data.forEach((model) => {
models[model.id] = {
id: model.id,
name: model.name,
organization:
model.id.split("/")[0].charAt(0).toUpperCase() +
model.id.split("/")[0].slice(1),
maxLength: model.context_length,
};
});
fs.writeFileSync(this.cacheModelPath, JSON.stringify(models), {
encoding: "utf-8",
});
fs.writeFileSync(this.cacheAtPath, String(Number(new Date())), {
encoding: "utf-8",
});
return models;
})
.catch((e) => {
console.error(e);
return {};
});
await fetchOpenRouterModels();
return;
}
@ -420,6 +386,54 @@ class OpenRouterLLM {
}
}
async function fetchOpenRouterModels() {
return await fetch(`https://openrouter.ai/api/v1/models`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
})
.then((res) => res.json())
.then(({ data = [] }) => {
const models = {};
data.forEach((model) => {
models[model.id] = {
id: model.id,
name: model.name,
organization:
model.id.split("/")[0].charAt(0).toUpperCase() +
model.id.split("/")[0].slice(1),
maxLength: model.context_length,
};
});
// Cache all response information
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
fs.writeFileSync(
path.resolve(cacheFolder, "models.json"),
JSON.stringify(models),
{
encoding: "utf-8",
}
);
fs.writeFileSync(
path.resolve(cacheFolder, ".cached_at"),
String(Number(new Date())),
{
encoding: "utf-8",
}
);
return models;
})
.catch((e) => {
console.error(e);
return {};
});
}
module.exports = {
OpenRouterLLM,
fetchOpenRouterModels,
};

View File

@ -22,7 +22,7 @@ async function generateRecoveryCodes(userId) {
const { error } = await RecoveryCode.createMany(newRecoveryCodes);
if (!!error) throw new Error(error);
const { success } = await User.update(userId, {
const { user: success } = await User._update(userId, {
seen_recovery_codes: true,
});
if (!success) throw new Error("Failed to generate user recovery codes!");
@ -80,6 +80,11 @@ async function resetPassword(token, _newPassword = "", confirmPassword = "") {
// JOI password rules will be enforced inside .update.
const { error } = await User.update(resetToken.user_id, {
password: newPassword,
});
// seen_recovery_codes is not publicly writable
// so we have to do direct update here
await User._update(resetToken.user_id, {
seen_recovery_codes: false,
});

View File

@ -1,4 +1,7 @@
const { OpenRouterLLM } = require("../AiProviders/openRouter");
const {
OpenRouterLLM,
fetchOpenRouterModels,
} = require("../AiProviders/openRouter");
const { perplexityModels } = require("../AiProviders/perplexity");
const { togetherAiModels } = require("../AiProviders/togetherAi");
const SUPPORT_CUSTOM_MODELS = [
@ -232,8 +235,7 @@ async function getPerplexityModels() {
}
async function getOpenRouterModels() {
const openrouter = await new OpenRouterLLM().init();
const knownModels = openrouter.models();
const knownModels = await fetchOpenRouterModels();
if (!Object.keys(knownModels).length === 0)
return { models: [], error: null };