add remove bg model selection

This commit is contained in:
Qing 2024-02-08 16:49:54 +08:00
parent cf9ceea4e6
commit 8060e16c70
19 changed files with 915 additions and 222 deletions

View File

@ -42,10 +42,10 @@ from iopaint.helper import (
adjust_mask,
)
from iopaint.model.utils import torch_gc
from iopaint.model_info import ModelInfo
from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import (
GenInfoResponse,
ApiConfig,
@ -56,6 +56,9 @@ from iopaint.schema import (
SDSampler,
PluginInfo,
AdjustMaskRequest,
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@ -154,11 +157,11 @@ class Api:
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
@ -175,9 +178,6 @@ class Api:
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
def api_models(self) -> List[ModelInfo]:
return self.model_manager.scan_models()
def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model
@ -187,16 +187,28 @@ class Api:
self.model_manager.switch(req.name)
return self.model_manager.current_model
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
if req.plugin_name in self.plugins:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse(
plugins=[
plugins = []
for it in self.plugins.values():
plugins.append(
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
for it in self.plugins.values()
],
)
return ServerConfigResponse(
plugins=plugins,
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
@ -340,6 +352,7 @@ class Api:
self.config.interactive_seg_model,
self.config.interactive_seg_device,
self.config.enable_remove_bg,
self.config.remove_bg_model,
self.config.enable_anime_seg,
self.config.enable_realesrgan,
self.config.realesrgan_device,

View File

@ -9,7 +9,7 @@ from typer_config import use_json_config
from iopaint.const import *
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
@ -127,6 +127,7 @@ def start(
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
@ -183,6 +184,7 @@ def start(
interactive_seg_model=interactive_seg_model,
interactive_seg_device=interactive_seg_device,
enable_remove_bg=enable_remove_bg,
remove_bg_model=remove_bg_model,
enable_anime_seg=enable_anime_seg,
enable_realesrgan=enable_realesrgan,
realesrgan_device=realesrgan_device,

View File

@ -1,8 +1,5 @@
import json
import os
from pathlib import Path
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
from typing import List
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """
Run diffusion models text encoder on CPU to reduce vRAM usage.
"""
SD_CONTROLNET_CHOICES = [
SD_CONTROLNET_CHOICES: List[str] = [
"lllyasviel/control_v11p_sd15_canny",
# "lllyasviel/control_v11p_sd15_seg",
"lllyasviel/control_v11p_sd15_openpose",
@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
REALESRGAN_HELP = "Enable realesrgan super resolution"
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
default_configs = dict(
host="127.0.0.1",
port=8080,
model=DEFAULT_MODEL,
model_dir=DEFAULT_MODEL_DIR,
no_half=False,
low_mem=False,
cpu_offload=False,
disable_nsfw_checker=False,
local_files_only=False,
cpu_textencoder=False,
device=Device.cuda,
input=None,
output_dir=None,
quality=95,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device=Device.cpu,
enable_restoreformer=False,
restoreformer_device=Device.cpu,
)

View File

@ -3,6 +3,7 @@ import os
from functools import lru_cache
from typing import List
from iopaint.schema import ModelType, ModelInfo
from loguru import logger
from pathlib import Path
@ -15,7 +16,6 @@ from iopaint.const import (
ANYTEXT_NAME,
)
from iopaint.model.original_sd_configs import get_config_files
from iopaint.model_info import ModelInfo, ModelType
def cli_download_model(model: str):

View File

@ -1,103 +0,0 @@
from typing import List
from pydantic import computed_field, BaseModel
from iopaint.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
)
from iopaint.schema import ModelType
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if "sd2" in self.name.lower():
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
if self.name == POWERPAINT_NAME:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [INSTRUCT_PIX2PIX_NAME]

View File

@ -8,8 +8,7 @@ from iopaint.download import scan_models
from iopaint.helper import switch_mps_device
from iopaint.model import models, ControlNet, SD, SDXL
from iopaint.model.utils import torch_gc, is_local_files_only
from iopaint.model_info import ModelInfo, ModelType
from iopaint.schema import InpaintRequest
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
class ModelManager:

View File

@ -16,6 +16,7 @@ def build_plugins(
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
remove_bg_model: str,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
@ -35,7 +36,7 @@ def build_plugins(
if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin")

View File

@ -25,3 +25,6 @@ class BasePlugin:
def check_dep(self):
...
def switch_model(self, new_model_name: str):
...

515
iopaint/plugins/briarmbg.py Normal file
View File

@ -0,0 +1,515 @@
# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from torchvision.transforms.functional import normalize
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7, self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(
self,
in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
):
super(myrebnconv, self).__init__()
self.conv = nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self, x):
return self.rl(self.bn(self.conv(x)))
class BriaRMBG(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(BriaRMBG, self).__init__()
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage1 = RSU7(64, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self, x):
hx = x
hxin = self.conv_in(hx)
# hx = self.pool_in(hxin)
# stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1, x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, x)
return [
F.sigmoid(d1),
F.sigmoid(d2),
F.sigmoid(d3),
F.sigmoid(d4),
F.sigmoid(d5),
F.sigmoid(d6),
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
def resize_image(image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def create_briarmbg_session():
from huggingface_hub import hf_hub_download
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
return net
def briarmbg_process(bgr_np_image, session, only_mask=False):
# prepare input
orig_bgr_image = Image.fromarray(bgr_np_image)
w, h = orig_im_size = orig_bgr_image.size
image = resize_image(orig_bgr_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = session(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
mask = np.squeeze(im_array)
if only_mask:
return mask
pil_im = Image.fromarray(mask)
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_bgr_image, mask=pil_im)
rgba_np_img = np.asarray(new_im)
return rgba_np_img

View File

@ -1,8 +1,6 @@
import hashlib
import json
from typing import List
import cv2
import numpy as np
import torch
from loguru import logger

View File

@ -1,10 +1,11 @@
import os
import cv2
import numpy as np
from loguru import logger
from torch.hub import get_dir
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
from iopaint.schema import RunPluginRequest, RemoveBGModel
class RemoveBG(BasePlugin):
@ -12,32 +13,53 @@ class RemoveBG(BasePlugin):
support_gen_mask = True
support_gen_image = True
def __init__(self):
def __init__(self, model_name):
super().__init__()
from rembg import new_session
self.model_name = model_name
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.environ["U2NET_HOME"] = model_dir
self.session = new_session(model_name="u2net")
self._init_session(model_name)
def _init_session(self, model_name: str):
if model_name == RemoveBGModel.briaai_rmbg_1_4:
from iopaint.plugins.briarmbg import (
create_briarmbg_session,
briarmbg_process,
)
self.session = create_briarmbg_session()
self.remove = briarmbg_process
else:
from rembg import new_session, remove
self.session = new_session(model_name=model_name)
self.remove = remove
def switch_model(self, new_model_name):
if self.model_name == new_model_name:
return
logger.info(
f"Switching removebg model from {self.model_name} to {new_model_name}"
)
self._init_session(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image
output = remove(bgr_np_img, session=self.session)
output = self.remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background
output = remove(bgr_np_img, session=self.session, only_mask=True)
output = self.remove(bgr_np_img, session=self.session, only_mask=True)
return output
def check_dep(self):

View File

@ -5,11 +5,11 @@ import sys
from pathlib import Path
import packaging.version
from iopaint.schema import Device
from loguru import logger
from rich import print
from typing import Dict, Any
from iopaint.const import Device
_PY_VERSION: str = sys.version.split()[0].rstrip("+")

View File

@ -1,11 +1,117 @@
import json
import random
from enum import Enum
from pathlib import Path
from typing import Optional, Literal, List
from iopaint.const import (
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, computed_field
class ModelType(str, Enum):
INPAINT = "inpaint" # LaMa, MAT...
DIFFUSERS_SD = "diffusers_sd"
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
DIFFUSERS_SDXL = "diffusers_sdxl"
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
DIFFUSERS_OTHER = "diffusers_other"
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if "sd2" in self.name.lower():
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
if self.name == POWERPAINT_NAME:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [INSTRUCT_PIX2PIX_NAME]
class Choices(str, Enum):
@ -20,6 +126,16 @@ class RealESRGANModel(Choices):
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
class RemoveBGModel(Choices):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
isnet_general_use = "isnet-general-use"
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
class Device(Choices):
cpu = "cpu"
cuda = "cuda"
@ -44,15 +160,6 @@ class CV2Flag(str, Enum):
INPAINT_TELEA = "INPAINT_TELEA"
class ModelType(str, Enum):
INPAINT = "inpaint" # LaMa, MAT...
DIFFUSERS_SD = "diffusers_sd"
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
DIFFUSERS_SDXL = "diffusers_sdxl"
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
DIFFUSERS_OTHER = "diffusers_other"
class HDStrategy(str, Enum):
# Use original image size
ORIGINAL = "Original"
@ -124,6 +231,7 @@ class ApiConfig(BaseModel):
interactive_seg_model: InteractiveSegModel
interactive_seg_device: Device
enable_remove_bg: bool
remove_bg_model: str
enable_anime_seg: bool
enable_realesrgan: bool
realesrgan_device: Device
@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel):
class ServerConfigResponse(BaseModel):
plugins: List[PluginInfo]
modelInfos: List[ModelInfo]
removeBGModel: RemoveBGModel
removeBGModels: List[str]
enableFileManager: bool
enableAutoSaving: bool
enableControlnet: bool
@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel):
name: str
class SwitchPluginModelRequest(BaseModel):
plugin_name: str
model_name: str
AdjustMaskOperate = Literal["expand", "shrink", "reverse"]

View File

@ -5,7 +5,7 @@ from PIL import Image
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
from iopaint.plugins.anime_seg import AnimeSeg
from iopaint.schema import RunPluginRequest
from iopaint.schema import RunPluginRequest, RemoveBGModel
from iopaint.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -34,7 +34,7 @@ def _save(img, name):
def test_remove_bg():
model = RemoveBG()
model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
rgba_np_img = model.gen_image(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)

View File

@ -1,4 +1,14 @@
import json
import os
from pathlib import Path
from iopaint.schema import (
Device,
InteractiveSegModel,
RemoveBGModel,
RealESRGANModel,
ApiConfig,
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -15,6 +25,37 @@ from iopaint.const import *
_config_file: Path = None
default_configs = dict(
host="127.0.0.1",
port=8080,
model=DEFAULT_MODEL,
model_dir=DEFAULT_MODEL_DIR,
no_half=False,
low_mem=False,
cpu_offload=False,
disable_nsfw_checker=False,
local_files_only=False,
cpu_textencoder=False,
device=Device.cuda,
input=None,
output_dir=None,
quality=95,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
remove_bg_model=RemoveBGModel.u2net,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device=Device.cpu,
enable_restoreformer=False,
restoreformer_device=Device.cpu,
)
class WebConfig(ApiConfig):
model_dir: str = DEFAULT_MODEL_DIR
@ -50,6 +91,7 @@ def save_config(
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
remove_bg_model,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
@ -115,7 +157,7 @@ def main(config_file: Path):
with gr.Row():
recommend_model = gr.Dropdown(
["lama", "mat", "migan"] + DIFFUSION_MODELS,
label="Recommend Models",
label="Recommended Models",
)
downloaded_model = gr.Dropdown(
downloaded_models, label="Downloaded Models"
@ -179,6 +221,11 @@ def main(config_file: Path):
enable_remove_bg = gr.Checkbox(
init_config.enable_remove_bg, label=REMOVE_BG_HELP
)
remove_bg_model = gr.Radio(
RemoveBGModel.values(),
label="Remove bg model",
value=init_config.remove_bg_model,
)
with gr.Row():
enable_anime_seg = gr.Checkbox(
init_config.enable_anime_seg, label=ANIMESEG_HELP
@ -241,6 +288,7 @@ def main(config_file: Path):
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
remove_bg_model,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,

View File

@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
import { useEffect, useState } from "react"
import { cn } from "@/lib/utils"
import { useQuery } from "@tanstack/react-query"
import { fetchModelInfos, switchModel } from "@/lib/api"
import { ModelInfo } from "@/lib/types"
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
import { ModelInfo, PluginName } from "@/lib/types"
import { useStore } from "@/lib/states"
import { ScrollArea } from "./ui/scroll-area"
import { useToast } from "./ui/use-toast"
@ -39,6 +39,14 @@ import {
MODEL_TYPE_OTHER,
} from "@/lib/const"
import useHotKey from "@/hooks/useHotkey"
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
SelectValue,
} from "./ui/select"
const formSchema = z.object({
enableFileManager: z.boolean(),
@ -48,42 +56,45 @@ const formSchema = z.object({
enableManualInpainting: z.boolean(),
enableUploadMask: z.boolean(),
enableAutoExtractPrompt: z.boolean(),
removeBGModel: z.string(),
})
const TAB_GENERAL = "General"
const TAB_MODEL = "Model"
const TAB_PLUGINS = "Plugins"
// const TAB_FILE_MANAGER = "File Manager"
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
export function SettingsDialog() {
const [open, toggleOpen] = useToggle(false)
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
const [tab, setTab] = useState(TAB_MODEL)
const [
updateAppState,
settings,
updateSettings,
fileManagerState,
updateFileManagerState,
setAppModel,
setServerConfig,
] = useStore((state) => [
state.updateAppState,
state.settings,
state.updateSettings,
state.fileManagerState,
state.updateFileManagerState,
state.setModel,
state.setServerConfig,
])
const { toast } = useToast()
const [model, setModel] = useState<ModelInfo>(settings.model)
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
const openModelSwitching = modelSwitchingTexts.length > 0
useEffect(() => {
setModel(settings.model)
}, [settings.model])
const { data: modelInfos, status } = useQuery({
queryKey: ["modelInfos"],
queryFn: fetchModelInfos,
const { data: serverConfig, status } = useQuery({
queryKey: ["serverConfig"],
queryFn: getServerConfig,
})
// 1. Define your form.
@ -96,9 +107,17 @@ export function SettingsDialog() {
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
inputDirectory: fileManagerState.inputDirectory,
outputDirectory: fileManagerState.outputDirectory,
removeBGModel: serverConfig?.removeBGModel,
},
})
useEffect(() => {
if (serverConfig) {
setServerConfig(serverConfig)
form.setValue("removeBGModel", serverConfig.removeBGModel)
}
}, [form, serverConfig])
async function onSubmit(values: z.infer<typeof formSchema>) {
// Do something with the form values. ✅ This will be type-safe and validated.
updateSettings({
@ -109,13 +128,33 @@ export function SettingsDialog() {
})
// TODO: validate input/output Directory
updateFileManagerState({
inputDirectory: values.inputDirectory,
outputDirectory: values.outputDirectory,
})
if (model.name !== settings.model.name) {
toggleOpenModelSwitching()
// updateFileManagerState({
// inputDirectory: values.inputDirectory,
// outputDirectory: values.outputDirectory,
// })
const shouldSwitchModel = model.name !== settings.model.name
const shouldSwitchRemoveBGModel =
serverConfig?.removeBGModel !== values.removeBGModel
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
if (showModelSwitching) {
const newModelSwitchingTexts: string[] = []
if (shouldSwitchModel) {
newModelSwitchingTexts.push(
`Switching model from ${settings.model.name} to ${model.name}`
)
}
if (shouldSwitchRemoveBGModel) {
newModelSwitchingTexts.push(
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
)
}
setModelSwitchingTexts(newModelSwitchingTexts)
updateAppState({ disableShortCuts: true })
if (shouldSwitchModel) {
try {
const newModel = await switchModel(model.name)
toast({
@ -128,11 +167,29 @@ export function SettingsDialog() {
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
} finally {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: false })
}
}
if (shouldSwitchRemoveBGModel) {
try {
const res = await switchPluginModel(
PluginName.RemoveBG,
values.removeBGModel
)
if (res.status !== 200) {
throw new Error(res.statusText)
}
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch removebg model to ${model.name} failed: ${error}`,
})
}
}
setModelSwitchingTexts([])
updateAppState({ disableShortCuts: false })
}
}
useHotKey(
@ -143,7 +200,17 @@ export function SettingsDialog() {
onSubmit(form.getValues())
}
},
[open, form, model]
[open, form, model, serverConfig]
)
if (status !== "success") {
return <></>
}
const modelInfos = serverConfig.modelInfos
const plugins = serverConfig.plugins
const removeBGEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RemoveBG
)
function onOpenChange(value: boolean) {
@ -186,10 +253,6 @@ export function SettingsDialog() {
}
function renderModelSettings() {
if (status !== "success") {
return <></>
}
let defaultTab = MODEL_TYPE_INPAINT
for (let info of modelInfos) {
if (model.name === info.name) {
@ -356,6 +419,44 @@ export function SettingsDialog() {
)
}
function renderPluginsSettings() {
return (
<div className="space-y-4 w-[510px]">
<FormField
control={form.control}
name="removeBGModel"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Remove Background</FormLabel>
<FormDescription>Remove background model</FormDescription>
</div>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!removeBGEnabled}
>
<FormControl>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="Select removebg model" />
</SelectTrigger>
</FormControl>
<SelectContent align="end">
<SelectGroup>
{serverConfig?.removeBGModels.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormItem>
)}
/>
</div>
)
}
// function renderFileManagerSettings() {
// return (
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
@ -446,7 +547,9 @@ export function SettingsDialog() {
<span className="sr-only">Loading...</span>
</div>
<div>Switching to {model.name}</div>
{modelSwitchingTexts.map((text, index) => (
<div key={index}>{text}</div>
))}
</div>
{/* </AlertDialogDescription> */}
</AlertDialogHeader>
@ -473,6 +576,7 @@ export function SettingsDialog() {
<Button
key={item}
variant="ghost"
disabled={item === TAB_PLUGINS && !removeBGEnabled}
onClick={() => setTab(item)}
className={cn(
tab === item ? "bg-muted " : "hover:bg-muted",
@ -489,6 +593,7 @@ export function SettingsDialog() {
<form onSubmit={form.handleSubmit(onSubmit)}>
{tab === TAB_MODEL ? renderModelSettings() : <></>}
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
{/* {tab === TAB_FILE_MANAGER ? (
renderFileManagerSettings()
) : (

View File

@ -12,7 +12,7 @@ export default function useInputImage() {
fetch(`${API_ENDPOINT}/inputimage`, { headers })
.then(async (res) => {
if (!res.ok) {
throw new Error("No input image found")
return
}
const filename = res.headers
.get("content-disposition")

View File

@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
return res.data
}
export async function switchPluginModel(
plugin_name: string,
model_name: string
) {
return api.post(`/switch_plugin_model`, { plugin_name, model_name })
}
export async function currentModel(): Promise<ModelInfo> {
const res = await api.get("/model")
return res.data
}
export function fetchModelInfos(): Promise<ModelInfo[]> {
return api.get("/models").then((response) => response.data)
}
export async function runPlugin(
genMask: boolean,
name: string,

View File

@ -14,6 +14,9 @@ export interface PluginInfo {
export interface ServerConfig {
plugins: PluginInfo[]
modelInfos: ModelInfo[]
removeBGModel: string
removeBGModels: string[]
enableFileManager: boolean
enableAutoSaving: boolean
enableControlnet: boolean