diff --git a/iopaint/api.py b/iopaint/api.py index a69a7f8..ab54cfe 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -42,10 +42,10 @@ from iopaint.helper import ( adjust_mask, ) from iopaint.model.utils import torch_gc -from iopaint.model_info import ModelInfo from iopaint.model_manager import ModelManager from iopaint.plugins import build_plugins from iopaint.plugins.base_plugin import BasePlugin +from iopaint.plugins.remove_bg import RemoveBG from iopaint.schema import ( GenInfoResponse, ApiConfig, @@ -56,6 +56,9 @@ from iopaint.schema import ( SDSampler, PluginInfo, AdjustMaskRequest, + RemoveBGModel, + SwitchPluginModelRequest, + ModelInfo, ) CURRENT_DIR = Path(__file__).parent.absolute().resolve() @@ -154,11 +157,11 @@ class Api: # fmt: off self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) - self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo]) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) + self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) @@ -175,9 +178,6 @@ class Api: def add_api_route(self, path: str, endpoint, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) - def api_models(self) -> List[ModelInfo]: - return self.model_manager.scan_models() - def api_current_model(self) -> ModelInfo: return self.model_manager.current_model @@ -187,16 +187,28 @@ class Api: self.model_manager.switch(req.name) return self.model_manager.current_model + def api_switch_plugin_model(self, req: SwitchPluginModelRequest): + if req.plugin_name in self.plugins: + self.plugins[req.plugin_name].switch_model(req.model_name) + if req.plugin_name == RemoveBG.name: + self.config.remove_bg_model = req.model_name + def api_server_config(self) -> ServerConfigResponse: - return ServerConfigResponse( - plugins=[ + plugins = [] + for it in self.plugins.values(): + plugins.append( PluginInfo( name=it.name, support_gen_image=it.support_gen_image, support_gen_mask=it.support_gen_mask, ) - for it in self.plugins.values() - ], + ) + + return ServerConfigResponse( + plugins=plugins, + modelInfos=self.model_manager.scan_models(), + removeBGModel=self.config.remove_bg_model, + removeBGModels=RemoveBGModel.values(), enableFileManager=self.file_manager is not None, enableAutoSaving=self.config.output_dir is not None, enableControlnet=self.model_manager.enable_controlnet, @@ -340,6 +352,7 @@ class Api: self.config.interactive_seg_model, self.config.interactive_seg_device, self.config.enable_remove_bg, + self.config.remove_bg_model, self.config.enable_anime_seg, self.config.enable_realesrgan, self.config.realesrgan_device, diff --git a/iopaint/cli.py b/iopaint/cli.py index 20f6ea7..4ad06a7 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -9,7 +9,7 @@ from typer_config import use_json_config from iopaint.const import * from iopaint.runtime import setup_model_dir, dump_environment_info, check_device -from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel +from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) @@ -127,6 +127,7 @@ def start( ), interactive_seg_device: Device = Option(Device.cpu), enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), + remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4), enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), enable_realesrgan: bool = Option(False), realesrgan_device: Device = Option(Device.cpu), @@ -183,6 +184,7 @@ def start( interactive_seg_model=interactive_seg_model, interactive_seg_device=interactive_seg_device, enable_remove_bg=enable_remove_bg, + remove_bg_model=remove_bg_model, enable_anime_seg=enable_anime_seg, enable_realesrgan=enable_realesrgan, realesrgan_device=realesrgan_device, diff --git a/iopaint/const.py b/iopaint/const.py index 5cdd5d7..5d5605b 100644 --- a/iopaint/const.py +++ b/iopaint/const.py @@ -1,8 +1,5 @@ -import json import os -from pathlib import Path - -from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel +from typing import List INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix" KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint" @@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """ Run diffusion models text encoder on CPU to reduce vRAM usage. """ -SD_CONTROLNET_CHOICES = [ +SD_CONTROLNET_CHOICES: List[str] = [ "lllyasviel/control_v11p_sd15_canny", # "lllyasviel/control_v11p_sd15_seg", "lllyasviel/control_v11p_sd15_openpose", @@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." -REMOVE_BG_HELP = "Enable remove background. Always run on CPU" -ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" +REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU" +ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU" REALESRGAN_HELP = "Enable realesrgan super resolution" GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan" RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan" GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" - -default_configs = dict( - host="127.0.0.1", - port=8080, - model=DEFAULT_MODEL, - model_dir=DEFAULT_MODEL_DIR, - no_half=False, - low_mem=False, - cpu_offload=False, - disable_nsfw_checker=False, - local_files_only=False, - cpu_textencoder=False, - device=Device.cuda, - input=None, - output_dir=None, - quality=95, - enable_interactive_seg=False, - interactive_seg_model=InteractiveSegModel.vit_b, - interactive_seg_device=Device.cpu, - enable_remove_bg=False, - enable_anime_seg=False, - enable_realesrgan=False, - realesrgan_device=Device.cpu, - realesrgan_model=RealESRGANModel.realesr_general_x4v3, - enable_gfpgan=False, - gfpgan_device=Device.cpu, - enable_restoreformer=False, - restoreformer_device=Device.cpu, -) diff --git a/iopaint/download.py b/iopaint/download.py index 67eb813..2ebd7fc 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -3,6 +3,7 @@ import os from functools import lru_cache from typing import List +from iopaint.schema import ModelType, ModelInfo from loguru import logger from pathlib import Path @@ -15,7 +16,6 @@ from iopaint.const import ( ANYTEXT_NAME, ) from iopaint.model.original_sd_configs import get_config_files -from iopaint.model_info import ModelInfo, ModelType def cli_download_model(model: str): diff --git a/iopaint/model_info.py b/iopaint/model_info.py deleted file mode 100644 index 8021fa3..0000000 --- a/iopaint/model_info.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List - -from pydantic import computed_field, BaseModel - -from iopaint.const import ( - SDXL_CONTROLNET_CHOICES, - SD2_CONTROLNET_CHOICES, - SD_CONTROLNET_CHOICES, - INSTRUCT_PIX2PIX_NAME, - KANDINSKY22_NAME, - POWERPAINT_NAME, - ANYTEXT_NAME, -) -from iopaint.schema import ModelType - - -class ModelInfo(BaseModel): - name: str - path: str - model_type: ModelType - is_single_file_diffusers: bool = False - - @computed_field - @property - def need_prompt(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [ - INSTRUCT_PIX2PIX_NAME, - KANDINSKY22_NAME, - POWERPAINT_NAME, - ANYTEXT_NAME, - ] - - @computed_field - @property - def controlnets(self) -> List[str]: - if self.model_type in [ - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SDXL_INPAINT, - ]: - return SDXL_CONTROLNET_CHOICES - if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: - if "sd2" in self.name.lower(): - return SD2_CONTROLNET_CHOICES - else: - return SD_CONTROLNET_CHOICES - if self.name == POWERPAINT_NAME: - return SD_CONTROLNET_CHOICES - return [] - - @computed_field - @property - def support_strength(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME] - - @computed_field - @property - def support_outpainting(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME] - - @computed_field - @property - def support_lcm_lora(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - - @computed_field - @property - def support_controlnet(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - - @computed_field - @property - def support_freeu(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [INSTRUCT_PIX2PIX_NAME] diff --git a/iopaint/model_manager.py b/iopaint/model_manager.py index a5ad87d..de13612 100644 --- a/iopaint/model_manager.py +++ b/iopaint/model_manager.py @@ -8,8 +8,7 @@ from iopaint.download import scan_models from iopaint.helper import switch_mps_device from iopaint.model import models, ControlNet, SD, SDXL from iopaint.model.utils import torch_gc, is_local_files_only -from iopaint.model_info import ModelInfo, ModelType -from iopaint.schema import InpaintRequest +from iopaint.schema import InpaintRequest, ModelInfo, ModelType class ModelManager: diff --git a/iopaint/plugins/__init__.py b/iopaint/plugins/__init__.py index 3e7d5cf..8128025 100644 --- a/iopaint/plugins/__init__.py +++ b/iopaint/plugins/__init__.py @@ -16,6 +16,7 @@ def build_plugins( interactive_seg_model: InteractiveSegModel, interactive_seg_device: Device, enable_remove_bg: bool, + remove_bg_model: str, enable_anime_seg: bool, enable_realesrgan: bool, realesrgan_device: Device, @@ -35,7 +36,7 @@ def build_plugins( if enable_remove_bg: logger.info(f"Initialize {RemoveBG.name} plugin") - plugins[RemoveBG.name] = RemoveBG() + plugins[RemoveBG.name] = RemoveBG(remove_bg_model) if enable_anime_seg: logger.info(f"Initialize {AnimeSeg.name} plugin") diff --git a/iopaint/plugins/base_plugin.py b/iopaint/plugins/base_plugin.py index 13dfdad..1f8bddc 100644 --- a/iopaint/plugins/base_plugin.py +++ b/iopaint/plugins/base_plugin.py @@ -25,3 +25,6 @@ class BasePlugin: def check_dep(self): ... + + def switch_model(self, new_model_name: str): + ... diff --git a/iopaint/plugins/briarmbg.py b/iopaint/plugins/briarmbg.py new file mode 100644 index 0000000..baa7adf --- /dev/null +++ b/iopaint/plugins/briarmbg.py @@ -0,0 +1,515 @@ +# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +import numpy as np +from torchvision.transforms.functional import normalize + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear") + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__( + self, + in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super(myrebnconv, self).__init__() + + self.conv = nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self, x): + return self.rl(self.bn(self.conv(x))) + + +class BriaRMBG(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(BriaRMBG, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, x) + + return [ + F.sigmoid(d1), + F.sigmoid(d2), + F.sigmoid(d3), + F.sigmoid(d4), + F.sigmoid(d5), + F.sigmoid(d6), + ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] + + +def resize_image(image): + image = image.convert("RGB") + model_input_size = (1024, 1024) + image = image.resize(model_input_size, Image.BILINEAR) + return image + + +def create_briarmbg_session(): + from huggingface_hub import hf_hub_download + + net = BriaRMBG() + model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth") + net.load_state_dict(torch.load(model_path, map_location="cpu")) + net.eval() + return net + + +def briarmbg_process(bgr_np_image, session, only_mask=False): + # prepare input + orig_bgr_image = Image.fromarray(bgr_np_image) + w, h = orig_im_size = orig_bgr_image.size + image = resize_image(orig_bgr_image) + im_np = np.array(image) + im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) + im_tensor = torch.unsqueeze(im_tensor, 0) + im_tensor = torch.divide(im_tensor, 255.0) + im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) + if torch.cuda.is_available(): + im_tensor = im_tensor.cuda() + + # inference + result = session(im_tensor) + # post process + result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) + ma = torch.max(result) + mi = torch.min(result) + result = (result - mi) / (ma - mi) + # image to pil + im_array = (result * 255).cpu().data.numpy().astype(np.uint8) + + mask = np.squeeze(im_array) + if only_mask: + return mask + + pil_im = Image.fromarray(mask) + # paste the mask on the original image + new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) + new_im.paste(orig_bgr_image, mask=pil_im) + rgba_np_img = np.asarray(new_im) + return rgba_np_img diff --git a/iopaint/plugins/interactive_seg.py b/iopaint/plugins/interactive_seg.py index 7c03ba5..16b819d 100644 --- a/iopaint/plugins/interactive_seg.py +++ b/iopaint/plugins/interactive_seg.py @@ -1,8 +1,6 @@ import hashlib -import json from typing import List -import cv2 import numpy as np import torch from loguru import logger diff --git a/iopaint/plugins/remove_bg.py b/iopaint/plugins/remove_bg.py index 55de64f..64bf785 100644 --- a/iopaint/plugins/remove_bg.py +++ b/iopaint/plugins/remove_bg.py @@ -1,10 +1,11 @@ import os import cv2 import numpy as np +from loguru import logger from torch.hub import get_dir from iopaint.plugins.base_plugin import BasePlugin -from iopaint.schema import RunPluginRequest +from iopaint.schema import RunPluginRequest, RemoveBGModel class RemoveBG(BasePlugin): @@ -12,32 +13,53 @@ class RemoveBG(BasePlugin): support_gen_mask = True support_gen_image = True - def __init__(self): + def __init__(self, model_name): super().__init__() - from rembg import new_session + self.model_name = model_name hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") os.environ["U2NET_HOME"] = model_dir - self.session = new_session(model_name="u2net") + self._init_session(model_name) + + def _init_session(self, model_name: str): + if model_name == RemoveBGModel.briaai_rmbg_1_4: + from iopaint.plugins.briarmbg import ( + create_briarmbg_session, + briarmbg_process, + ) + + self.session = create_briarmbg_session() + self.remove = briarmbg_process + else: + from rembg import new_session, remove + + self.session = new_session(model_name=model_name) + self.remove = remove + + def switch_model(self, new_model_name): + if self.model_name == new_model_name: + return + + logger.info( + f"Switching removebg model from {self.model_name} to {new_model_name}" + ) + self._init_session(new_model_name) + self.model_name = new_model_name def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: - from rembg import remove - bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) # return BGRA image - output = remove(bgr_np_img, session=self.session) + output = self.remove(bgr_np_img, session=self.session) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: - from rembg import remove - bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) # return BGR image, 255 means foreground, 0 means background - output = remove(bgr_np_img, session=self.session, only_mask=True) + output = self.remove(bgr_np_img, session=self.session, only_mask=True) return output def check_dep(self): diff --git a/iopaint/runtime.py b/iopaint/runtime.py index 44e6779..7199f83 100644 --- a/iopaint/runtime.py +++ b/iopaint/runtime.py @@ -5,11 +5,11 @@ import sys from pathlib import Path import packaging.version +from iopaint.schema import Device from loguru import logger from rich import print from typing import Dict, Any -from iopaint.const import Device _PY_VERSION: str = sys.version.split()[0].rstrip("+") diff --git a/iopaint/schema.py b/iopaint/schema.py index 6fd3980..91fe7c0 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -1,11 +1,117 @@ -import json import random from enum import Enum from pathlib import Path from typing import Optional, Literal, List +from iopaint.const import ( + INSTRUCT_PIX2PIX_NAME, + KANDINSKY22_NAME, + POWERPAINT_NAME, + ANYTEXT_NAME, + SDXL_CONTROLNET_CHOICES, + SD2_CONTROLNET_CHOICES, + SD_CONTROLNET_CHOICES, +) from loguru import logger -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, computed_field + + +class ModelType(str, Enum): + INPAINT = "inpaint" # LaMa, MAT... + DIFFUSERS_SD = "diffusers_sd" + DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint" + DIFFUSERS_SDXL = "diffusers_sdxl" + DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint" + DIFFUSERS_OTHER = "diffusers_other" + + +class ModelInfo(BaseModel): + name: str + path: str + model_type: ModelType + is_single_file_diffusers: bool = False + + @computed_field + @property + def need_prompt(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [ + INSTRUCT_PIX2PIX_NAME, + KANDINSKY22_NAME, + POWERPAINT_NAME, + ANYTEXT_NAME, + ] + + @computed_field + @property + def controlnets(self) -> List[str]: + if self.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + return SDXL_CONTROLNET_CHOICES + if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: + if "sd2" in self.name.lower(): + return SD2_CONTROLNET_CHOICES + else: + return SD_CONTROLNET_CHOICES + if self.name == POWERPAINT_NAME: + return SD_CONTROLNET_CHOICES + return [] + + @computed_field + @property + def support_strength(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME] + + @computed_field + @property + def support_outpainting(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME] + + @computed_field + @property + def support_lcm_lora(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_controlnet(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_freeu(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [INSTRUCT_PIX2PIX_NAME] class Choices(str, Enum): @@ -20,6 +126,16 @@ class RealESRGANModel(Choices): RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" +class RemoveBGModel(Choices): + u2net = "u2net" + u2netp = "u2netp" + u2net_human_seg = "u2net_human_seg" + u2net_cloth_seg = "u2net_cloth_seg" + silueta = "silueta" + isnet_general_use = "isnet-general-use" + briaai_rmbg_1_4 = "briaai/RMBG-1.4" + + class Device(Choices): cpu = "cpu" cuda = "cuda" @@ -44,15 +160,6 @@ class CV2Flag(str, Enum): INPAINT_TELEA = "INPAINT_TELEA" -class ModelType(str, Enum): - INPAINT = "inpaint" # LaMa, MAT... - DIFFUSERS_SD = "diffusers_sd" - DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint" - DIFFUSERS_SDXL = "diffusers_sdxl" - DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint" - DIFFUSERS_OTHER = "diffusers_other" - - class HDStrategy(str, Enum): # Use original image size ORIGINAL = "Original" @@ -124,6 +231,7 @@ class ApiConfig(BaseModel): interactive_seg_model: InteractiveSegModel interactive_seg_device: Device enable_remove_bg: bool + remove_bg_model: str enable_anime_seg: bool enable_realesrgan: bool realesrgan_device: Device @@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel): class ServerConfigResponse(BaseModel): plugins: List[PluginInfo] + modelInfos: List[ModelInfo] + removeBGModel: RemoveBGModel + removeBGModels: List[str] enableFileManager: bool enableAutoSaving: bool enableControlnet: bool @@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel): name: str +class SwitchPluginModelRequest(BaseModel): + plugin_name: str + model_name: str + + AdjustMaskOperate = Literal["expand", "shrink", "reverse"] diff --git a/iopaint/tests/test_plugins.py b/iopaint/tests/test_plugins.py index c481cb1..77efdbe 100644 --- a/iopaint/tests/test_plugins.py +++ b/iopaint/tests/test_plugins.py @@ -5,7 +5,7 @@ from PIL import Image from iopaint.helper import encode_pil_to_base64, gen_frontend_mask from iopaint.plugins.anime_seg import AnimeSeg -from iopaint.schema import RunPluginRequest +from iopaint.schema import RunPluginRequest, RemoveBGModel from iopaint.tests.utils import check_device, current_dir, save_dir os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -34,7 +34,7 @@ def _save(img, name): def test_remove_bg(): - model = RemoveBG() + model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4) rgba_np_img = model.gen_image( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 761116b..4f71b24 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -1,4 +1,14 @@ +import json import os +from pathlib import Path + +from iopaint.schema import ( + Device, + InteractiveSegModel, + RemoveBGModel, + RealESRGANModel, + ApiConfig, +) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -15,6 +25,37 @@ from iopaint.const import * _config_file: Path = None +default_configs = dict( + host="127.0.0.1", + port=8080, + model=DEFAULT_MODEL, + model_dir=DEFAULT_MODEL_DIR, + no_half=False, + low_mem=False, + cpu_offload=False, + disable_nsfw_checker=False, + local_files_only=False, + cpu_textencoder=False, + device=Device.cuda, + input=None, + output_dir=None, + quality=95, + enable_interactive_seg=False, + interactive_seg_model=InteractiveSegModel.vit_b, + interactive_seg_device=Device.cpu, + enable_remove_bg=False, + remove_bg_model=RemoveBGModel.u2net, + enable_anime_seg=False, + enable_realesrgan=False, + realesrgan_device=Device.cpu, + realesrgan_model=RealESRGANModel.realesr_general_x4v3, + enable_gfpgan=False, + gfpgan_device=Device.cpu, + enable_restoreformer=False, + restoreformer_device=Device.cpu, +) + + class WebConfig(ApiConfig): model_dir: str = DEFAULT_MODEL_DIR @@ -50,6 +91,7 @@ def save_config( interactive_seg_model, interactive_seg_device, enable_remove_bg, + remove_bg_model, enable_anime_seg, enable_realesrgan, realesrgan_device, @@ -115,7 +157,7 @@ def main(config_file: Path): with gr.Row(): recommend_model = gr.Dropdown( ["lama", "mat", "migan"] + DIFFUSION_MODELS, - label="Recommend Models", + label="Recommended Models", ) downloaded_model = gr.Dropdown( downloaded_models, label="Downloaded Models" @@ -179,6 +221,11 @@ def main(config_file: Path): enable_remove_bg = gr.Checkbox( init_config.enable_remove_bg, label=REMOVE_BG_HELP ) + remove_bg_model = gr.Radio( + RemoveBGModel.values(), + label="Remove bg model", + value=init_config.remove_bg_model, + ) with gr.Row(): enable_anime_seg = gr.Checkbox( init_config.enable_anime_seg, label=ANIMESEG_HELP @@ -241,6 +288,7 @@ def main(config_file: Path): interactive_seg_model, interactive_seg_device, enable_remove_bg, + remove_bg_model, enable_anime_seg, enable_realesrgan, realesrgan_device, diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index ff1c692..8ba5514 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs" import { useEffect, useState } from "react" import { cn } from "@/lib/utils" import { useQuery } from "@tanstack/react-query" -import { fetchModelInfos, switchModel } from "@/lib/api" -import { ModelInfo } from "@/lib/types" +import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api" +import { ModelInfo, PluginName } from "@/lib/types" import { useStore } from "@/lib/states" import { ScrollArea } from "./ui/scroll-area" import { useToast } from "./ui/use-toast" @@ -39,6 +39,14 @@ import { MODEL_TYPE_OTHER, } from "@/lib/const" import useHotKey from "@/hooks/useHotkey" +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from "./ui/select" const formSchema = z.object({ enableFileManager: z.boolean(), @@ -48,42 +56,45 @@ const formSchema = z.object({ enableManualInpainting: z.boolean(), enableUploadMask: z.boolean(), enableAutoExtractPrompt: z.boolean(), + removeBGModel: z.string(), }) const TAB_GENERAL = "General" const TAB_MODEL = "Model" +const TAB_PLUGINS = "Plugins" // const TAB_FILE_MANAGER = "File Manager" -const TAB_NAMES = [TAB_MODEL, TAB_GENERAL] +const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS] export function SettingsDialog() { const [open, toggleOpen] = useToggle(false) - const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false) const [tab, setTab] = useState(TAB_MODEL) const [ updateAppState, settings, updateSettings, fileManagerState, - updateFileManagerState, setAppModel, + setServerConfig, ] = useStore((state) => [ state.updateAppState, state.settings, state.updateSettings, state.fileManagerState, - state.updateFileManagerState, state.setModel, + state.setServerConfig, ]) const { toast } = useToast() const [model, setModel] = useState(settings.model) + const [modelSwitchingTexts, setModelSwitchingTexts] = useState([]) + const openModelSwitching = modelSwitchingTexts.length > 0 useEffect(() => { setModel(settings.model) }, [settings.model]) - const { data: modelInfos, status } = useQuery({ - queryKey: ["modelInfos"], - queryFn: fetchModelInfos, + const { data: serverConfig, status } = useQuery({ + queryKey: ["serverConfig"], + queryFn: getServerConfig, }) // 1. Define your form. @@ -96,9 +107,17 @@ export function SettingsDialog() { enableAutoExtractPrompt: settings.enableAutoExtractPrompt, inputDirectory: fileManagerState.inputDirectory, outputDirectory: fileManagerState.outputDirectory, + removeBGModel: serverConfig?.removeBGModel, }, }) + useEffect(() => { + if (serverConfig) { + setServerConfig(serverConfig) + form.setValue("removeBGModel", serverConfig.removeBGModel) + } + }, [form, serverConfig]) + async function onSubmit(values: z.infer) { // Do something with the form values. ✅ This will be type-safe and validated. updateSettings({ @@ -109,29 +128,67 @@ export function SettingsDialog() { }) // TODO: validate input/output Directory - updateFileManagerState({ - inputDirectory: values.inputDirectory, - outputDirectory: values.outputDirectory, - }) - if (model.name !== settings.model.name) { - toggleOpenModelSwitching() - updateAppState({ disableShortCuts: true }) - try { - const newModel = await switchModel(model.name) - toast({ - title: `Switch to ${newModel.name} success`, - }) - setAppModel(model) - } catch (error: any) { - toast({ - variant: "destructive", - title: `Switch to ${model.name} failed: ${error}`, - }) - setModel(settings.model) - } finally { - toggleOpenModelSwitching() - updateAppState({ disableShortCuts: false }) + // updateFileManagerState({ + // inputDirectory: values.inputDirectory, + // outputDirectory: values.outputDirectory, + // }) + + const shouldSwitchModel = model.name !== settings.model.name + const shouldSwitchRemoveBGModel = + serverConfig?.removeBGModel !== values.removeBGModel + const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel + + if (showModelSwitching) { + const newModelSwitchingTexts: string[] = [] + if (shouldSwitchModel) { + newModelSwitchingTexts.push( + `Switching model from ${settings.model.name} to ${model.name}` + ) } + if (shouldSwitchRemoveBGModel) { + newModelSwitchingTexts.push( + `Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}` + ) + } + setModelSwitchingTexts(newModelSwitchingTexts) + + updateAppState({ disableShortCuts: true }) + + if (shouldSwitchModel) { + try { + const newModel = await switchModel(model.name) + toast({ + title: `Switch to ${newModel.name} success`, + }) + setAppModel(model) + } catch (error: any) { + toast({ + variant: "destructive", + title: `Switch to ${model.name} failed: ${error}`, + }) + setModel(settings.model) + } + } + + if (shouldSwitchRemoveBGModel) { + try { + const res = await switchPluginModel( + PluginName.RemoveBG, + values.removeBGModel + ) + if (res.status !== 200) { + throw new Error(res.statusText) + } + } catch (error: any) { + toast({ + variant: "destructive", + title: `Switch removebg model to ${model.name} failed: ${error}`, + }) + } + } + + setModelSwitchingTexts([]) + updateAppState({ disableShortCuts: false }) } } @@ -143,7 +200,17 @@ export function SettingsDialog() { onSubmit(form.getValues()) } }, - [open, form, model] + [open, form, model, serverConfig] + ) + + if (status !== "success") { + return <> + } + + const modelInfos = serverConfig.modelInfos + const plugins = serverConfig.plugins + const removeBGEnabled = plugins.some( + (plugin) => plugin.name === PluginName.RemoveBG ) function onOpenChange(value: boolean) { @@ -186,10 +253,6 @@ export function SettingsDialog() { } function renderModelSettings() { - if (status !== "success") { - return <> - } - let defaultTab = MODEL_TYPE_INPAINT for (let info of modelInfos) { if (model.name === info.name) { @@ -356,6 +419,44 @@ export function SettingsDialog() { ) } + function renderPluginsSettings() { + return ( +
+ ( + +
+ Remove Background + Remove background model +
+ +
+ )} + /> +
+ ) + } // function renderFileManagerSettings() { // return ( //
@@ -446,7 +547,9 @@ export function SettingsDialog() { Loading...
-
Switching to {model.name}
+ {modelSwitchingTexts.map((text, index) => ( +
{text}
+ ))} {/* */} @@ -473,6 +576,7 @@ export function SettingsDialog() {