anything-llm/server/utils/agents/aibitat/providers/openai.js

152 lines
3.7 KiB
JavaScript
Raw Normal View History

const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js");
/**
* The agent provider for the OpenAI API.
* By default, the model is set to 'gpt-3.5-turbo'.
*/
class OpenAIProvider extends Provider {
model;
static COST_PER_TOKEN = {
"gpt-4": {
input: 0.03,
output: 0.06,
},
"gpt-4o": {
input: 0.005,
output: 0.015,
},
"gpt-4-32k": {
input: 0.06,
output: 0.12,
},
"gpt-3.5-turbo": {
input: 0.0015,
output: 0.002,
},
"gpt-3.5-turbo-16k": {
input: 0.003,
output: 0.004,
},
};
constructor(config = {}) {
const {
options = {
apiKey: process.env.OPEN_AI_KEY,
maxRetries: 3,
},
model = "gpt-4o",
} = config;
const client = new OpenAI(options);
super(client);
this.model = model;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the OpenAI API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
const response = await this.client.chat.completions.create({
model: this.model,
// stream: true,
messages,
...(Array.isArray(functions) && functions?.length > 0
? { functions }
: {}),
});
// Right now, we only support one completion,
// so we just take the first one in the list
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
// treat function calls
if (completion.function_call) {
let functionArgs = {};
try {
functionArgs = JSON.parse(completion.function_call.arguments);
} catch (error) {
// call the complete function again in case it gets a json error
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
}
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
}
return {
result: completion.content,
cost,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param usage The completion to get the cost for.
* @returns The cost of the completion.
*/
getCost(usage) {
if (!usage) {
return Number.NaN;
}
// regex to remove the version number from the model
const modelBase = this.model.replace(/-(\d{4})$/, "");
if (!(modelBase in OpenAIProvider.COST_PER_TOKEN)) {
return Number.NaN;
}
const costPerToken = OpenAIProvider.COST_PER_TOKEN?.[modelBase];
const inputCost = (usage.prompt_tokens / 1000) * costPerToken.input;
const outputCost = (usage.completion_tokens / 1000) * costPerToken.output;
return inputCost + outputCost;
}
}
module.exports = OpenAIProvider;