add remove bg model selection
This commit is contained in:
parent
cf9ceea4e6
commit
8060e16c70
@ -42,10 +42,10 @@ from iopaint.helper import (
|
||||
adjust_mask,
|
||||
)
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_info import ModelInfo
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.plugins import build_plugins
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.remove_bg import RemoveBG
|
||||
from iopaint.schema import (
|
||||
GenInfoResponse,
|
||||
ApiConfig,
|
||||
@ -56,6 +56,9 @@ from iopaint.schema import (
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
RemoveBGModel,
|
||||
SwitchPluginModelRequest,
|
||||
ModelInfo,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@ -154,11 +157,11 @@ class Api:
|
||||
# fmt: off
|
||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
|
||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
@ -175,9 +178,6 @@ class Api:
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def api_models(self) -> List[ModelInfo]:
|
||||
return self.model_manager.scan_models()
|
||||
|
||||
def api_current_model(self) -> ModelInfo:
|
||||
return self.model_manager.current_model
|
||||
|
||||
@ -187,16 +187,28 @@ class Api:
|
||||
self.model_manager.switch(req.name)
|
||||
return self.model_manager.current_model
|
||||
|
||||
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
||||
if req.plugin_name in self.plugins:
|
||||
self.plugins[req.plugin_name].switch_model(req.model_name)
|
||||
if req.plugin_name == RemoveBG.name:
|
||||
self.config.remove_bg_model = req.model_name
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
return ServerConfigResponse(
|
||||
plugins=[
|
||||
plugins = []
|
||||
for it in self.plugins.values():
|
||||
plugins.append(
|
||||
PluginInfo(
|
||||
name=it.name,
|
||||
support_gen_image=it.support_gen_image,
|
||||
support_gen_mask=it.support_gen_mask,
|
||||
)
|
||||
for it in self.plugins.values()
|
||||
],
|
||||
)
|
||||
|
||||
return ServerConfigResponse(
|
||||
plugins=plugins,
|
||||
modelInfos=self.model_manager.scan_models(),
|
||||
removeBGModel=self.config.remove_bg_model,
|
||||
removeBGModels=RemoveBGModel.values(),
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
@ -340,6 +352,7 @@ class Api:
|
||||
self.config.interactive_seg_model,
|
||||
self.config.interactive_seg_device,
|
||||
self.config.enable_remove_bg,
|
||||
self.config.remove_bg_model,
|
||||
self.config.enable_anime_seg,
|
||||
self.config.enable_realesrgan,
|
||||
self.config.realesrgan_device,
|
||||
|
@ -9,7 +9,7 @@ from typer_config import use_json_config
|
||||
|
||||
from iopaint.const import *
|
||||
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
|
||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel
|
||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
|
||||
|
||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||
|
||||
@ -127,6 +127,7 @@ def start(
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
|
||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||
enable_realesrgan: bool = Option(False),
|
||||
realesrgan_device: Device = Option(Device.cpu),
|
||||
@ -183,6 +184,7 @@ def start(
|
||||
interactive_seg_model=interactive_seg_model,
|
||||
interactive_seg_device=interactive_seg_device,
|
||||
enable_remove_bg=enable_remove_bg,
|
||||
remove_bg_model=remove_bg_model,
|
||||
enable_anime_seg=enable_anime_seg,
|
||||
enable_realesrgan=enable_realesrgan,
|
||||
realesrgan_device=realesrgan_device,
|
||||
|
@ -1,8 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
|
||||
from typing import List
|
||||
|
||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """
|
||||
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
||||
"""
|
||||
|
||||
SD_CONTROLNET_CHOICES = [
|
||||
SD_CONTROLNET_CHOICES: List[str] = [
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
# "lllyasviel/control_v11p_sd15_seg",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la
|
||||
|
||||
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
||||
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
||||
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
|
||||
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
|
||||
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
|
||||
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
|
||||
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
||||
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
|
||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
|
||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
||||
|
||||
default_configs = dict(
|
||||
host="127.0.0.1",
|
||||
port=8080,
|
||||
model=DEFAULT_MODEL,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
no_half=False,
|
||||
low_mem=False,
|
||||
cpu_offload=False,
|
||||
disable_nsfw_checker=False,
|
||||
local_files_only=False,
|
||||
cpu_textencoder=False,
|
||||
device=Device.cuda,
|
||||
input=None,
|
||||
output_dir=None,
|
||||
quality=95,
|
||||
enable_interactive_seg=False,
|
||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||
interactive_seg_device=Device.cpu,
|
||||
enable_remove_bg=False,
|
||||
enable_anime_seg=False,
|
||||
enable_realesrgan=False,
|
||||
realesrgan_device=Device.cpu,
|
||||
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
||||
enable_gfpgan=False,
|
||||
gfpgan_device=Device.cpu,
|
||||
enable_restoreformer=False,
|
||||
restoreformer_device=Device.cpu,
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from iopaint.schema import ModelType, ModelInfo
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
@ -15,7 +16,6 @@ from iopaint.const import (
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
from iopaint.model_info import ModelInfo, ModelType
|
||||
|
||||
|
||||
def cli_download_model(model: str):
|
||||
|
@ -1,103 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import computed_field, BaseModel
|
||||
|
||||
from iopaint.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.schema import ModelType
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def need_prompt(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def controlnets(self) -> List[str]:
|
||||
if self.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return SDXL_CONTROLNET_CHOICES
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||
if "sd2" in self.name.lower():
|
||||
return SD2_CONTROLNET_CHOICES
|
||||
else:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
if self.name == POWERPAINT_NAME:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_outpainting(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [INSTRUCT_PIX2PIX_NAME]
|
@ -8,8 +8,7 @@ from iopaint.download import scan_models
|
||||
from iopaint.helper import switch_mps_device
|
||||
from iopaint.model import models, ControlNet, SD, SDXL
|
||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||
from iopaint.model_info import ModelInfo, ModelType
|
||||
from iopaint.schema import InpaintRequest
|
||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
@ -16,6 +16,7 @@ def build_plugins(
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
enable_remove_bg: bool,
|
||||
remove_bg_model: str,
|
||||
enable_anime_seg: bool,
|
||||
enable_realesrgan: bool,
|
||||
realesrgan_device: Device,
|
||||
@ -35,7 +36,7 @@ def build_plugins(
|
||||
|
||||
if enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
|
||||
|
||||
if enable_anime_seg:
|
||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||
|
@ -25,3 +25,6 @@ class BasePlugin:
|
||||
|
||||
def check_dep(self):
|
||||
...
|
||||
|
||||
def switch_model(self, new_model_name: str):
|
||||
...
|
||||
|
515
iopaint/plugins/briarmbg.py
Normal file
515
iopaint/plugins/briarmbg.py
Normal file
@ -0,0 +1,515 @@
|
||||
# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
||||
super(REBNCONV, self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
||||
)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src, tar):
|
||||
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||
super(RSU7, self).__init__()
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.mid_ch = mid_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||
hx6dup = _upsample_like(hx6d, hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class myrebnconv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_ch=3,
|
||||
out_ch=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
):
|
||||
super(myrebnconv, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_ch)
|
||||
self.rl = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.rl(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class BriaRMBG(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super(BriaRMBG, self).__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
||||
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage1 = RSU7(64, 32, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 32, 128)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128, 64, 256)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256, 128, 512)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512, 256, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512, 256, 512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024, 256, 512)
|
||||
self.stage4d = RSU4(1024, 128, 256)
|
||||
self.stage3d = RSU5(512, 64, 128)
|
||||
self.stage2d = RSU6(256, 32, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
|
||||
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
# hx = self.pool_in(hxin)
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# -------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
d1 = _upsample_like(d1, x)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, x)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, x)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, x)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, x)
|
||||
|
||||
return [
|
||||
F.sigmoid(d1),
|
||||
F.sigmoid(d2),
|
||||
F.sigmoid(d3),
|
||||
F.sigmoid(d4),
|
||||
F.sigmoid(d5),
|
||||
F.sigmoid(d6),
|
||||
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
||||
|
||||
|
||||
def resize_image(image):
|
||||
image = image.convert("RGB")
|
||||
model_input_size = (1024, 1024)
|
||||
image = image.resize(model_input_size, Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
def create_briarmbg_session():
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
net = BriaRMBG()
|
||||
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
|
||||
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
|
||||
def briarmbg_process(bgr_np_image, session, only_mask=False):
|
||||
# prepare input
|
||||
orig_bgr_image = Image.fromarray(bgr_np_image)
|
||||
w, h = orig_im_size = orig_bgr_image.size
|
||||
image = resize_image(orig_bgr_image)
|
||||
im_np = np.array(image)
|
||||
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
|
||||
im_tensor = torch.unsqueeze(im_tensor, 0)
|
||||
im_tensor = torch.divide(im_tensor, 255.0)
|
||||
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
||||
if torch.cuda.is_available():
|
||||
im_tensor = im_tensor.cuda()
|
||||
|
||||
# inference
|
||||
result = session(im_tensor)
|
||||
# post process
|
||||
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
|
||||
ma = torch.max(result)
|
||||
mi = torch.min(result)
|
||||
result = (result - mi) / (ma - mi)
|
||||
# image to pil
|
||||
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
|
||||
|
||||
mask = np.squeeze(im_array)
|
||||
if only_mask:
|
||||
return mask
|
||||
|
||||
pil_im = Image.fromarray(mask)
|
||||
# paste the mask on the original image
|
||||
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
||||
new_im.paste(orig_bgr_image, mask=pil_im)
|
||||
rgba_np_img = np.asarray(new_im)
|
||||
return rgba_np_img
|
@ -1,8 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from torch.hub import get_dir
|
||||
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.schema import RunPluginRequest
|
||||
from iopaint.schema import RunPluginRequest, RemoveBGModel
|
||||
|
||||
|
||||
class RemoveBG(BasePlugin):
|
||||
@ -12,32 +13,53 @@ class RemoveBG(BasePlugin):
|
||||
support_gen_mask = True
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
from rembg import new_session
|
||||
self.model_name = model_name
|
||||
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
os.environ["U2NET_HOME"] = model_dir
|
||||
|
||||
self.session = new_session(model_name="u2net")
|
||||
self._init_session(model_name)
|
||||
|
||||
def _init_session(self, model_name: str):
|
||||
if model_name == RemoveBGModel.briaai_rmbg_1_4:
|
||||
from iopaint.plugins.briarmbg import (
|
||||
create_briarmbg_session,
|
||||
briarmbg_process,
|
||||
)
|
||||
|
||||
self.session = create_briarmbg_session()
|
||||
self.remove = briarmbg_process
|
||||
else:
|
||||
from rembg import new_session, remove
|
||||
|
||||
self.session = new_session(model_name=model_name)
|
||||
self.remove = remove
|
||||
|
||||
def switch_model(self, new_model_name):
|
||||
if self.model_name == new_model_name:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Switching removebg model from {self.model_name} to {new_model_name}"
|
||||
)
|
||||
self._init_session(new_model_name)
|
||||
self.model_name = new_model_name
|
||||
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
from rembg import remove
|
||||
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# return BGRA image
|
||||
output = remove(bgr_np_img, session=self.session)
|
||||
output = self.remove(bgr_np_img, session=self.session)
|
||||
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
from rembg import remove
|
||||
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# return BGR image, 255 means foreground, 0 means background
|
||||
output = remove(bgr_np_img, session=self.session, only_mask=True)
|
||||
output = self.remove(bgr_np_img, session=self.session, only_mask=True)
|
||||
return output
|
||||
|
||||
def check_dep(self):
|
||||
|
@ -5,11 +5,11 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
from iopaint.schema import Device
|
||||
from loguru import logger
|
||||
from rich import print
|
||||
from typing import Dict, Any
|
||||
|
||||
from iopaint.const import Device
|
||||
|
||||
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
|
||||
|
||||
|
@ -1,11 +1,117 @@
|
||||
import json
|
||||
import random
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
from iopaint.const import (
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
INPAINT = "inpaint" # LaMa, MAT...
|
||||
DIFFUSERS_SD = "diffusers_sd"
|
||||
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
|
||||
DIFFUSERS_SDXL = "diffusers_sdxl"
|
||||
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
||||
DIFFUSERS_OTHER = "diffusers_other"
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def need_prompt(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def controlnets(self) -> List[str]:
|
||||
if self.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return SDXL_CONTROLNET_CHOICES
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||
if "sd2" in self.name.lower():
|
||||
return SD2_CONTROLNET_CHOICES
|
||||
else:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
if self.name == POWERPAINT_NAME:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_outpainting(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [INSTRUCT_PIX2PIX_NAME]
|
||||
|
||||
|
||||
class Choices(str, Enum):
|
||||
@ -20,6 +126,16 @@ class RealESRGANModel(Choices):
|
||||
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
||||
|
||||
|
||||
class RemoveBGModel(Choices):
|
||||
u2net = "u2net"
|
||||
u2netp = "u2netp"
|
||||
u2net_human_seg = "u2net_human_seg"
|
||||
u2net_cloth_seg = "u2net_cloth_seg"
|
||||
silueta = "silueta"
|
||||
isnet_general_use = "isnet-general-use"
|
||||
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
|
||||
|
||||
|
||||
class Device(Choices):
|
||||
cpu = "cpu"
|
||||
cuda = "cuda"
|
||||
@ -44,15 +160,6 @@ class CV2Flag(str, Enum):
|
||||
INPAINT_TELEA = "INPAINT_TELEA"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
INPAINT = "inpaint" # LaMa, MAT...
|
||||
DIFFUSERS_SD = "diffusers_sd"
|
||||
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
|
||||
DIFFUSERS_SDXL = "diffusers_sdxl"
|
||||
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
||||
DIFFUSERS_OTHER = "diffusers_other"
|
||||
|
||||
|
||||
class HDStrategy(str, Enum):
|
||||
# Use original image size
|
||||
ORIGINAL = "Original"
|
||||
@ -124,6 +231,7 @@ class ApiConfig(BaseModel):
|
||||
interactive_seg_model: InteractiveSegModel
|
||||
interactive_seg_device: Device
|
||||
enable_remove_bg: bool
|
||||
remove_bg_model: str
|
||||
enable_anime_seg: bool
|
||||
enable_realesrgan: bool
|
||||
realesrgan_device: Device
|
||||
@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel):
|
||||
|
||||
class ServerConfigResponse(BaseModel):
|
||||
plugins: List[PluginInfo]
|
||||
modelInfos: List[ModelInfo]
|
||||
removeBGModel: RemoveBGModel
|
||||
removeBGModels: List[str]
|
||||
enableFileManager: bool
|
||||
enableAutoSaving: bool
|
||||
enableControlnet: bool
|
||||
@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class SwitchPluginModelRequest(BaseModel):
|
||||
plugin_name: str
|
||||
model_name: str
|
||||
|
||||
|
||||
AdjustMaskOperate = Literal["expand", "shrink", "reverse"]
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ from PIL import Image
|
||||
|
||||
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
|
||||
from iopaint.plugins.anime_seg import AnimeSeg
|
||||
from iopaint.schema import RunPluginRequest
|
||||
from iopaint.schema import RunPluginRequest, RemoveBGModel
|
||||
from iopaint.tests.utils import check_device, current_dir, save_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@ -34,7 +34,7 @@ def _save(img, name):
|
||||
|
||||
|
||||
def test_remove_bg():
|
||||
model = RemoveBG()
|
||||
model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
|
||||
rgba_np_img = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
|
@ -1,4 +1,14 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.schema import (
|
||||
Device,
|
||||
InteractiveSegModel,
|
||||
RemoveBGModel,
|
||||
RealESRGANModel,
|
||||
ApiConfig,
|
||||
)
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -15,6 +25,37 @@ from iopaint.const import *
|
||||
_config_file: Path = None
|
||||
|
||||
|
||||
default_configs = dict(
|
||||
host="127.0.0.1",
|
||||
port=8080,
|
||||
model=DEFAULT_MODEL,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
no_half=False,
|
||||
low_mem=False,
|
||||
cpu_offload=False,
|
||||
disable_nsfw_checker=False,
|
||||
local_files_only=False,
|
||||
cpu_textencoder=False,
|
||||
device=Device.cuda,
|
||||
input=None,
|
||||
output_dir=None,
|
||||
quality=95,
|
||||
enable_interactive_seg=False,
|
||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||
interactive_seg_device=Device.cpu,
|
||||
enable_remove_bg=False,
|
||||
remove_bg_model=RemoveBGModel.u2net,
|
||||
enable_anime_seg=False,
|
||||
enable_realesrgan=False,
|
||||
realesrgan_device=Device.cpu,
|
||||
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
||||
enable_gfpgan=False,
|
||||
gfpgan_device=Device.cpu,
|
||||
enable_restoreformer=False,
|
||||
restoreformer_device=Device.cpu,
|
||||
)
|
||||
|
||||
|
||||
class WebConfig(ApiConfig):
|
||||
model_dir: str = DEFAULT_MODEL_DIR
|
||||
|
||||
@ -50,6 +91,7 @@ def save_config(
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
@ -115,7 +157,7 @@ def main(config_file: Path):
|
||||
with gr.Row():
|
||||
recommend_model = gr.Dropdown(
|
||||
["lama", "mat", "migan"] + DIFFUSION_MODELS,
|
||||
label="Recommend Models",
|
||||
label="Recommended Models",
|
||||
)
|
||||
downloaded_model = gr.Dropdown(
|
||||
downloaded_models, label="Downloaded Models"
|
||||
@ -179,6 +221,11 @@ def main(config_file: Path):
|
||||
enable_remove_bg = gr.Checkbox(
|
||||
init_config.enable_remove_bg, label=REMOVE_BG_HELP
|
||||
)
|
||||
remove_bg_model = gr.Radio(
|
||||
RemoveBGModel.values(),
|
||||
label="Remove bg model",
|
||||
value=init_config.remove_bg_model,
|
||||
)
|
||||
with gr.Row():
|
||||
enable_anime_seg = gr.Checkbox(
|
||||
init_config.enable_anime_seg, label=ANIMESEG_HELP
|
||||
@ -241,6 +288,7 @@ def main(config_file: Path):
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
|
@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
|
||||
import { useEffect, useState } from "react"
|
||||
import { cn } from "@/lib/utils"
|
||||
import { useQuery } from "@tanstack/react-query"
|
||||
import { fetchModelInfos, switchModel } from "@/lib/api"
|
||||
import { ModelInfo } from "@/lib/types"
|
||||
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
|
||||
import { ModelInfo, PluginName } from "@/lib/types"
|
||||
import { useStore } from "@/lib/states"
|
||||
import { ScrollArea } from "./ui/scroll-area"
|
||||
import { useToast } from "./ui/use-toast"
|
||||
@ -39,6 +39,14 @@ import {
|
||||
MODEL_TYPE_OTHER,
|
||||
} from "@/lib/const"
|
||||
import useHotKey from "@/hooks/useHotkey"
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "./ui/select"
|
||||
|
||||
const formSchema = z.object({
|
||||
enableFileManager: z.boolean(),
|
||||
@ -48,42 +56,45 @@ const formSchema = z.object({
|
||||
enableManualInpainting: z.boolean(),
|
||||
enableUploadMask: z.boolean(),
|
||||
enableAutoExtractPrompt: z.boolean(),
|
||||
removeBGModel: z.string(),
|
||||
})
|
||||
|
||||
const TAB_GENERAL = "General"
|
||||
const TAB_MODEL = "Model"
|
||||
const TAB_PLUGINS = "Plugins"
|
||||
// const TAB_FILE_MANAGER = "File Manager"
|
||||
|
||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
|
||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
|
||||
|
||||
export function SettingsDialog() {
|
||||
const [open, toggleOpen] = useToggle(false)
|
||||
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
|
||||
const [tab, setTab] = useState(TAB_MODEL)
|
||||
const [
|
||||
updateAppState,
|
||||
settings,
|
||||
updateSettings,
|
||||
fileManagerState,
|
||||
updateFileManagerState,
|
||||
setAppModel,
|
||||
setServerConfig,
|
||||
] = useStore((state) => [
|
||||
state.updateAppState,
|
||||
state.settings,
|
||||
state.updateSettings,
|
||||
state.fileManagerState,
|
||||
state.updateFileManagerState,
|
||||
state.setModel,
|
||||
state.setServerConfig,
|
||||
])
|
||||
const { toast } = useToast()
|
||||
const [model, setModel] = useState<ModelInfo>(settings.model)
|
||||
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
|
||||
const openModelSwitching = modelSwitchingTexts.length > 0
|
||||
useEffect(() => {
|
||||
setModel(settings.model)
|
||||
}, [settings.model])
|
||||
|
||||
const { data: modelInfos, status } = useQuery({
|
||||
queryKey: ["modelInfos"],
|
||||
queryFn: fetchModelInfos,
|
||||
const { data: serverConfig, status } = useQuery({
|
||||
queryKey: ["serverConfig"],
|
||||
queryFn: getServerConfig,
|
||||
})
|
||||
|
||||
// 1. Define your form.
|
||||
@ -96,9 +107,17 @@ export function SettingsDialog() {
|
||||
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
|
||||
inputDirectory: fileManagerState.inputDirectory,
|
||||
outputDirectory: fileManagerState.outputDirectory,
|
||||
removeBGModel: serverConfig?.removeBGModel,
|
||||
},
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (serverConfig) {
|
||||
setServerConfig(serverConfig)
|
||||
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
||||
}
|
||||
}, [form, serverConfig])
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
// Do something with the form values. ✅ This will be type-safe and validated.
|
||||
updateSettings({
|
||||
@ -109,13 +128,33 @@ export function SettingsDialog() {
|
||||
})
|
||||
|
||||
// TODO: validate input/output Directory
|
||||
updateFileManagerState({
|
||||
inputDirectory: values.inputDirectory,
|
||||
outputDirectory: values.outputDirectory,
|
||||
})
|
||||
if (model.name !== settings.model.name) {
|
||||
toggleOpenModelSwitching()
|
||||
// updateFileManagerState({
|
||||
// inputDirectory: values.inputDirectory,
|
||||
// outputDirectory: values.outputDirectory,
|
||||
// })
|
||||
|
||||
const shouldSwitchModel = model.name !== settings.model.name
|
||||
const shouldSwitchRemoveBGModel =
|
||||
serverConfig?.removeBGModel !== values.removeBGModel
|
||||
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
|
||||
|
||||
if (showModelSwitching) {
|
||||
const newModelSwitchingTexts: string[] = []
|
||||
if (shouldSwitchModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching model from ${settings.model.name} to ${model.name}`
|
||||
)
|
||||
}
|
||||
if (shouldSwitchRemoveBGModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
|
||||
)
|
||||
}
|
||||
setModelSwitchingTexts(newModelSwitchingTexts)
|
||||
|
||||
updateAppState({ disableShortCuts: true })
|
||||
|
||||
if (shouldSwitchModel) {
|
||||
try {
|
||||
const newModel = await switchModel(model.name)
|
||||
toast({
|
||||
@ -128,11 +167,29 @@ export function SettingsDialog() {
|
||||
title: `Switch to ${model.name} failed: ${error}`,
|
||||
})
|
||||
setModel(settings.model)
|
||||
} finally {
|
||||
toggleOpenModelSwitching()
|
||||
updateAppState({ disableShortCuts: false })
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSwitchRemoveBGModel) {
|
||||
try {
|
||||
const res = await switchPluginModel(
|
||||
PluginName.RemoveBG,
|
||||
values.removeBGModel
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(res.statusText)
|
||||
}
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch removebg model to ${model.name} failed: ${error}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
setModelSwitchingTexts([])
|
||||
updateAppState({ disableShortCuts: false })
|
||||
}
|
||||
}
|
||||
|
||||
useHotKey(
|
||||
@ -143,7 +200,17 @@ export function SettingsDialog() {
|
||||
onSubmit(form.getValues())
|
||||
}
|
||||
},
|
||||
[open, form, model]
|
||||
[open, form, model, serverConfig]
|
||||
)
|
||||
|
||||
if (status !== "success") {
|
||||
return <></>
|
||||
}
|
||||
|
||||
const modelInfos = serverConfig.modelInfos
|
||||
const plugins = serverConfig.plugins
|
||||
const removeBGEnabled = plugins.some(
|
||||
(plugin) => plugin.name === PluginName.RemoveBG
|
||||
)
|
||||
|
||||
function onOpenChange(value: boolean) {
|
||||
@ -186,10 +253,6 @@ export function SettingsDialog() {
|
||||
}
|
||||
|
||||
function renderModelSettings() {
|
||||
if (status !== "success") {
|
||||
return <></>
|
||||
}
|
||||
|
||||
let defaultTab = MODEL_TYPE_INPAINT
|
||||
for (let info of modelInfos) {
|
||||
if (model.name === info.name) {
|
||||
@ -356,6 +419,44 @@ export function SettingsDialog() {
|
||||
)
|
||||
}
|
||||
|
||||
function renderPluginsSettings() {
|
||||
return (
|
||||
<div className="space-y-4 w-[510px]">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="removeBGModel"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center justify-between">
|
||||
<div className="space-y-0.5">
|
||||
<FormLabel>Remove Background</FormLabel>
|
||||
<FormDescription>Remove background model</FormDescription>
|
||||
</div>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
disabled={!removeBGEnabled}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-[200px]">
|
||||
<SelectValue placeholder="Select removebg model" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent align="end">
|
||||
<SelectGroup>
|
||||
{serverConfig?.removeBGModels.map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
{model}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
// function renderFileManagerSettings() {
|
||||
// return (
|
||||
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
|
||||
@ -446,7 +547,9 @@ export function SettingsDialog() {
|
||||
<span className="sr-only">Loading...</span>
|
||||
</div>
|
||||
|
||||
<div>Switching to {model.name}</div>
|
||||
{modelSwitchingTexts.map((text, index) => (
|
||||
<div key={index}>{text}</div>
|
||||
))}
|
||||
</div>
|
||||
{/* </AlertDialogDescription> */}
|
||||
</AlertDialogHeader>
|
||||
@ -473,6 +576,7 @@ export function SettingsDialog() {
|
||||
<Button
|
||||
key={item}
|
||||
variant="ghost"
|
||||
disabled={item === TAB_PLUGINS && !removeBGEnabled}
|
||||
onClick={() => setTab(item)}
|
||||
className={cn(
|
||||
tab === item ? "bg-muted " : "hover:bg-muted",
|
||||
@ -489,6 +593,7 @@ export function SettingsDialog() {
|
||||
<form onSubmit={form.handleSubmit(onSubmit)}>
|
||||
{tab === TAB_MODEL ? renderModelSettings() : <></>}
|
||||
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
|
||||
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
|
||||
{/* {tab === TAB_FILE_MANAGER ? (
|
||||
renderFileManagerSettings()
|
||||
) : (
|
||||
|
@ -12,7 +12,7 @@ export default function useInputImage() {
|
||||
fetch(`${API_ENDPOINT}/inputimage`, { headers })
|
||||
.then(async (res) => {
|
||||
if (!res.ok) {
|
||||
throw new Error("No input image found")
|
||||
return
|
||||
}
|
||||
const filename = res.headers
|
||||
.get("content-disposition")
|
||||
|
@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
|
||||
return res.data
|
||||
}
|
||||
|
||||
export async function switchPluginModel(
|
||||
plugin_name: string,
|
||||
model_name: string
|
||||
) {
|
||||
return api.post(`/switch_plugin_model`, { plugin_name, model_name })
|
||||
}
|
||||
|
||||
export async function currentModel(): Promise<ModelInfo> {
|
||||
const res = await api.get("/model")
|
||||
return res.data
|
||||
}
|
||||
|
||||
export function fetchModelInfos(): Promise<ModelInfo[]> {
|
||||
return api.get("/models").then((response) => response.data)
|
||||
}
|
||||
|
||||
export async function runPlugin(
|
||||
genMask: boolean,
|
||||
name: string,
|
||||
|
@ -14,6 +14,9 @@ export interface PluginInfo {
|
||||
|
||||
export interface ServerConfig {
|
||||
plugins: PluginInfo[]
|
||||
modelInfos: ModelInfo[]
|
||||
removeBGModel: string
|
||||
removeBGModels: string[]
|
||||
enableFileManager: boolean
|
||||
enableAutoSaving: boolean
|
||||
enableControlnet: boolean
|
||||
|
Loading…
Reference in New Issue
Block a user