Workspace agent autoselection (#2357)

* refactor agent to add fallback to workspace, then to chat provider/model

* commenting
update logic for bedrock and fireworks fallbacks

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-09-25 13:30:20 -07:00 committed by GitHub
parent 074088d3cb
commit e2195a96d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ class AgentHandler {
lmstudio: "LMSTUDIO_MODEL_PREF", lmstudio: "LMSTUDIO_MODEL_PREF",
textgenwebui: null, // does not even use `model` in API req textgenwebui: null, // does not even use `model` in API req
"generic-openai": "GENERIC_OPEN_AI_MODEL_PREF", "generic-openai": "GENERIC_OPEN_AI_MODEL_PREF",
bedrock: "AWS_BEDROCK_LLM_MODEL_PREFERENCE",
}; };
invocation = null; invocation = null;
aibitat = null; aibitat = null;
@ -149,20 +150,16 @@ class AgentHandler {
if ( if (
!process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID || !process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID ||
!process.env.AWS_BEDROCK_LLM_ACCESS_KEY || !process.env.AWS_BEDROCK_LLM_ACCESS_KEY ||
!process.env.AWS_BEDROCK_LLM_REGION || !process.env.AWS_BEDROCK_LLM_REGION
!process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE
) )
throw new Error( throw new Error(
"AWS Bedrock Access Keys, model and region must be provided to use agents." "AWS Bedrock Access Keys and region must be provided to use agents."
); );
break; break;
case "fireworksai": case "fireworksai":
if ( if (!process.env.FIREWORKS_AI_LLM_API_KEY)
!process.env.FIREWORKS_AI_LLM_API_KEY ||
!process.env.FIREWORKS_AI_LLM_MODEL_PREF
)
throw new Error( throw new Error(
"FireworksAI API Key & model must be provided to use agents." "FireworksAI API Key must be provided to use agents."
); );
break; break;
@ -173,8 +170,8 @@ class AgentHandler {
} }
} }
providerDefault() { providerDefault(provider = this.provider) {
switch (this.provider) { switch (provider) {
case "openai": case "openai":
return "gpt-4o"; return "gpt-4o";
case "anthropic": case "anthropic":
@ -214,6 +211,32 @@ class AgentHandler {
} }
} }
#getFallbackProvider() {
// First, fallback to the workspace chat provider and model if they exist
if (
this.invocation.workspace.chatProvider &&
this.invocation.workspace.chatModel
) {
return {
provider: this.invocation.workspace.chatProvider,
model: this.invocation.workspace.chatModel,
};
}
// If workspace does not have chat provider and model fallback
// to system provider and try to load provider default model
const systemProvider = process.env.LLM_PROVIDER;
const systemModel = this.providerDefault(systemProvider);
if (systemProvider && systemModel) {
return {
provider: systemProvider,
model: systemModel,
};
}
return null;
}
/** /**
* Finds or assumes the model preference value to use for API calls. * Finds or assumes the model preference value to use for API calls.
* If multi-model loading is supported, we use their agent model selection of the workspace * If multi-model loading is supported, we use their agent model selection of the workspace
@ -222,22 +245,41 @@ class AgentHandler {
* @returns {string} the model preference value to use in API calls * @returns {string} the model preference value to use in API calls
*/ */
#fetchModel() { #fetchModel() {
if (!Object.keys(this.noProviderModelDefault).includes(this.provider)) // Provider was not explicitly set for workspace, so we are going to run our fallback logic
return this.invocation.workspace.agentModel || this.providerDefault(); // that will set a provider and model for us to use.
if (!this.provider) {
const fallback = this.#getFallbackProvider();
if (!fallback) throw new Error("No valid provider found for the agent.");
this.provider = fallback.provider; // re-set the provider to the fallback provider so it is not null.
return fallback.model; // set its defined model based on fallback logic.
}
// Provider has no reliable default (cant load many models) - so we need to look at system // The provider was explicitly set, so check if the workspace has an agent model set.
// for the model param. if (this.invocation.workspace.agentModel) {
return this.invocation.workspace.agentModel;
}
// If the provider we are using is not supported or does not support multi-model loading
// then we use the default model for the provider.
if (!Object.keys(this.noProviderModelDefault).includes(this.provider)) {
return this.providerDefault();
}
// Load the model from the system environment variable for providers with no multi-model loading.
const sysModelKey = this.noProviderModelDefault[this.provider]; const sysModelKey = this.noProviderModelDefault[this.provider];
if (!!sysModelKey) if (sysModelKey) return process.env[sysModelKey] ?? this.providerDefault();
return process.env[sysModelKey] ?? this.providerDefault();
// If all else fails - look at the provider default list // Otherwise, we have no model to use - so guess a default model to use.
return this.providerDefault(); return this.providerDefault();
} }
#providerSetupAndCheck() { #providerSetupAndCheck() {
this.provider = this.invocation.workspace.agentProvider; this.provider = this.invocation.workspace.agentProvider ?? null; // set provider to workspace agent provider if it exists
this.model = this.#fetchModel(); this.model = this.#fetchModel();
if (!this.provider)
throw new Error("No valid provider found for the agent.");
this.log(`Start ${this.#invocationUUID}::${this.provider}:${this.model}`); this.log(`Start ${this.#invocationUUID}::${this.provider}:${this.model}`);
this.checkSetup(); this.checkSetup();
} }