Support free-form input for workspace model for providers with no /models endpoint (#2397)

* support generic openai workspace model

* Update UI for free form input for some providers

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-10-15 15:24:44 -07:00 committed by GitHub
parent c3723ce2ff
commit 6674e5aab8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 54 deletions

View File

@ -71,23 +71,6 @@ export default function AzureAiOptions({ settings }) {
</option> </option>
</select> </select>
</div> </div>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Embedding Deployment Name
</label>
<input
type="text"
name="AzureOpenAiEmbeddingModelPref"
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Azure OpenAI embedding model deployment name"
defaultValue={settings?.AzureOpenAiEmbeddingModelPref}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex-flex-col w-60"></div>
</div> </div>
</div> </div>
); );

View File

@ -8,15 +8,18 @@ import { useTranslation } from "react-i18next";
import { Link } from "react-router-dom"; import { Link } from "react-router-dom";
import paths from "@/utils/paths"; import paths from "@/utils/paths";
// Some providers can only be associated with a single model. // Some providers do not support model selection via /models.
// In that case there is no selection to be made so we can just move on. // In that case we allow the user to enter the model name manually and hope they
const NO_MODEL_SELECTION = [ // type it correctly.
"default", const FREE_FORM_LLM_SELECTION = ["bedrock", "azure", "generic-openai"];
"huggingface",
"generic-openai", // Some providers do not support model selection via /models
"bedrock", // and only have a fixed single-model they can use.
]; const NO_MODEL_SELECTION = ["default", "huggingface"];
const DISABLED_PROVIDERS = ["azure", "native"];
// Some providers we just fully disable for ease of use.
const DISABLED_PROVIDERS = ["native"];
const LLM_DEFAULT = { const LLM_DEFAULT = {
name: "System default", name: "System default",
value: "default", value: "default",
@ -65,8 +68,8 @@ export default function WorkspaceLLMSelection({
); );
setFilteredLLMs(filtered); setFilteredLLMs(filtered);
}, [LLMS, searchQuery, selectedLLM]); }, [LLMS, searchQuery, selectedLLM]);
const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM); const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM);
return ( return (
<div className="border-b border-white/40 pb-8"> <div className="border-b border-white/40 pb-8">
<div className="flex flex-col"> <div className="flex flex-col">
@ -155,9 +158,20 @@ export default function WorkspaceLLMSelection({
</button> </button>
)} )}
</div> </div>
{NO_MODEL_SELECTION.includes(selectedLLM) ? ( <ModelSelector
<> selectedLLM={selectedLLM}
{selectedLLM !== "default" && ( workspace={workspace}
setHasChanges={setHasChanges}
/>
</div>
);
}
// TODO: Add this to agent selector as well as make generic component.
function ModelSelector({ selectedLLM, workspace, setHasChanges }) {
if (NO_MODEL_SELECTION.includes(selectedLLM)) {
if (selectedLLM !== "default") {
return (
<div className="w-full h-10 justify-center items-center flex mt-4"> <div className="w-full h-10 justify-center items-center flex mt-4">
<p className="text-sm font-base text-white text-opacity-60 text-center"> <p className="text-sm font-base text-white text-opacity-60 text-center">
Multi-model support is not supported for this provider yet. Multi-model support is not supported for this provider yet.
@ -168,17 +182,42 @@ export default function WorkspaceLLMSelection({
</Link> </Link>
</p> </p>
</div> </div>
)} );
</> }
) : ( return null;
<div className="mt-4 flex flex-col gap-y-1"> }
if (FREE_FORM_LLM_SELECTION.includes(selectedLLM)) {
return (
<FreeFormLLMInput workspace={workspace} setHasChanges={setHasChanges} />
);
}
return (
<ChatModelSelection <ChatModelSelection
provider={selectedLLM} provider={selectedLLM}
workspace={workspace} workspace={workspace}
setHasChanges={setHasChanges} setHasChanges={setHasChanges}
/> />
</div> );
)} }
function FreeFormLLMInput({ workspace, setHasChanges }) {
const { t } = useTranslation();
return (
<div className="mt-4 flex flex-col gap-y-1">
<label className="block input-label">{t("chat.model.title")}</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
{t("chat.model.description")}
</p>
<input
type="text"
name="chatModel"
defaultValue={workspace?.chatModel || ""}
onChange={() => setHasChanges(true)}
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Enter model name exactly as referenced in the API (e.g., gpt-3.5-turbo)"
/>
</div> </div>
); );
} }

View File

@ -5,7 +5,7 @@ const {
} = require("../../helpers/chat/responses"); } = require("../../helpers/chat/responses");
class AzureOpenAiLLM { class AzureOpenAiLLM {
constructor(embedder = null, _modelPreference = null) { constructor(embedder = null, modelPreference = null) {
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
if (!process.env.AZURE_OPENAI_ENDPOINT) if (!process.env.AZURE_OPENAI_ENDPOINT)
throw new Error("No Azure API endpoint was set."); throw new Error("No Azure API endpoint was set.");
@ -16,7 +16,7 @@ class AzureOpenAiLLM {
process.env.AZURE_OPENAI_ENDPOINT, process.env.AZURE_OPENAI_ENDPOINT,
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY) new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
); );
this.model = process.env.OPEN_MODEL_PREF; this.model = modelPreference ?? process.env.OPEN_MODEL_PREF;
this.limits = { this.limits = {
history: this.promptWindowLimit() * 0.15, history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15,

View File

@ -32,7 +32,7 @@ class AWSBedrockLLM {
#bedrockClient({ temperature = 0.7 }) { #bedrockClient({ temperature = 0.7 }) {
const { ChatBedrockConverse } = require("@langchain/aws"); const { ChatBedrockConverse } = require("@langchain/aws");
return new ChatBedrockConverse({ return new ChatBedrockConverse({
model: process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE, model: this.model,
region: process.env.AWS_BEDROCK_LLM_REGION, region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: { credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID, accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,