[FEAT] Groq LLM support (#865)

* Groq LLM support complete

* update useGetProvidersModels for groq models

* Add definiations
update comments and error log reports
add example envs

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-03-06 14:48:38 -08:00 committed by GitHub
parent 4731ec8be8
commit 0634013788
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 320 additions and 4 deletions

View File

@ -4,12 +4,15 @@
"Astra",
"Dockerized",
"Embeddable",
"GROQ",
"hljs",
"inferencing",
"Langchain",
"Milvus",
"Mintplex",
"Ollama",
"openai",
"openrouter",
"Qdrant",
"vectordbs",
"Weaviate",

View File

@ -61,6 +61,10 @@ GID='1000'
# HUGGING_FACE_LLM_API_KEY=hf_xxxxxx
# HUGGING_FACE_LLM_TOKEN_LIMIT=8000
# LLM_PROVIDER='groq'
# GROQ_API_KEY=gsk_abcxyz
# GROQ_MODEL_PREF=llama2-70b-4096
###########################################
######## Embedding API SElECTION ##########
###########################################

View File

@ -0,0 +1,41 @@
export default function GroqAiOptions({ settings }) {
return (
<div className="flex gap-x-4">
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-4">
Groq API Key
</label>
<input
type="password"
name="GroqApiKey"
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
placeholder="Groq API Key"
defaultValue={settings?.GroqApiKey ? "*".repeat(20) : ""}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-4">
Chat Model Selection
</label>
<select
name="GroqModelPref"
defaultValue={settings?.GroqModelPref || "llama2-70b-4096"}
required={true}
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{["llama2-70b-4096", "mixtral-8x7b-32768"].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</select>
</div>
</div>
);
}

View File

@ -19,6 +19,7 @@ const PROVIDER_DEFAULT_MODELS = {
localai: [],
ollama: [],
togetherai: [],
groq: ["llama2-70b-4096", "mixtral-8x7b-32768"],
native: [],
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -16,6 +16,7 @@ import MistralLogo from "@/media/llmprovider/mistral.jpeg";
import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
import GroqLogo from "@/media/llmprovider/groq.png";
import PreLoader from "@/components/Preloader";
import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions";
import AzureAiOptions from "@/components/LLMSelection/AzureAiOptions";
@ -28,11 +29,12 @@ import OllamaLLMOptions from "@/components/LLMSelection/OllamaLLMOptions";
import TogetherAiOptions from "@/components/LLMSelection/TogetherAiOptions";
import MistralOptions from "@/components/LLMSelection/MistralOptions";
import HuggingFaceOptions from "@/components/LLMSelection/HuggingFaceOptions";
import PerplexityOptions from "@/components/LLMSelection/PerplexityOptions";
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
import GroqAiOptions from "@/components/LLMSelection/GroqAiOptions";
import LLMItem from "@/components/LLMSelection/LLMItem";
import { MagnifyingGlass } from "@phosphor-icons/react";
import PerplexityOptions from "@/components/LLMSelection/PerplexityOptions";
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
export default function GeneralLLMPreference() {
const [saving, setSaving] = useState(false);
@ -173,6 +175,14 @@ export default function GeneralLLMPreference() {
options: <OpenRouterOptions settings={settings} />,
description: "A unified interface for LLMs.",
},
{
name: "Groq",
value: "groq",
logo: GroqLogo,
options: <GroqAiOptions settings={settings} />,
description:
"The fastest LLM inferencing available for real-time AI applications.",
},
{
name: "Native",
value: "native",

View File

@ -13,6 +13,7 @@ import MistralLogo from "@/media/llmprovider/mistral.jpeg";
import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
import GroqLogo from "@/media/llmprovider/groq.png";
import ZillizLogo from "@/media/vectordbs/zilliz.png";
import AstraDBLogo from "@/media/vectordbs/astraDB.png";
import ChromaLogo from "@/media/vectordbs/chroma.png";
@ -127,6 +128,14 @@ const LLM_SELECTION_PRIVACY = {
],
logo: OpenRouterLogo,
},
groq: {
name: "Groq",
description: [
"Your chats will not be used for training",
"Your prompts and document text used in response creation are visible to Groq",
],
logo: GroqLogo,
},
};
const VECTOR_DB_PRIVACY = {

View File

@ -13,6 +13,7 @@ import MistralLogo from "@/media/llmprovider/mistral.jpeg";
import HuggingFaceLogo from "@/media/llmprovider/huggingface.png";
import PerplexityLogo from "@/media/llmprovider/perplexity.png";
import OpenRouterLogo from "@/media/llmprovider/openrouter.jpeg";
import GroqLogo from "@/media/llmprovider/groq.png";
import OpenAiOptions from "@/components/LLMSelection/OpenAiOptions";
import AzureAiOptions from "@/components/LLMSelection/AzureAiOptions";
import AnthropicAiOptions from "@/components/LLMSelection/AnthropicAiOptions";
@ -25,12 +26,13 @@ import MistralOptions from "@/components/LLMSelection/MistralOptions";
import HuggingFaceOptions from "@/components/LLMSelection/HuggingFaceOptions";
import TogetherAiOptions from "@/components/LLMSelection/TogetherAiOptions";
import PerplexityOptions from "@/components/LLMSelection/PerplexityOptions";
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
import GroqAiOptions from "@/components/LLMSelection/GroqAiOptions";
import LLMItem from "@/components/LLMSelection/LLMItem";
import System from "@/models/system";
import paths from "@/utils/paths";
import showToast from "@/utils/toast";
import { useNavigate } from "react-router-dom";
import OpenRouterOptions from "@/components/LLMSelection/OpenRouterOptions";
const TITLE = "LLM Preference";
const DESCRIPTION =
@ -147,6 +149,14 @@ export default function LLMPreference({
options: <OpenRouterOptions settings={settings} />,
description: "A unified interface for LLMs.",
},
{
name: "Groq",
value: "groq",
logo: GroqLogo,
options: <GroqAiOptions settings={settings} />,
description:
"The fastest LLM inferencing available for real-time AI applications.",
},
{
name: "Native",
value: "native",

View File

@ -58,6 +58,10 @@ JWT_SECRET="my-random-string-for-seeding" # Please generate random string at lea
# HUGGING_FACE_LLM_API_KEY=hf_xxxxxx
# HUGGING_FACE_LLM_TOKEN_LIMIT=8000
# LLM_PROVIDER='groq'
# GROQ_API_KEY=gsk_abcxyz
# GROQ_MODEL_PREF=llama2-70b-4096
###########################################
######## Embedding API SElECTION ##########
###########################################

View File

@ -219,12 +219,25 @@ const SystemSettings = {
AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF,
}
: {}),
...(llmProvider === "groq"
? {
GroqApiKey: !!process.env.GROQ_API_KEY,
GroqModelPref: process.env.GROQ_MODEL_PREF,
// For embedding credentials when groq is selected.
OpenAiKey: !!process.env.OPEN_AI_KEY,
AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT,
AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY,
AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF,
}
: {}),
...(llmProvider === "native"
? {
NativeLLMModelPref: process.env.NATIVE_LLM_MODEL_PREF,
NativeLLMTokenLimit: process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT,
// For embedding credentials when ollama is selected.
// For embedding credentials when native is selected.
OpenAiKey: !!process.env.OPEN_AI_KEY,
AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT,
AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY,

View File

@ -0,0 +1,207 @@
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { chatPrompt } = require("../../chats");
const { handleDefaultStreamResponse } = require("../../helpers/chat/responses");
class GroqLLM {
constructor(embedder = null, modelPreference = null) {
const { Configuration, OpenAIApi } = require("openai");
if (!process.env.GROQ_API_KEY) throw new Error("No Groq API key was set.");
const config = new Configuration({
basePath: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY,
});
this.openai = new OpenAIApi(config);
this.model =
modelPreference || process.env.GROQ_MODEL_PREF || "llama2-70b-4096";
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
user: this.promptWindowLimit() * 0.7,
};
this.embedder = !embedder ? new NativeEmbedder() : embedder;
this.defaultTemp = 0.7;
}
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
"\nContext:\n" +
contextTexts
.map((text, i) => {
return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
})
.join("")
);
}
streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
promptWindowLimit() {
switch (this.model) {
case "llama2-70b-4096":
return 4096;
case "mixtral-8x7b-32768":
return 32_768;
default:
return 4096;
}
}
async isValidChatCompletionModel(modelName = "") {
const validModels = ["llama2-70b-4096", "mixtral-8x7b-32768"];
const isPreset = validModels.some((model) => modelName === model);
if (isPreset) return true;
const model = await this.openai
.retrieveModel(modelName)
.then((res) => res.data)
.catch(() => null);
return !!model;
}
constructPrompt({
systemPrompt = "",
contextTexts = [],
chatHistory = [],
userPrompt = "",
}) {
const prompt = {
role: "system",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
};
return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
}
async isSafe(_input = "") {
// Not implemented so must be stubbed
return { safe: true, reasons: [] };
}
async sendChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`Groq chat: ${this.model} is not valid for chat completion!`
);
const textResponse = await this.openai
.createChatCompletion({
model: this.model,
temperature: Number(workspace?.openAiTemp ?? this.defaultTemp),
n: 1,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
})
.then((json) => {
const res = json.data;
if (!res.hasOwnProperty("choices"))
throw new Error("GroqAI chat: No results!");
if (res.choices.length === 0)
throw new Error("GroqAI chat: No results length!");
return res.choices[0].message.content;
})
.catch((error) => {
throw new Error(
`GroqAI::createChatCompletion failed with: ${error.message}`
);
});
return textResponse;
}
async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`GroqAI:streamChat: ${this.model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model: this.model,
stream: true,
temperature: Number(workspace?.openAiTemp ?? this.defaultTemp),
n: 1,
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
},
{ responseType: "stream" }
);
return streamRequest;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`GroqAI:chatCompletion: ${this.model} is not valid for chat completion!`
);
const { data } = await this.openai
.createChatCompletion({
model: this.model,
messages,
temperature,
})
.catch((e) => {
throw new Error(e.response.data.error.message);
});
if (!data.hasOwnProperty("choices")) return null;
return data.choices[0].message.content;
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
if (!(await this.isValidChatCompletionModel(this.model)))
throw new Error(
`GroqAI:streamChatCompletion: ${this.model} is not valid for chat completion!`
);
const streamRequest = await this.openai.createChatCompletion(
{
model: this.model,
stream: true,
messages,
temperature,
},
{ responseType: "stream" }
);
return streamRequest;
}
handleStream(response, stream, responseProps) {
return handleDefaultStreamResponse(response, stream, responseProps);
}
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
}
async embedChunks(textChunks = []) {
return await this.embedder.embedChunks(textChunks);
}
async compressMessages(promptArgs = {}, rawHistory = []) {
const { messageArrayCompressor } = require("../../helpers/chat");
const messageArray = this.constructPrompt(promptArgs);
return await messageArrayCompressor(this, messageArray, rawHistory);
}
}
module.exports = {
GroqLLM,
};

View File

@ -73,6 +73,9 @@ function getLLMProvider(modelPreference = null) {
case "huggingface":
const { HuggingFaceLLM } = require("../AiProviders/huggingface");
return new HuggingFaceLLM(embedder, modelPreference);
case "groq":
const { GroqLLM } = require("../AiProviders/groq");
return new GroqLLM(embedder, modelPreference);
default:
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
}

View File

@ -259,6 +259,16 @@ const KEY_MAPPING = {
checks: [isNotEmpty],
},
// Groq Options
GroqApiKey: {
envKey: "GROQ_API_KEY",
checks: [isNotEmpty],
},
GroqModelPref: {
envKey: "GROQ_MODEL_PREF",
checks: [isNotEmpty],
},
// System Settings
AuthToken: {
envKey: "AUTH_TOKEN",
@ -336,6 +346,7 @@ function supportedLLM(input = "") {
"huggingface",
"perplexity",
"openrouter",
"groq",
].includes(input);
return validSelection ? null : `${input} is not a valid LLM provider.`;
}