add remove bg model selection

This commit is contained in:
Qing 2024-02-08 16:49:54 +08:00
parent cf9ceea4e6
commit 8060e16c70
19 changed files with 915 additions and 222 deletions

View File

@ -42,10 +42,10 @@ from iopaint.helper import (
adjust_mask, adjust_mask,
) )
from iopaint.model.utils import torch_gc from iopaint.model.utils import torch_gc
from iopaint.model_info import ModelInfo
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins from iopaint.plugins import build_plugins
from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import ( from iopaint.schema import (
GenInfoResponse, GenInfoResponse,
ApiConfig, ApiConfig,
@ -56,6 +56,9 @@ from iopaint.schema import (
SDSampler, SDSampler,
PluginInfo, PluginInfo,
AdjustMaskRequest, AdjustMaskRequest,
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
) )
CURRENT_DIR = Path(__file__).parent.absolute().resolve() CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@ -154,11 +157,11 @@ class Api:
# fmt: off # fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
@ -175,9 +178,6 @@ class Api:
def add_api_route(self, path: str, endpoint, **kwargs): def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs) return self.app.add_api_route(path, endpoint, **kwargs)
def api_models(self) -> List[ModelInfo]:
return self.model_manager.scan_models()
def api_current_model(self) -> ModelInfo: def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model return self.model_manager.current_model
@ -187,16 +187,28 @@ class Api:
self.model_manager.switch(req.name) self.model_manager.switch(req.name)
return self.model_manager.current_model return self.model_manager.current_model
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
if req.plugin_name in self.plugins:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
def api_server_config(self) -> ServerConfigResponse: def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse( plugins = []
plugins=[ for it in self.plugins.values():
plugins.append(
PluginInfo( PluginInfo(
name=it.name, name=it.name,
support_gen_image=it.support_gen_image, support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask, support_gen_mask=it.support_gen_mask,
) )
for it in self.plugins.values() )
],
return ServerConfigResponse(
plugins=plugins,
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
enableFileManager=self.file_manager is not None, enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None, enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet, enableControlnet=self.model_manager.enable_controlnet,
@ -340,6 +352,7 @@ class Api:
self.config.interactive_seg_model, self.config.interactive_seg_model,
self.config.interactive_seg_device, self.config.interactive_seg_device,
self.config.enable_remove_bg, self.config.enable_remove_bg,
self.config.remove_bg_model,
self.config.enable_anime_seg, self.config.enable_anime_seg,
self.config.enable_realesrgan, self.config.enable_realesrgan,
self.config.realesrgan_device, self.config.realesrgan_device,

View File

@ -9,7 +9,7 @@ from typer_config import use_json_config
from iopaint.const import * from iopaint.const import *
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
@ -127,6 +127,7 @@ def start(
), ),
interactive_seg_device: Device = Option(Device.cpu), interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False), enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu), realesrgan_device: Device = Option(Device.cpu),
@ -183,6 +184,7 @@ def start(
interactive_seg_model=interactive_seg_model, interactive_seg_model=interactive_seg_model,
interactive_seg_device=interactive_seg_device, interactive_seg_device=interactive_seg_device,
enable_remove_bg=enable_remove_bg, enable_remove_bg=enable_remove_bg,
remove_bg_model=remove_bg_model,
enable_anime_seg=enable_anime_seg, enable_anime_seg=enable_anime_seg,
enable_realesrgan=enable_realesrgan, enable_realesrgan=enable_realesrgan,
realesrgan_device=realesrgan_device, realesrgan_device=realesrgan_device,

View File

@ -1,8 +1,5 @@
import json
import os import os
from pathlib import Path from typing import List
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix" INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint" KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """
Run diffusion models text encoder on CPU to reduce vRAM usage. Run diffusion models text encoder on CPU to reduce vRAM usage.
""" """
SD_CONTROLNET_CHOICES = [ SD_CONTROLNET_CHOICES: List[str] = [
"lllyasviel/control_v11p_sd15_canny", "lllyasviel/control_v11p_sd15_canny",
# "lllyasviel/control_v11p_sd15_seg", # "lllyasviel/control_v11p_sd15_seg",
"lllyasviel/control_v11p_sd15_openpose", "lllyasviel/control_v11p_sd15_openpose",
@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
REMOVE_BG_HELP = "Enable remove background. Always run on CPU" REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
REALESRGAN_HELP = "Enable realesrgan super resolution" REALESRGAN_HELP = "Enable realesrgan super resolution"
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan" GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan" RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
default_configs = dict(
host="127.0.0.1",
port=8080,
model=DEFAULT_MODEL,
model_dir=DEFAULT_MODEL_DIR,
no_half=False,
low_mem=False,
cpu_offload=False,
disable_nsfw_checker=False,
local_files_only=False,
cpu_textencoder=False,
device=Device.cuda,
input=None,
output_dir=None,
quality=95,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device=Device.cpu,
enable_restoreformer=False,
restoreformer_device=Device.cpu,
)

View File

@ -3,6 +3,7 @@ import os
from functools import lru_cache from functools import lru_cache
from typing import List from typing import List
from iopaint.schema import ModelType, ModelInfo
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
@ -15,7 +16,6 @@ from iopaint.const import (
ANYTEXT_NAME, ANYTEXT_NAME,
) )
from iopaint.model.original_sd_configs import get_config_files from iopaint.model.original_sd_configs import get_config_files
from iopaint.model_info import ModelInfo, ModelType
def cli_download_model(model: str): def cli_download_model(model: str):

View File

@ -1,103 +0,0 @@
from typing import List
from pydantic import computed_field, BaseModel
from iopaint.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
)
from iopaint.schema import ModelType
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if "sd2" in self.name.lower():
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
if self.name == POWERPAINT_NAME:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [INSTRUCT_PIX2PIX_NAME]

View File

@ -8,8 +8,7 @@ from iopaint.download import scan_models
from iopaint.helper import switch_mps_device from iopaint.helper import switch_mps_device
from iopaint.model import models, ControlNet, SD, SDXL from iopaint.model import models, ControlNet, SD, SDXL
from iopaint.model.utils import torch_gc, is_local_files_only from iopaint.model.utils import torch_gc, is_local_files_only
from iopaint.model_info import ModelInfo, ModelType from iopaint.schema import InpaintRequest, ModelInfo, ModelType
from iopaint.schema import InpaintRequest
class ModelManager: class ModelManager:

View File

@ -16,6 +16,7 @@ def build_plugins(
interactive_seg_model: InteractiveSegModel, interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device, interactive_seg_device: Device,
enable_remove_bg: bool, enable_remove_bg: bool,
remove_bg_model: str,
enable_anime_seg: bool, enable_anime_seg: bool,
enable_realesrgan: bool, enable_realesrgan: bool,
realesrgan_device: Device, realesrgan_device: Device,
@ -35,7 +36,7 @@ def build_plugins(
if enable_remove_bg: if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin") logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG() plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
if enable_anime_seg: if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin") logger.info(f"Initialize {AnimeSeg.name} plugin")

View File

@ -25,3 +25,6 @@ class BasePlugin:
def check_dep(self): def check_dep(self):
... ...
def switch_model(self, new_model_name: str):
...

515
iopaint/plugins/briarmbg.py Normal file
View File

@ -0,0 +1,515 @@
# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from torchvision.transforms.functional import normalize
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7, self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(
self,
in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
):
super(myrebnconv, self).__init__()
self.conv = nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self, x):
return self.rl(self.bn(self.conv(x)))
class BriaRMBG(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(BriaRMBG, self).__init__()
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage1 = RSU7(64, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self, x):
hx = x
hxin = self.conv_in(hx)
# hx = self.pool_in(hxin)
# stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1, x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, x)
return [
F.sigmoid(d1),
F.sigmoid(d2),
F.sigmoid(d3),
F.sigmoid(d4),
F.sigmoid(d5),
F.sigmoid(d6),
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
def resize_image(image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def create_briarmbg_session():
from huggingface_hub import hf_hub_download
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
return net
def briarmbg_process(bgr_np_image, session, only_mask=False):
# prepare input
orig_bgr_image = Image.fromarray(bgr_np_image)
w, h = orig_im_size = orig_bgr_image.size
image = resize_image(orig_bgr_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = session(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
mask = np.squeeze(im_array)
if only_mask:
return mask
pil_im = Image.fromarray(mask)
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_bgr_image, mask=pil_im)
rgba_np_img = np.asarray(new_im)
return rgba_np_img

View File

@ -1,8 +1,6 @@
import hashlib import hashlib
import json
from typing import List from typing import List
import cv2
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger

View File

@ -1,10 +1,11 @@
import os import os
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger
from torch.hub import get_dir from torch.hub import get_dir
from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest from iopaint.schema import RunPluginRequest, RemoveBGModel
class RemoveBG(BasePlugin): class RemoveBG(BasePlugin):
@ -12,32 +13,53 @@ class RemoveBG(BasePlugin):
support_gen_mask = True support_gen_mask = True
support_gen_image = True support_gen_image = True
def __init__(self): def __init__(self, model_name):
super().__init__() super().__init__()
from rembg import new_session self.model_name = model_name
hub_dir = get_dir() hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints") model_dir = os.path.join(hub_dir, "checkpoints")
os.environ["U2NET_HOME"] = model_dir os.environ["U2NET_HOME"] = model_dir
self.session = new_session(model_name="u2net") self._init_session(model_name)
def _init_session(self, model_name: str):
if model_name == RemoveBGModel.briaai_rmbg_1_4:
from iopaint.plugins.briarmbg import (
create_briarmbg_session,
briarmbg_process,
)
self.session = create_briarmbg_session()
self.remove = briarmbg_process
else:
from rembg import new_session, remove
self.session = new_session(model_name=model_name)
self.remove = remove
def switch_model(self, new_model_name):
if self.model_name == new_model_name:
return
logger.info(
f"Switching removebg model from {self.model_name} to {new_model_name}"
)
self._init_session(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image # return BGRA image
output = remove(bgr_np_img, session=self.session) output = self.remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background # return BGR image, 255 means foreground, 0 means background
output = remove(bgr_np_img, session=self.session, only_mask=True) output = self.remove(bgr_np_img, session=self.session, only_mask=True)
return output return output
def check_dep(self): def check_dep(self):

View File

@ -5,11 +5,11 @@ import sys
from pathlib import Path from pathlib import Path
import packaging.version import packaging.version
from iopaint.schema import Device
from loguru import logger from loguru import logger
from rich import print from rich import print
from typing import Dict, Any from typing import Dict, Any
from iopaint.const import Device
_PY_VERSION: str = sys.version.split()[0].rstrip("+") _PY_VERSION: str = sys.version.split()[0].rstrip("+")

View File

@ -1,11 +1,117 @@
import json
import random import random
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Literal, List from typing import Optional, Literal, List
from iopaint.const import (
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, computed_field
class ModelType(str, Enum):
INPAINT = "inpaint" # LaMa, MAT...
DIFFUSERS_SD = "diffusers_sd"
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
DIFFUSERS_SDXL = "diffusers_sdxl"
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
DIFFUSERS_OTHER = "diffusers_other"
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
POWERPAINT_NAME,
ANYTEXT_NAME,
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if "sd2" in self.name.lower():
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
if self.name == POWERPAINT_NAME:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [INSTRUCT_PIX2PIX_NAME]
class Choices(str, Enum): class Choices(str, Enum):
@ -20,6 +126,16 @@ class RealESRGANModel(Choices):
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
class RemoveBGModel(Choices):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
isnet_general_use = "isnet-general-use"
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
class Device(Choices): class Device(Choices):
cpu = "cpu" cpu = "cpu"
cuda = "cuda" cuda = "cuda"
@ -44,15 +160,6 @@ class CV2Flag(str, Enum):
INPAINT_TELEA = "INPAINT_TELEA" INPAINT_TELEA = "INPAINT_TELEA"
class ModelType(str, Enum):
INPAINT = "inpaint" # LaMa, MAT...
DIFFUSERS_SD = "diffusers_sd"
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
DIFFUSERS_SDXL = "diffusers_sdxl"
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
DIFFUSERS_OTHER = "diffusers_other"
class HDStrategy(str, Enum): class HDStrategy(str, Enum):
# Use original image size # Use original image size
ORIGINAL = "Original" ORIGINAL = "Original"
@ -124,6 +231,7 @@ class ApiConfig(BaseModel):
interactive_seg_model: InteractiveSegModel interactive_seg_model: InteractiveSegModel
interactive_seg_device: Device interactive_seg_device: Device
enable_remove_bg: bool enable_remove_bg: bool
remove_bg_model: str
enable_anime_seg: bool enable_anime_seg: bool
enable_realesrgan: bool enable_realesrgan: bool
realesrgan_device: Device realesrgan_device: Device
@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel):
class ServerConfigResponse(BaseModel): class ServerConfigResponse(BaseModel):
plugins: List[PluginInfo] plugins: List[PluginInfo]
modelInfos: List[ModelInfo]
removeBGModel: RemoveBGModel
removeBGModels: List[str]
enableFileManager: bool enableFileManager: bool
enableAutoSaving: bool enableAutoSaving: bool
enableControlnet: bool enableControlnet: bool
@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel):
name: str name: str
class SwitchPluginModelRequest(BaseModel):
plugin_name: str
model_name: str
AdjustMaskOperate = Literal["expand", "shrink", "reverse"] AdjustMaskOperate = Literal["expand", "shrink", "reverse"]

View File

@ -5,7 +5,7 @@ from PIL import Image
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
from iopaint.plugins.anime_seg import AnimeSeg from iopaint.plugins.anime_seg import AnimeSeg
from iopaint.schema import RunPluginRequest from iopaint.schema import RunPluginRequest, RemoveBGModel
from iopaint.tests.utils import check_device, current_dir, save_dir from iopaint.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -34,7 +34,7 @@ def _save(img, name):
def test_remove_bg(): def test_remove_bg():
model = RemoveBG() model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
rgba_np_img = model.gen_image( rgba_np_img = model.gen_image(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
) )

View File

@ -1,4 +1,14 @@
import json
import os import os
from pathlib import Path
from iopaint.schema import (
Device,
InteractiveSegModel,
RemoveBGModel,
RealESRGANModel,
ApiConfig,
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -15,6 +25,37 @@ from iopaint.const import *
_config_file: Path = None _config_file: Path = None
default_configs = dict(
host="127.0.0.1",
port=8080,
model=DEFAULT_MODEL,
model_dir=DEFAULT_MODEL_DIR,
no_half=False,
low_mem=False,
cpu_offload=False,
disable_nsfw_checker=False,
local_files_only=False,
cpu_textencoder=False,
device=Device.cuda,
input=None,
output_dir=None,
quality=95,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu,
enable_remove_bg=False,
remove_bg_model=RemoveBGModel.u2net,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device=Device.cpu,
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device=Device.cpu,
enable_restoreformer=False,
restoreformer_device=Device.cpu,
)
class WebConfig(ApiConfig): class WebConfig(ApiConfig):
model_dir: str = DEFAULT_MODEL_DIR model_dir: str = DEFAULT_MODEL_DIR
@ -50,6 +91,7 @@ def save_config(
interactive_seg_model, interactive_seg_model,
interactive_seg_device, interactive_seg_device,
enable_remove_bg, enable_remove_bg,
remove_bg_model,
enable_anime_seg, enable_anime_seg,
enable_realesrgan, enable_realesrgan,
realesrgan_device, realesrgan_device,
@ -115,7 +157,7 @@ def main(config_file: Path):
with gr.Row(): with gr.Row():
recommend_model = gr.Dropdown( recommend_model = gr.Dropdown(
["lama", "mat", "migan"] + DIFFUSION_MODELS, ["lama", "mat", "migan"] + DIFFUSION_MODELS,
label="Recommend Models", label="Recommended Models",
) )
downloaded_model = gr.Dropdown( downloaded_model = gr.Dropdown(
downloaded_models, label="Downloaded Models" downloaded_models, label="Downloaded Models"
@ -179,6 +221,11 @@ def main(config_file: Path):
enable_remove_bg = gr.Checkbox( enable_remove_bg = gr.Checkbox(
init_config.enable_remove_bg, label=REMOVE_BG_HELP init_config.enable_remove_bg, label=REMOVE_BG_HELP
) )
remove_bg_model = gr.Radio(
RemoveBGModel.values(),
label="Remove bg model",
value=init_config.remove_bg_model,
)
with gr.Row(): with gr.Row():
enable_anime_seg = gr.Checkbox( enable_anime_seg = gr.Checkbox(
init_config.enable_anime_seg, label=ANIMESEG_HELP init_config.enable_anime_seg, label=ANIMESEG_HELP
@ -241,6 +288,7 @@ def main(config_file: Path):
interactive_seg_model, interactive_seg_model,
interactive_seg_device, interactive_seg_device,
enable_remove_bg, enable_remove_bg,
remove_bg_model,
enable_anime_seg, enable_anime_seg,
enable_realesrgan, enable_realesrgan,
realesrgan_device, realesrgan_device,

View File

@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
import { useEffect, useState } from "react" import { useEffect, useState } from "react"
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils"
import { useQuery } from "@tanstack/react-query" import { useQuery } from "@tanstack/react-query"
import { fetchModelInfos, switchModel } from "@/lib/api" import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
import { ModelInfo } from "@/lib/types" import { ModelInfo, PluginName } from "@/lib/types"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { ScrollArea } from "./ui/scroll-area" import { ScrollArea } from "./ui/scroll-area"
import { useToast } from "./ui/use-toast" import { useToast } from "./ui/use-toast"
@ -39,6 +39,14 @@ import {
MODEL_TYPE_OTHER, MODEL_TYPE_OTHER,
} from "@/lib/const" } from "@/lib/const"
import useHotKey from "@/hooks/useHotkey" import useHotKey from "@/hooks/useHotkey"
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
SelectValue,
} from "./ui/select"
const formSchema = z.object({ const formSchema = z.object({
enableFileManager: z.boolean(), enableFileManager: z.boolean(),
@ -48,42 +56,45 @@ const formSchema = z.object({
enableManualInpainting: z.boolean(), enableManualInpainting: z.boolean(),
enableUploadMask: z.boolean(), enableUploadMask: z.boolean(),
enableAutoExtractPrompt: z.boolean(), enableAutoExtractPrompt: z.boolean(),
removeBGModel: z.string(),
}) })
const TAB_GENERAL = "General" const TAB_GENERAL = "General"
const TAB_MODEL = "Model" const TAB_MODEL = "Model"
const TAB_PLUGINS = "Plugins"
// const TAB_FILE_MANAGER = "File Manager" // const TAB_FILE_MANAGER = "File Manager"
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL] const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
export function SettingsDialog() { export function SettingsDialog() {
const [open, toggleOpen] = useToggle(false) const [open, toggleOpen] = useToggle(false)
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
const [tab, setTab] = useState(TAB_MODEL) const [tab, setTab] = useState(TAB_MODEL)
const [ const [
updateAppState, updateAppState,
settings, settings,
updateSettings, updateSettings,
fileManagerState, fileManagerState,
updateFileManagerState,
setAppModel, setAppModel,
setServerConfig,
] = useStore((state) => [ ] = useStore((state) => [
state.updateAppState, state.updateAppState,
state.settings, state.settings,
state.updateSettings, state.updateSettings,
state.fileManagerState, state.fileManagerState,
state.updateFileManagerState,
state.setModel, state.setModel,
state.setServerConfig,
]) ])
const { toast } = useToast() const { toast } = useToast()
const [model, setModel] = useState<ModelInfo>(settings.model) const [model, setModel] = useState<ModelInfo>(settings.model)
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
const openModelSwitching = modelSwitchingTexts.length > 0
useEffect(() => { useEffect(() => {
setModel(settings.model) setModel(settings.model)
}, [settings.model]) }, [settings.model])
const { data: modelInfos, status } = useQuery({ const { data: serverConfig, status } = useQuery({
queryKey: ["modelInfos"], queryKey: ["serverConfig"],
queryFn: fetchModelInfos, queryFn: getServerConfig,
}) })
// 1. Define your form. // 1. Define your form.
@ -96,9 +107,17 @@ export function SettingsDialog() {
enableAutoExtractPrompt: settings.enableAutoExtractPrompt, enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
inputDirectory: fileManagerState.inputDirectory, inputDirectory: fileManagerState.inputDirectory,
outputDirectory: fileManagerState.outputDirectory, outputDirectory: fileManagerState.outputDirectory,
removeBGModel: serverConfig?.removeBGModel,
}, },
}) })
useEffect(() => {
if (serverConfig) {
setServerConfig(serverConfig)
form.setValue("removeBGModel", serverConfig.removeBGModel)
}
}, [form, serverConfig])
async function onSubmit(values: z.infer<typeof formSchema>) { async function onSubmit(values: z.infer<typeof formSchema>) {
// Do something with the form values. ✅ This will be type-safe and validated. // Do something with the form values. ✅ This will be type-safe and validated.
updateSettings({ updateSettings({
@ -109,29 +128,67 @@ export function SettingsDialog() {
}) })
// TODO: validate input/output Directory // TODO: validate input/output Directory
updateFileManagerState({ // updateFileManagerState({
inputDirectory: values.inputDirectory, // inputDirectory: values.inputDirectory,
outputDirectory: values.outputDirectory, // outputDirectory: values.outputDirectory,
}) // })
if (model.name !== settings.model.name) {
toggleOpenModelSwitching() const shouldSwitchModel = model.name !== settings.model.name
updateAppState({ disableShortCuts: true }) const shouldSwitchRemoveBGModel =
try { serverConfig?.removeBGModel !== values.removeBGModel
const newModel = await switchModel(model.name) const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
toast({
title: `Switch to ${newModel.name} success`, if (showModelSwitching) {
}) const newModelSwitchingTexts: string[] = []
setAppModel(model) if (shouldSwitchModel) {
} catch (error: any) { newModelSwitchingTexts.push(
toast({ `Switching model from ${settings.model.name} to ${model.name}`
variant: "destructive", )
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
} finally {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: false })
} }
if (shouldSwitchRemoveBGModel) {
newModelSwitchingTexts.push(
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
)
}
setModelSwitchingTexts(newModelSwitchingTexts)
updateAppState({ disableShortCuts: true })
if (shouldSwitchModel) {
try {
const newModel = await switchModel(model.name)
toast({
title: `Switch to ${newModel.name} success`,
})
setAppModel(model)
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
}
}
if (shouldSwitchRemoveBGModel) {
try {
const res = await switchPluginModel(
PluginName.RemoveBG,
values.removeBGModel
)
if (res.status !== 200) {
throw new Error(res.statusText)
}
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch removebg model to ${model.name} failed: ${error}`,
})
}
}
setModelSwitchingTexts([])
updateAppState({ disableShortCuts: false })
} }
} }
@ -143,7 +200,17 @@ export function SettingsDialog() {
onSubmit(form.getValues()) onSubmit(form.getValues())
} }
}, },
[open, form, model] [open, form, model, serverConfig]
)
if (status !== "success") {
return <></>
}
const modelInfos = serverConfig.modelInfos
const plugins = serverConfig.plugins
const removeBGEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RemoveBG
) )
function onOpenChange(value: boolean) { function onOpenChange(value: boolean) {
@ -186,10 +253,6 @@ export function SettingsDialog() {
} }
function renderModelSettings() { function renderModelSettings() {
if (status !== "success") {
return <></>
}
let defaultTab = MODEL_TYPE_INPAINT let defaultTab = MODEL_TYPE_INPAINT
for (let info of modelInfos) { for (let info of modelInfos) {
if (model.name === info.name) { if (model.name === info.name) {
@ -356,6 +419,44 @@ export function SettingsDialog() {
) )
} }
function renderPluginsSettings() {
return (
<div className="space-y-4 w-[510px]">
<FormField
control={form.control}
name="removeBGModel"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Remove Background</FormLabel>
<FormDescription>Remove background model</FormDescription>
</div>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!removeBGEnabled}
>
<FormControl>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="Select removebg model" />
</SelectTrigger>
</FormControl>
<SelectContent align="end">
<SelectGroup>
{serverConfig?.removeBGModels.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormItem>
)}
/>
</div>
)
}
// function renderFileManagerSettings() { // function renderFileManagerSettings() {
// return ( // return (
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]"> // <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
@ -446,7 +547,9 @@ export function SettingsDialog() {
<span className="sr-only">Loading...</span> <span className="sr-only">Loading...</span>
</div> </div>
<div>Switching to {model.name}</div> {modelSwitchingTexts.map((text, index) => (
<div key={index}>{text}</div>
))}
</div> </div>
{/* </AlertDialogDescription> */} {/* </AlertDialogDescription> */}
</AlertDialogHeader> </AlertDialogHeader>
@ -473,6 +576,7 @@ export function SettingsDialog() {
<Button <Button
key={item} key={item}
variant="ghost" variant="ghost"
disabled={item === TAB_PLUGINS && !removeBGEnabled}
onClick={() => setTab(item)} onClick={() => setTab(item)}
className={cn( className={cn(
tab === item ? "bg-muted " : "hover:bg-muted", tab === item ? "bg-muted " : "hover:bg-muted",
@ -489,6 +593,7 @@ export function SettingsDialog() {
<form onSubmit={form.handleSubmit(onSubmit)}> <form onSubmit={form.handleSubmit(onSubmit)}>
{tab === TAB_MODEL ? renderModelSettings() : <></>} {tab === TAB_MODEL ? renderModelSettings() : <></>}
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>} {tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
{/* {tab === TAB_FILE_MANAGER ? ( {/* {tab === TAB_FILE_MANAGER ? (
renderFileManagerSettings() renderFileManagerSettings()
) : ( ) : (

View File

@ -12,7 +12,7 @@ export default function useInputImage() {
fetch(`${API_ENDPOINT}/inputimage`, { headers }) fetch(`${API_ENDPOINT}/inputimage`, { headers })
.then(async (res) => { .then(async (res) => {
if (!res.ok) { if (!res.ok) {
throw new Error("No input image found") return
} }
const filename = res.headers const filename = res.headers
.get("content-disposition") .get("content-disposition")

View File

@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
return res.data return res.data
} }
export async function switchPluginModel(
plugin_name: string,
model_name: string
) {
return api.post(`/switch_plugin_model`, { plugin_name, model_name })
}
export async function currentModel(): Promise<ModelInfo> { export async function currentModel(): Promise<ModelInfo> {
const res = await api.get("/model") const res = await api.get("/model")
return res.data return res.data
} }
export function fetchModelInfos(): Promise<ModelInfo[]> {
return api.get("/models").then((response) => response.data)
}
export async function runPlugin( export async function runPlugin(
genMask: boolean, genMask: boolean,
name: string, name: string,

View File

@ -14,6 +14,9 @@ export interface PluginInfo {
export interface ServerConfig { export interface ServerConfig {
plugins: PluginInfo[] plugins: PluginInfo[]
modelInfos: ModelInfo[]
removeBGModel: string
removeBGModels: string[]
enableFileManager: boolean enableFileManager: boolean
enableAutoSaving: boolean enableAutoSaving: boolean
enableControlnet: boolean enableControlnet: boolean