const OpenAI = require("openai:latest"); const Provider = require("./ai-provider.js"); const { RetryError } = require("../error.js"); /** * The provider for the OpenAI API. * By default, the model is set to 'gpt-3.5-turbo'. */ class OpenAIProvider extends Provider { model; static COST_PER_TOKEN = { "gpt-4": { input: 0.03, output: 0.06, }, "gpt-4-32k": { input: 0.06, output: 0.12, }, "gpt-3.5-turbo": { input: 0.0015, output: 0.002, }, "gpt-3.5-turbo-16k": { input: 0.003, output: 0.004, }, }; constructor(config = {}) { const { options = { apiKey: process.env.OPEN_AI_KEY, maxRetries: 3, }, model = "gpt-3.5-turbo", } = config; const client = new OpenAI(options); super(client); this.model = model; } /** * Create a completion based on the received messages. * * @param messages A list of messages to send to the OpenAI API. * @param functions * @returns The completion. */ async complete(messages, functions = null) { try { const response = await this.client.chat.completions.create({ model: this.model, // stream: true, messages, ...(Array.isArray(functions) && functions?.length > 0 ? { functions } : {}), }); // Right now, we only support one completion, // so we just take the first one in the list const completion = response.choices[0].message; const cost = this.getCost(response.usage); // treat function calls if (completion.function_call) { let functionArgs = {}; try { functionArgs = JSON.parse(completion.function_call.arguments); } catch (error) { // call the complete function again in case it gets a json error return this.complete( [ ...messages, { role: "function", name: completion.function_call.name, function_call: completion.function_call, content: error?.message, }, ], functions ); } // console.log(completion, { functionArgs }) return { result: null, functionCall: { name: completion.function_call.name, arguments: functionArgs, }, cost, }; } return { result: completion.content, cost, }; } catch (error) { // If invalid Auth error we need to abort because no amount of waiting // will make auth better. if (error instanceof OpenAI.AuthenticationError) throw error; if ( error instanceof OpenAI.RateLimitError || error instanceof OpenAI.InternalServerError || error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! ) { throw new RetryError(error.message); } throw error; } } /** * Get the cost of the completion. * * @param usage The completion to get the cost for. * @returns The cost of the completion. */ getCost(usage) { if (!usage) { return Number.NaN; } // regex to remove the version number from the model const modelBase = this.model.replace(/-(\d{4})$/, ""); if (!(modelBase in OpenAIProvider.COST_PER_TOKEN)) { return Number.NaN; } const costPerToken = OpenAIProvider.COST_PER_TOKEN?.[modelBase]; const inputCost = (usage.prompt_tokens / 1000) * costPerToken.input; const outputCost = (usage.completion_tokens / 1000) * costPerToken.output; return inputCost + outputCost; } } module.exports = OpenAIProvider;