add remove bg model selection
This commit is contained in:
parent
cf9ceea4e6
commit
8060e16c70
@ -42,10 +42,10 @@ from iopaint.helper import (
|
|||||||
adjust_mask,
|
adjust_mask,
|
||||||
)
|
)
|
||||||
from iopaint.model.utils import torch_gc
|
from iopaint.model.utils import torch_gc
|
||||||
from iopaint.model_info import ModelInfo
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.plugins import build_plugins
|
from iopaint.plugins import build_plugins
|
||||||
from iopaint.plugins.base_plugin import BasePlugin
|
from iopaint.plugins.base_plugin import BasePlugin
|
||||||
|
from iopaint.plugins.remove_bg import RemoveBG
|
||||||
from iopaint.schema import (
|
from iopaint.schema import (
|
||||||
GenInfoResponse,
|
GenInfoResponse,
|
||||||
ApiConfig,
|
ApiConfig,
|
||||||
@ -56,6 +56,9 @@ from iopaint.schema import (
|
|||||||
SDSampler,
|
SDSampler,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
AdjustMaskRequest,
|
AdjustMaskRequest,
|
||||||
|
RemoveBGModel,
|
||||||
|
SwitchPluginModelRequest,
|
||||||
|
ModelInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||||
@ -154,11 +157,11 @@ class Api:
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||||
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
|
|
||||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||||
|
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||||
@ -175,9 +178,6 @@ class Api:
|
|||||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||||
|
|
||||||
def api_models(self) -> List[ModelInfo]:
|
|
||||||
return self.model_manager.scan_models()
|
|
||||||
|
|
||||||
def api_current_model(self) -> ModelInfo:
|
def api_current_model(self) -> ModelInfo:
|
||||||
return self.model_manager.current_model
|
return self.model_manager.current_model
|
||||||
|
|
||||||
@ -187,16 +187,28 @@ class Api:
|
|||||||
self.model_manager.switch(req.name)
|
self.model_manager.switch(req.name)
|
||||||
return self.model_manager.current_model
|
return self.model_manager.current_model
|
||||||
|
|
||||||
|
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
||||||
|
if req.plugin_name in self.plugins:
|
||||||
|
self.plugins[req.plugin_name].switch_model(req.model_name)
|
||||||
|
if req.plugin_name == RemoveBG.name:
|
||||||
|
self.config.remove_bg_model = req.model_name
|
||||||
|
|
||||||
def api_server_config(self) -> ServerConfigResponse:
|
def api_server_config(self) -> ServerConfigResponse:
|
||||||
return ServerConfigResponse(
|
plugins = []
|
||||||
plugins=[
|
for it in self.plugins.values():
|
||||||
|
plugins.append(
|
||||||
PluginInfo(
|
PluginInfo(
|
||||||
name=it.name,
|
name=it.name,
|
||||||
support_gen_image=it.support_gen_image,
|
support_gen_image=it.support_gen_image,
|
||||||
support_gen_mask=it.support_gen_mask,
|
support_gen_mask=it.support_gen_mask,
|
||||||
)
|
)
|
||||||
for it in self.plugins.values()
|
)
|
||||||
],
|
|
||||||
|
return ServerConfigResponse(
|
||||||
|
plugins=plugins,
|
||||||
|
modelInfos=self.model_manager.scan_models(),
|
||||||
|
removeBGModel=self.config.remove_bg_model,
|
||||||
|
removeBGModels=RemoveBGModel.values(),
|
||||||
enableFileManager=self.file_manager is not None,
|
enableFileManager=self.file_manager is not None,
|
||||||
enableAutoSaving=self.config.output_dir is not None,
|
enableAutoSaving=self.config.output_dir is not None,
|
||||||
enableControlnet=self.model_manager.enable_controlnet,
|
enableControlnet=self.model_manager.enable_controlnet,
|
||||||
@ -340,6 +352,7 @@ class Api:
|
|||||||
self.config.interactive_seg_model,
|
self.config.interactive_seg_model,
|
||||||
self.config.interactive_seg_device,
|
self.config.interactive_seg_device,
|
||||||
self.config.enable_remove_bg,
|
self.config.enable_remove_bg,
|
||||||
|
self.config.remove_bg_model,
|
||||||
self.config.enable_anime_seg,
|
self.config.enable_anime_seg,
|
||||||
self.config.enable_realesrgan,
|
self.config.enable_realesrgan,
|
||||||
self.config.realesrgan_device,
|
self.config.realesrgan_device,
|
||||||
|
@ -9,7 +9,7 @@ from typer_config import use_json_config
|
|||||||
|
|
||||||
from iopaint.const import *
|
from iopaint.const import *
|
||||||
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
|
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
|
||||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel
|
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
|
||||||
|
|
||||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||||
|
|
||||||
@ -127,6 +127,7 @@ def start(
|
|||||||
),
|
),
|
||||||
interactive_seg_device: Device = Option(Device.cpu),
|
interactive_seg_device: Device = Option(Device.cpu),
|
||||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||||
|
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
|
||||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||||
enable_realesrgan: bool = Option(False),
|
enable_realesrgan: bool = Option(False),
|
||||||
realesrgan_device: Device = Option(Device.cpu),
|
realesrgan_device: Device = Option(Device.cpu),
|
||||||
@ -183,6 +184,7 @@ def start(
|
|||||||
interactive_seg_model=interactive_seg_model,
|
interactive_seg_model=interactive_seg_model,
|
||||||
interactive_seg_device=interactive_seg_device,
|
interactive_seg_device=interactive_seg_device,
|
||||||
enable_remove_bg=enable_remove_bg,
|
enable_remove_bg=enable_remove_bg,
|
||||||
|
remove_bg_model=remove_bg_model,
|
||||||
enable_anime_seg=enable_anime_seg,
|
enable_anime_seg=enable_anime_seg,
|
||||||
enable_realesrgan=enable_realesrgan,
|
enable_realesrgan=enable_realesrgan,
|
||||||
realesrgan_device=realesrgan_device,
|
realesrgan_device=realesrgan_device,
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from typing import List
|
||||||
|
|
||||||
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
|
|
||||||
|
|
||||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
||||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||||
@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """
|
|||||||
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SD_CONTROLNET_CHOICES = [
|
SD_CONTROLNET_CHOICES: List[str] = [
|
||||||
"lllyasviel/control_v11p_sd15_canny",
|
"lllyasviel/control_v11p_sd15_canny",
|
||||||
# "lllyasviel/control_v11p_sd15_seg",
|
# "lllyasviel/control_v11p_sd15_seg",
|
||||||
"lllyasviel/control_v11p_sd15_openpose",
|
"lllyasviel/control_v11p_sd15_openpose",
|
||||||
@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la
|
|||||||
|
|
||||||
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
||||||
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
||||||
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
|
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
|
||||||
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
|
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
|
||||||
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
||||||
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
|
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
|
||||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
|
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
|
||||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
||||||
|
|
||||||
default_configs = dict(
|
|
||||||
host="127.0.0.1",
|
|
||||||
port=8080,
|
|
||||||
model=DEFAULT_MODEL,
|
|
||||||
model_dir=DEFAULT_MODEL_DIR,
|
|
||||||
no_half=False,
|
|
||||||
low_mem=False,
|
|
||||||
cpu_offload=False,
|
|
||||||
disable_nsfw_checker=False,
|
|
||||||
local_files_only=False,
|
|
||||||
cpu_textencoder=False,
|
|
||||||
device=Device.cuda,
|
|
||||||
input=None,
|
|
||||||
output_dir=None,
|
|
||||||
quality=95,
|
|
||||||
enable_interactive_seg=False,
|
|
||||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
|
||||||
interactive_seg_device=Device.cpu,
|
|
||||||
enable_remove_bg=False,
|
|
||||||
enable_anime_seg=False,
|
|
||||||
enable_realesrgan=False,
|
|
||||||
realesrgan_device=Device.cpu,
|
|
||||||
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
|
||||||
enable_gfpgan=False,
|
|
||||||
gfpgan_device=Device.cpu,
|
|
||||||
enable_restoreformer=False,
|
|
||||||
restoreformer_device=Device.cpu,
|
|
||||||
)
|
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from iopaint.schema import ModelType, ModelInfo
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -15,7 +16,6 @@ from iopaint.const import (
|
|||||||
ANYTEXT_NAME,
|
ANYTEXT_NAME,
|
||||||
)
|
)
|
||||||
from iopaint.model.original_sd_configs import get_config_files
|
from iopaint.model.original_sd_configs import get_config_files
|
||||||
from iopaint.model_info import ModelInfo, ModelType
|
|
||||||
|
|
||||||
|
|
||||||
def cli_download_model(model: str):
|
def cli_download_model(model: str):
|
||||||
|
@ -1,103 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from pydantic import computed_field, BaseModel
|
|
||||||
|
|
||||||
from iopaint.const import (
|
|
||||||
SDXL_CONTROLNET_CHOICES,
|
|
||||||
SD2_CONTROLNET_CHOICES,
|
|
||||||
SD_CONTROLNET_CHOICES,
|
|
||||||
INSTRUCT_PIX2PIX_NAME,
|
|
||||||
KANDINSKY22_NAME,
|
|
||||||
POWERPAINT_NAME,
|
|
||||||
ANYTEXT_NAME,
|
|
||||||
)
|
|
||||||
from iopaint.schema import ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
model_type: ModelType
|
|
||||||
is_single_file_diffusers: bool = False
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def need_prompt(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [
|
|
||||||
INSTRUCT_PIX2PIX_NAME,
|
|
||||||
KANDINSKY22_NAME,
|
|
||||||
POWERPAINT_NAME,
|
|
||||||
ANYTEXT_NAME,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def controlnets(self) -> List[str]:
|
|
||||||
if self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]:
|
|
||||||
return SDXL_CONTROLNET_CHOICES
|
|
||||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
|
||||||
if "sd2" in self.name.lower():
|
|
||||||
return SD2_CONTROLNET_CHOICES
|
|
||||||
else:
|
|
||||||
return SD_CONTROLNET_CHOICES
|
|
||||||
if self.name == POWERPAINT_NAME:
|
|
||||||
return SD_CONTROLNET_CHOICES
|
|
||||||
return []
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_strength(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_outpainting(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_lcm_lora(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_controlnet(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_freeu(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [INSTRUCT_PIX2PIX_NAME]
|
|
@ -8,8 +8,7 @@ from iopaint.download import scan_models
|
|||||||
from iopaint.helper import switch_mps_device
|
from iopaint.helper import switch_mps_device
|
||||||
from iopaint.model import models, ControlNet, SD, SDXL
|
from iopaint.model import models, ControlNet, SD, SDXL
|
||||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||||
from iopaint.model_info import ModelInfo, ModelType
|
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||||
from iopaint.schema import InpaintRequest
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
@ -16,6 +16,7 @@ def build_plugins(
|
|||||||
interactive_seg_model: InteractiveSegModel,
|
interactive_seg_model: InteractiveSegModel,
|
||||||
interactive_seg_device: Device,
|
interactive_seg_device: Device,
|
||||||
enable_remove_bg: bool,
|
enable_remove_bg: bool,
|
||||||
|
remove_bg_model: str,
|
||||||
enable_anime_seg: bool,
|
enable_anime_seg: bool,
|
||||||
enable_realesrgan: bool,
|
enable_realesrgan: bool,
|
||||||
realesrgan_device: Device,
|
realesrgan_device: Device,
|
||||||
@ -35,7 +36,7 @@ def build_plugins(
|
|||||||
|
|
||||||
if enable_remove_bg:
|
if enable_remove_bg:
|
||||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||||
plugins[RemoveBG.name] = RemoveBG()
|
plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
|
||||||
|
|
||||||
if enable_anime_seg:
|
if enable_anime_seg:
|
||||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||||
|
@ -25,3 +25,6 @@ class BasePlugin:
|
|||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def switch_model(self, new_model_name: str):
|
||||||
|
...
|
||||||
|
515
iopaint/plugins/briarmbg.py
Normal file
515
iopaint/plugins/briarmbg.py
Normal file
@ -0,0 +1,515 @@
|
|||||||
|
# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
|
|
||||||
|
class REBNCONV(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
||||||
|
super(REBNCONV, self).__init__()
|
||||||
|
|
||||||
|
self.conv_s1 = nn.Conv2d(
|
||||||
|
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
||||||
|
)
|
||||||
|
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||||
|
self.relu_s1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||||
|
|
||||||
|
return xout
|
||||||
|
|
||||||
|
|
||||||
|
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||||
|
def _upsample_like(src, tar):
|
||||||
|
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-7 ###
|
||||||
|
class RSU7(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||||
|
super(RSU7, self).__init__()
|
||||||
|
|
||||||
|
self.in_ch = in_ch
|
||||||
|
self.mid_ch = mid_ch
|
||||||
|
self.out_ch = out_ch
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
hx = self.pool4(hx4)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx)
|
||||||
|
hx = self.pool5(hx5)
|
||||||
|
|
||||||
|
hx6 = self.rebnconv6(hx)
|
||||||
|
|
||||||
|
hx7 = self.rebnconv7(hx6)
|
||||||
|
|
||||||
|
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||||
|
hx6dup = _upsample_like(hx6d, hx5)
|
||||||
|
|
||||||
|
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||||
|
hx5dup = _upsample_like(hx5d, hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||||
|
hx4dup = _upsample_like(hx4d, hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||||
|
hx3dup = _upsample_like(hx3d, hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||||
|
hx2dup = _upsample_like(hx2d, hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-6 ###
|
||||||
|
class RSU6(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU6, self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
hx = self.pool4(hx4)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx)
|
||||||
|
|
||||||
|
hx6 = self.rebnconv6(hx5)
|
||||||
|
|
||||||
|
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||||
|
hx5dup = _upsample_like(hx5d, hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||||
|
hx4dup = _upsample_like(hx4d, hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||||
|
hx3dup = _upsample_like(hx3d, hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||||
|
hx2dup = _upsample_like(hx2d, hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-5 ###
|
||||||
|
class RSU5(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU5, self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||||
|
hx4dup = _upsample_like(hx4d, hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||||
|
hx3dup = _upsample_like(hx3d, hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||||
|
hx2dup = _upsample_like(hx2d, hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-4 ###
|
||||||
|
class RSU4(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU4, self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||||
|
hx3dup = _upsample_like(hx3d, hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||||
|
hx2dup = _upsample_like(hx2d, hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-4F ###
|
||||||
|
class RSU4F(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU4F, self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||||
|
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx2 = self.rebnconv2(hx1)
|
||||||
|
hx3 = self.rebnconv3(hx2)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
class myrebnconv(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_ch=3,
|
||||||
|
out_ch=1,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
):
|
||||||
|
super(myrebnconv, self).__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_ch,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm2d(out_ch)
|
||||||
|
self.rl = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.rl(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRMBG(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, out_ch=1):
|
||||||
|
super(BriaRMBG, self).__init__()
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
||||||
|
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage1 = RSU7(64, 32, 64)
|
||||||
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage2 = RSU6(64, 32, 128)
|
||||||
|
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage3 = RSU5(128, 64, 256)
|
||||||
|
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage4 = RSU4(256, 128, 512)
|
||||||
|
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage5 = RSU4F(512, 256, 512)
|
||||||
|
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage6 = RSU4F(512, 256, 512)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.stage5d = RSU4F(1024, 256, 512)
|
||||||
|
self.stage4d = RSU4(1024, 128, 256)
|
||||||
|
self.stage3d = RSU5(512, 64, 128)
|
||||||
|
self.stage2d = RSU6(256, 32, 64)
|
||||||
|
self.stage1d = RSU7(128, 16, 64)
|
||||||
|
|
||||||
|
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||||
|
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||||
|
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||||
|
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||||
|
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||||
|
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||||
|
|
||||||
|
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.conv_in(hx)
|
||||||
|
# hx = self.pool_in(hxin)
|
||||||
|
|
||||||
|
# stage 1
|
||||||
|
hx1 = self.stage1(hxin)
|
||||||
|
hx = self.pool12(hx1)
|
||||||
|
|
||||||
|
# stage 2
|
||||||
|
hx2 = self.stage2(hx)
|
||||||
|
hx = self.pool23(hx2)
|
||||||
|
|
||||||
|
# stage 3
|
||||||
|
hx3 = self.stage3(hx)
|
||||||
|
hx = self.pool34(hx3)
|
||||||
|
|
||||||
|
# stage 4
|
||||||
|
hx4 = self.stage4(hx)
|
||||||
|
hx = self.pool45(hx4)
|
||||||
|
|
||||||
|
# stage 5
|
||||||
|
hx5 = self.stage5(hx)
|
||||||
|
hx = self.pool56(hx5)
|
||||||
|
|
||||||
|
# stage 6
|
||||||
|
hx6 = self.stage6(hx)
|
||||||
|
hx6up = _upsample_like(hx6, hx5)
|
||||||
|
|
||||||
|
# -------------------- decoder --------------------
|
||||||
|
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||||
|
hx5dup = _upsample_like(hx5d, hx4)
|
||||||
|
|
||||||
|
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||||
|
hx4dup = _upsample_like(hx4d, hx3)
|
||||||
|
|
||||||
|
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||||
|
hx3dup = _upsample_like(hx3d, hx2)
|
||||||
|
|
||||||
|
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||||
|
hx2dup = _upsample_like(hx2d, hx1)
|
||||||
|
|
||||||
|
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||||
|
|
||||||
|
# side output
|
||||||
|
d1 = self.side1(hx1d)
|
||||||
|
d1 = _upsample_like(d1, x)
|
||||||
|
|
||||||
|
d2 = self.side2(hx2d)
|
||||||
|
d2 = _upsample_like(d2, x)
|
||||||
|
|
||||||
|
d3 = self.side3(hx3d)
|
||||||
|
d3 = _upsample_like(d3, x)
|
||||||
|
|
||||||
|
d4 = self.side4(hx4d)
|
||||||
|
d4 = _upsample_like(d4, x)
|
||||||
|
|
||||||
|
d5 = self.side5(hx5d)
|
||||||
|
d5 = _upsample_like(d5, x)
|
||||||
|
|
||||||
|
d6 = self.side6(hx6)
|
||||||
|
d6 = _upsample_like(d6, x)
|
||||||
|
|
||||||
|
return [
|
||||||
|
F.sigmoid(d1),
|
||||||
|
F.sigmoid(d2),
|
||||||
|
F.sigmoid(d3),
|
||||||
|
F.sigmoid(d4),
|
||||||
|
F.sigmoid(d5),
|
||||||
|
F.sigmoid(d6),
|
||||||
|
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
model_input_size = (1024, 1024)
|
||||||
|
image = image.resize(model_input_size, Image.BILINEAR)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def create_briarmbg_session():
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
net = BriaRMBG()
|
||||||
|
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
|
||||||
|
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||||
|
net.eval()
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def briarmbg_process(bgr_np_image, session, only_mask=False):
|
||||||
|
# prepare input
|
||||||
|
orig_bgr_image = Image.fromarray(bgr_np_image)
|
||||||
|
w, h = orig_im_size = orig_bgr_image.size
|
||||||
|
image = resize_image(orig_bgr_image)
|
||||||
|
im_np = np.array(image)
|
||||||
|
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
|
||||||
|
im_tensor = torch.unsqueeze(im_tensor, 0)
|
||||||
|
im_tensor = torch.divide(im_tensor, 255.0)
|
||||||
|
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
im_tensor = im_tensor.cuda()
|
||||||
|
|
||||||
|
# inference
|
||||||
|
result = session(im_tensor)
|
||||||
|
# post process
|
||||||
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
|
||||||
|
ma = torch.max(result)
|
||||||
|
mi = torch.min(result)
|
||||||
|
result = (result - mi) / (ma - mi)
|
||||||
|
# image to pil
|
||||||
|
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
|
||||||
|
|
||||||
|
mask = np.squeeze(im_array)
|
||||||
|
if only_mask:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
pil_im = Image.fromarray(mask)
|
||||||
|
# paste the mask on the original image
|
||||||
|
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
||||||
|
new_im.paste(orig_bgr_image, mask=pil_im)
|
||||||
|
rgba_np_img = np.asarray(new_im)
|
||||||
|
return rgba_np_img
|
@ -1,8 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
|
|
||||||
from iopaint.plugins.base_plugin import BasePlugin
|
from iopaint.plugins.base_plugin import BasePlugin
|
||||||
from iopaint.schema import RunPluginRequest
|
from iopaint.schema import RunPluginRequest, RemoveBGModel
|
||||||
|
|
||||||
|
|
||||||
class RemoveBG(BasePlugin):
|
class RemoveBG(BasePlugin):
|
||||||
@ -12,32 +13,53 @@ class RemoveBG(BasePlugin):
|
|||||||
support_gen_mask = True
|
support_gen_mask = True
|
||||||
support_gen_image = True
|
support_gen_image = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, model_name):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from rembg import new_session
|
self.model_name = model_name
|
||||||
|
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||||
os.environ["U2NET_HOME"] = model_dir
|
os.environ["U2NET_HOME"] = model_dir
|
||||||
|
|
||||||
self.session = new_session(model_name="u2net")
|
self._init_session(model_name)
|
||||||
|
|
||||||
|
def _init_session(self, model_name: str):
|
||||||
|
if model_name == RemoveBGModel.briaai_rmbg_1_4:
|
||||||
|
from iopaint.plugins.briarmbg import (
|
||||||
|
create_briarmbg_session,
|
||||||
|
briarmbg_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session = create_briarmbg_session()
|
||||||
|
self.remove = briarmbg_process
|
||||||
|
else:
|
||||||
|
from rembg import new_session, remove
|
||||||
|
|
||||||
|
self.session = new_session(model_name=model_name)
|
||||||
|
self.remove = remove
|
||||||
|
|
||||||
|
def switch_model(self, new_model_name):
|
||||||
|
if self.model_name == new_model_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Switching removebg model from {self.model_name} to {new_model_name}"
|
||||||
|
)
|
||||||
|
self._init_session(new_model_name)
|
||||||
|
self.model_name = new_model_name
|
||||||
|
|
||||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
from rembg import remove
|
|
||||||
|
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
# return BGRA image
|
# return BGRA image
|
||||||
output = remove(bgr_np_img, session=self.session)
|
output = self.remove(bgr_np_img, session=self.session)
|
||||||
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
||||||
|
|
||||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
from rembg import remove
|
|
||||||
|
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
# return BGR image, 255 means foreground, 0 means background
|
# return BGR image, 255 means foreground, 0 means background
|
||||||
output = remove(bgr_np_img, session=self.session, only_mask=True)
|
output = self.remove(bgr_np_img, session=self.session, only_mask=True)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
|
@ -5,11 +5,11 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
from iopaint.schema import Device
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from rich import print
|
from rich import print
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from iopaint.const import Device
|
|
||||||
|
|
||||||
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
|
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
|
||||||
|
|
||||||
|
@ -1,11 +1,117 @@
|
|||||||
import json
|
|
||||||
import random
|
import random
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, List
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
|
from iopaint.const import (
|
||||||
|
INSTRUCT_PIX2PIX_NAME,
|
||||||
|
KANDINSKY22_NAME,
|
||||||
|
POWERPAINT_NAME,
|
||||||
|
ANYTEXT_NAME,
|
||||||
|
SDXL_CONTROLNET_CHOICES,
|
||||||
|
SD2_CONTROLNET_CHOICES,
|
||||||
|
SD_CONTROLNET_CHOICES,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
INPAINT = "inpaint" # LaMa, MAT...
|
||||||
|
DIFFUSERS_SD = "diffusers_sd"
|
||||||
|
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
|
||||||
|
DIFFUSERS_SDXL = "diffusers_sdxl"
|
||||||
|
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
||||||
|
DIFFUSERS_OTHER = "diffusers_other"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
model_type: ModelType
|
||||||
|
is_single_file_diffusers: bool = False
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def need_prompt(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [
|
||||||
|
INSTRUCT_PIX2PIX_NAME,
|
||||||
|
KANDINSKY22_NAME,
|
||||||
|
POWERPAINT_NAME,
|
||||||
|
ANYTEXT_NAME,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def controlnets(self) -> List[str]:
|
||||||
|
if self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]:
|
||||||
|
return SDXL_CONTROLNET_CHOICES
|
||||||
|
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||||
|
if "sd2" in self.name.lower():
|
||||||
|
return SD2_CONTROLNET_CHOICES
|
||||||
|
else:
|
||||||
|
return SD_CONTROLNET_CHOICES
|
||||||
|
if self.name == POWERPAINT_NAME:
|
||||||
|
return SD_CONTROLNET_CHOICES
|
||||||
|
return []
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_strength(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_outpainting(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_lcm_lora(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_controlnet(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_freeu(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [INSTRUCT_PIX2PIX_NAME]
|
||||||
|
|
||||||
|
|
||||||
class Choices(str, Enum):
|
class Choices(str, Enum):
|
||||||
@ -20,6 +126,16 @@ class RealESRGANModel(Choices):
|
|||||||
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveBGModel(Choices):
|
||||||
|
u2net = "u2net"
|
||||||
|
u2netp = "u2netp"
|
||||||
|
u2net_human_seg = "u2net_human_seg"
|
||||||
|
u2net_cloth_seg = "u2net_cloth_seg"
|
||||||
|
silueta = "silueta"
|
||||||
|
isnet_general_use = "isnet-general-use"
|
||||||
|
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
|
||||||
|
|
||||||
|
|
||||||
class Device(Choices):
|
class Device(Choices):
|
||||||
cpu = "cpu"
|
cpu = "cpu"
|
||||||
cuda = "cuda"
|
cuda = "cuda"
|
||||||
@ -44,15 +160,6 @@ class CV2Flag(str, Enum):
|
|||||||
INPAINT_TELEA = "INPAINT_TELEA"
|
INPAINT_TELEA = "INPAINT_TELEA"
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
INPAINT = "inpaint" # LaMa, MAT...
|
|
||||||
DIFFUSERS_SD = "diffusers_sd"
|
|
||||||
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
|
|
||||||
DIFFUSERS_SDXL = "diffusers_sdxl"
|
|
||||||
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
|
||||||
DIFFUSERS_OTHER = "diffusers_other"
|
|
||||||
|
|
||||||
|
|
||||||
class HDStrategy(str, Enum):
|
class HDStrategy(str, Enum):
|
||||||
# Use original image size
|
# Use original image size
|
||||||
ORIGINAL = "Original"
|
ORIGINAL = "Original"
|
||||||
@ -124,6 +231,7 @@ class ApiConfig(BaseModel):
|
|||||||
interactive_seg_model: InteractiveSegModel
|
interactive_seg_model: InteractiveSegModel
|
||||||
interactive_seg_device: Device
|
interactive_seg_device: Device
|
||||||
enable_remove_bg: bool
|
enable_remove_bg: bool
|
||||||
|
remove_bg_model: str
|
||||||
enable_anime_seg: bool
|
enable_anime_seg: bool
|
||||||
enable_realesrgan: bool
|
enable_realesrgan: bool
|
||||||
realesrgan_device: Device
|
realesrgan_device: Device
|
||||||
@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel):
|
|||||||
|
|
||||||
class ServerConfigResponse(BaseModel):
|
class ServerConfigResponse(BaseModel):
|
||||||
plugins: List[PluginInfo]
|
plugins: List[PluginInfo]
|
||||||
|
modelInfos: List[ModelInfo]
|
||||||
|
removeBGModel: RemoveBGModel
|
||||||
|
removeBGModels: List[str]
|
||||||
enableFileManager: bool
|
enableFileManager: bool
|
||||||
enableAutoSaving: bool
|
enableAutoSaving: bool
|
||||||
enableControlnet: bool
|
enableControlnet: bool
|
||||||
@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchPluginModelRequest(BaseModel):
|
||||||
|
plugin_name: str
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
AdjustMaskOperate = Literal["expand", "shrink", "reverse"]
|
AdjustMaskOperate = Literal["expand", "shrink", "reverse"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
|
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
|
||||||
from iopaint.plugins.anime_seg import AnimeSeg
|
from iopaint.plugins.anime_seg import AnimeSeg
|
||||||
from iopaint.schema import RunPluginRequest
|
from iopaint.schema import RunPluginRequest, RemoveBGModel
|
||||||
from iopaint.tests.utils import check_device, current_dir, save_dir
|
from iopaint.tests.utils import check_device, current_dir, save_dir
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
@ -34,7 +34,7 @@ def _save(img, name):
|
|||||||
|
|
||||||
|
|
||||||
def test_remove_bg():
|
def test_remove_bg():
|
||||||
model = RemoveBG()
|
model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
|
||||||
rgba_np_img = model.gen_image(
|
rgba_np_img = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,14 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from iopaint.schema import (
|
||||||
|
Device,
|
||||||
|
InteractiveSegModel,
|
||||||
|
RemoveBGModel,
|
||||||
|
RealESRGANModel,
|
||||||
|
ApiConfig,
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
|
|
||||||
@ -15,6 +25,37 @@ from iopaint.const import *
|
|||||||
_config_file: Path = None
|
_config_file: Path = None
|
||||||
|
|
||||||
|
|
||||||
|
default_configs = dict(
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=8080,
|
||||||
|
model=DEFAULT_MODEL,
|
||||||
|
model_dir=DEFAULT_MODEL_DIR,
|
||||||
|
no_half=False,
|
||||||
|
low_mem=False,
|
||||||
|
cpu_offload=False,
|
||||||
|
disable_nsfw_checker=False,
|
||||||
|
local_files_only=False,
|
||||||
|
cpu_textencoder=False,
|
||||||
|
device=Device.cuda,
|
||||||
|
input=None,
|
||||||
|
output_dir=None,
|
||||||
|
quality=95,
|
||||||
|
enable_interactive_seg=False,
|
||||||
|
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||||
|
interactive_seg_device=Device.cpu,
|
||||||
|
enable_remove_bg=False,
|
||||||
|
remove_bg_model=RemoveBGModel.u2net,
|
||||||
|
enable_anime_seg=False,
|
||||||
|
enable_realesrgan=False,
|
||||||
|
realesrgan_device=Device.cpu,
|
||||||
|
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
||||||
|
enable_gfpgan=False,
|
||||||
|
gfpgan_device=Device.cpu,
|
||||||
|
enable_restoreformer=False,
|
||||||
|
restoreformer_device=Device.cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebConfig(ApiConfig):
|
class WebConfig(ApiConfig):
|
||||||
model_dir: str = DEFAULT_MODEL_DIR
|
model_dir: str = DEFAULT_MODEL_DIR
|
||||||
|
|
||||||
@ -50,6 +91,7 @@ def save_config(
|
|||||||
interactive_seg_model,
|
interactive_seg_model,
|
||||||
interactive_seg_device,
|
interactive_seg_device,
|
||||||
enable_remove_bg,
|
enable_remove_bg,
|
||||||
|
remove_bg_model,
|
||||||
enable_anime_seg,
|
enable_anime_seg,
|
||||||
enable_realesrgan,
|
enable_realesrgan,
|
||||||
realesrgan_device,
|
realesrgan_device,
|
||||||
@ -115,7 +157,7 @@ def main(config_file: Path):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
recommend_model = gr.Dropdown(
|
recommend_model = gr.Dropdown(
|
||||||
["lama", "mat", "migan"] + DIFFUSION_MODELS,
|
["lama", "mat", "migan"] + DIFFUSION_MODELS,
|
||||||
label="Recommend Models",
|
label="Recommended Models",
|
||||||
)
|
)
|
||||||
downloaded_model = gr.Dropdown(
|
downloaded_model = gr.Dropdown(
|
||||||
downloaded_models, label="Downloaded Models"
|
downloaded_models, label="Downloaded Models"
|
||||||
@ -179,6 +221,11 @@ def main(config_file: Path):
|
|||||||
enable_remove_bg = gr.Checkbox(
|
enable_remove_bg = gr.Checkbox(
|
||||||
init_config.enable_remove_bg, label=REMOVE_BG_HELP
|
init_config.enable_remove_bg, label=REMOVE_BG_HELP
|
||||||
)
|
)
|
||||||
|
remove_bg_model = gr.Radio(
|
||||||
|
RemoveBGModel.values(),
|
||||||
|
label="Remove bg model",
|
||||||
|
value=init_config.remove_bg_model,
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_anime_seg = gr.Checkbox(
|
enable_anime_seg = gr.Checkbox(
|
||||||
init_config.enable_anime_seg, label=ANIMESEG_HELP
|
init_config.enable_anime_seg, label=ANIMESEG_HELP
|
||||||
@ -241,6 +288,7 @@ def main(config_file: Path):
|
|||||||
interactive_seg_model,
|
interactive_seg_model,
|
||||||
interactive_seg_device,
|
interactive_seg_device,
|
||||||
enable_remove_bg,
|
enable_remove_bg,
|
||||||
|
remove_bg_model,
|
||||||
enable_anime_seg,
|
enable_anime_seg,
|
||||||
enable_realesrgan,
|
enable_realesrgan,
|
||||||
realesrgan_device,
|
realesrgan_device,
|
||||||
|
@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
|
|||||||
import { useEffect, useState } from "react"
|
import { useEffect, useState } from "react"
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils"
|
||||||
import { useQuery } from "@tanstack/react-query"
|
import { useQuery } from "@tanstack/react-query"
|
||||||
import { fetchModelInfos, switchModel } from "@/lib/api"
|
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
|
||||||
import { ModelInfo } from "@/lib/types"
|
import { ModelInfo, PluginName } from "@/lib/types"
|
||||||
import { useStore } from "@/lib/states"
|
import { useStore } from "@/lib/states"
|
||||||
import { ScrollArea } from "./ui/scroll-area"
|
import { ScrollArea } from "./ui/scroll-area"
|
||||||
import { useToast } from "./ui/use-toast"
|
import { useToast } from "./ui/use-toast"
|
||||||
@ -39,6 +39,14 @@ import {
|
|||||||
MODEL_TYPE_OTHER,
|
MODEL_TYPE_OTHER,
|
||||||
} from "@/lib/const"
|
} from "@/lib/const"
|
||||||
import useHotKey from "@/hooks/useHotkey"
|
import useHotKey from "@/hooks/useHotkey"
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectGroup,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "./ui/select"
|
||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
enableFileManager: z.boolean(),
|
enableFileManager: z.boolean(),
|
||||||
@ -48,42 +56,45 @@ const formSchema = z.object({
|
|||||||
enableManualInpainting: z.boolean(),
|
enableManualInpainting: z.boolean(),
|
||||||
enableUploadMask: z.boolean(),
|
enableUploadMask: z.boolean(),
|
||||||
enableAutoExtractPrompt: z.boolean(),
|
enableAutoExtractPrompt: z.boolean(),
|
||||||
|
removeBGModel: z.string(),
|
||||||
})
|
})
|
||||||
|
|
||||||
const TAB_GENERAL = "General"
|
const TAB_GENERAL = "General"
|
||||||
const TAB_MODEL = "Model"
|
const TAB_MODEL = "Model"
|
||||||
|
const TAB_PLUGINS = "Plugins"
|
||||||
// const TAB_FILE_MANAGER = "File Manager"
|
// const TAB_FILE_MANAGER = "File Manager"
|
||||||
|
|
||||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
|
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
|
||||||
|
|
||||||
export function SettingsDialog() {
|
export function SettingsDialog() {
|
||||||
const [open, toggleOpen] = useToggle(false)
|
const [open, toggleOpen] = useToggle(false)
|
||||||
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
|
|
||||||
const [tab, setTab] = useState(TAB_MODEL)
|
const [tab, setTab] = useState(TAB_MODEL)
|
||||||
const [
|
const [
|
||||||
updateAppState,
|
updateAppState,
|
||||||
settings,
|
settings,
|
||||||
updateSettings,
|
updateSettings,
|
||||||
fileManagerState,
|
fileManagerState,
|
||||||
updateFileManagerState,
|
|
||||||
setAppModel,
|
setAppModel,
|
||||||
|
setServerConfig,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
state.updateAppState,
|
state.updateAppState,
|
||||||
state.settings,
|
state.settings,
|
||||||
state.updateSettings,
|
state.updateSettings,
|
||||||
state.fileManagerState,
|
state.fileManagerState,
|
||||||
state.updateFileManagerState,
|
|
||||||
state.setModel,
|
state.setModel,
|
||||||
|
state.setServerConfig,
|
||||||
])
|
])
|
||||||
const { toast } = useToast()
|
const { toast } = useToast()
|
||||||
const [model, setModel] = useState<ModelInfo>(settings.model)
|
const [model, setModel] = useState<ModelInfo>(settings.model)
|
||||||
|
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
|
||||||
|
const openModelSwitching = modelSwitchingTexts.length > 0
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setModel(settings.model)
|
setModel(settings.model)
|
||||||
}, [settings.model])
|
}, [settings.model])
|
||||||
|
|
||||||
const { data: modelInfos, status } = useQuery({
|
const { data: serverConfig, status } = useQuery({
|
||||||
queryKey: ["modelInfos"],
|
queryKey: ["serverConfig"],
|
||||||
queryFn: fetchModelInfos,
|
queryFn: getServerConfig,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 1. Define your form.
|
// 1. Define your form.
|
||||||
@ -96,9 +107,17 @@ export function SettingsDialog() {
|
|||||||
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
|
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
|
||||||
inputDirectory: fileManagerState.inputDirectory,
|
inputDirectory: fileManagerState.inputDirectory,
|
||||||
outputDirectory: fileManagerState.outputDirectory,
|
outputDirectory: fileManagerState.outputDirectory,
|
||||||
|
removeBGModel: serverConfig?.removeBGModel,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (serverConfig) {
|
||||||
|
setServerConfig(serverConfig)
|
||||||
|
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
||||||
|
}
|
||||||
|
}, [form, serverConfig])
|
||||||
|
|
||||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||||
// Do something with the form values. ✅ This will be type-safe and validated.
|
// Do something with the form values. ✅ This will be type-safe and validated.
|
||||||
updateSettings({
|
updateSettings({
|
||||||
@ -109,29 +128,67 @@ export function SettingsDialog() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// TODO: validate input/output Directory
|
// TODO: validate input/output Directory
|
||||||
updateFileManagerState({
|
// updateFileManagerState({
|
||||||
inputDirectory: values.inputDirectory,
|
// inputDirectory: values.inputDirectory,
|
||||||
outputDirectory: values.outputDirectory,
|
// outputDirectory: values.outputDirectory,
|
||||||
})
|
// })
|
||||||
if (model.name !== settings.model.name) {
|
|
||||||
toggleOpenModelSwitching()
|
const shouldSwitchModel = model.name !== settings.model.name
|
||||||
updateAppState({ disableShortCuts: true })
|
const shouldSwitchRemoveBGModel =
|
||||||
try {
|
serverConfig?.removeBGModel !== values.removeBGModel
|
||||||
const newModel = await switchModel(model.name)
|
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
|
||||||
toast({
|
|
||||||
title: `Switch to ${newModel.name} success`,
|
if (showModelSwitching) {
|
||||||
})
|
const newModelSwitchingTexts: string[] = []
|
||||||
setAppModel(model)
|
if (shouldSwitchModel) {
|
||||||
} catch (error: any) {
|
newModelSwitchingTexts.push(
|
||||||
toast({
|
`Switching model from ${settings.model.name} to ${model.name}`
|
||||||
variant: "destructive",
|
)
|
||||||
title: `Switch to ${model.name} failed: ${error}`,
|
|
||||||
})
|
|
||||||
setModel(settings.model)
|
|
||||||
} finally {
|
|
||||||
toggleOpenModelSwitching()
|
|
||||||
updateAppState({ disableShortCuts: false })
|
|
||||||
}
|
}
|
||||||
|
if (shouldSwitchRemoveBGModel) {
|
||||||
|
newModelSwitchingTexts.push(
|
||||||
|
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
setModelSwitchingTexts(newModelSwitchingTexts)
|
||||||
|
|
||||||
|
updateAppState({ disableShortCuts: true })
|
||||||
|
|
||||||
|
if (shouldSwitchModel) {
|
||||||
|
try {
|
||||||
|
const newModel = await switchModel(model.name)
|
||||||
|
toast({
|
||||||
|
title: `Switch to ${newModel.name} success`,
|
||||||
|
})
|
||||||
|
setAppModel(model)
|
||||||
|
} catch (error: any) {
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
title: `Switch to ${model.name} failed: ${error}`,
|
||||||
|
})
|
||||||
|
setModel(settings.model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldSwitchRemoveBGModel) {
|
||||||
|
try {
|
||||||
|
const res = await switchPluginModel(
|
||||||
|
PluginName.RemoveBG,
|
||||||
|
values.removeBGModel
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(res.statusText)
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
title: `Switch removebg model to ${model.name} failed: ${error}`,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setModelSwitchingTexts([])
|
||||||
|
updateAppState({ disableShortCuts: false })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +200,17 @@ export function SettingsDialog() {
|
|||||||
onSubmit(form.getValues())
|
onSubmit(form.getValues())
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[open, form, model]
|
[open, form, model, serverConfig]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (status !== "success") {
|
||||||
|
return <></>
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelInfos = serverConfig.modelInfos
|
||||||
|
const plugins = serverConfig.plugins
|
||||||
|
const removeBGEnabled = plugins.some(
|
||||||
|
(plugin) => plugin.name === PluginName.RemoveBG
|
||||||
)
|
)
|
||||||
|
|
||||||
function onOpenChange(value: boolean) {
|
function onOpenChange(value: boolean) {
|
||||||
@ -186,10 +253,6 @@ export function SettingsDialog() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function renderModelSettings() {
|
function renderModelSettings() {
|
||||||
if (status !== "success") {
|
|
||||||
return <></>
|
|
||||||
}
|
|
||||||
|
|
||||||
let defaultTab = MODEL_TYPE_INPAINT
|
let defaultTab = MODEL_TYPE_INPAINT
|
||||||
for (let info of modelInfos) {
|
for (let info of modelInfos) {
|
||||||
if (model.name === info.name) {
|
if (model.name === info.name) {
|
||||||
@ -356,6 +419,44 @@ export function SettingsDialog() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function renderPluginsSettings() {
|
||||||
|
return (
|
||||||
|
<div className="space-y-4 w-[510px]">
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="removeBGModel"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem className="flex items-center justify-between">
|
||||||
|
<div className="space-y-0.5">
|
||||||
|
<FormLabel>Remove Background</FormLabel>
|
||||||
|
<FormDescription>Remove background model</FormDescription>
|
||||||
|
</div>
|
||||||
|
<Select
|
||||||
|
onValueChange={field.onChange}
|
||||||
|
defaultValue={field.value}
|
||||||
|
disabled={!removeBGEnabled}
|
||||||
|
>
|
||||||
|
<FormControl>
|
||||||
|
<SelectTrigger className="w-[200px]">
|
||||||
|
<SelectValue placeholder="Select removebg model" />
|
||||||
|
</SelectTrigger>
|
||||||
|
</FormControl>
|
||||||
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{serverConfig?.removeBGModels.map((model) => (
|
||||||
|
<SelectItem key={model} value={model}>
|
||||||
|
{model}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
// function renderFileManagerSettings() {
|
// function renderFileManagerSettings() {
|
||||||
// return (
|
// return (
|
||||||
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
|
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
|
||||||
@ -446,7 +547,9 @@ export function SettingsDialog() {
|
|||||||
<span className="sr-only">Loading...</span>
|
<span className="sr-only">Loading...</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>Switching to {model.name}</div>
|
{modelSwitchingTexts.map((text, index) => (
|
||||||
|
<div key={index}>{text}</div>
|
||||||
|
))}
|
||||||
</div>
|
</div>
|
||||||
{/* </AlertDialogDescription> */}
|
{/* </AlertDialogDescription> */}
|
||||||
</AlertDialogHeader>
|
</AlertDialogHeader>
|
||||||
@ -473,6 +576,7 @@ export function SettingsDialog() {
|
|||||||
<Button
|
<Button
|
||||||
key={item}
|
key={item}
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
|
disabled={item === TAB_PLUGINS && !removeBGEnabled}
|
||||||
onClick={() => setTab(item)}
|
onClick={() => setTab(item)}
|
||||||
className={cn(
|
className={cn(
|
||||||
tab === item ? "bg-muted " : "hover:bg-muted",
|
tab === item ? "bg-muted " : "hover:bg-muted",
|
||||||
@ -489,6 +593,7 @@ export function SettingsDialog() {
|
|||||||
<form onSubmit={form.handleSubmit(onSubmit)}>
|
<form onSubmit={form.handleSubmit(onSubmit)}>
|
||||||
{tab === TAB_MODEL ? renderModelSettings() : <></>}
|
{tab === TAB_MODEL ? renderModelSettings() : <></>}
|
||||||
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
|
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
|
||||||
|
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
|
||||||
{/* {tab === TAB_FILE_MANAGER ? (
|
{/* {tab === TAB_FILE_MANAGER ? (
|
||||||
renderFileManagerSettings()
|
renderFileManagerSettings()
|
||||||
) : (
|
) : (
|
||||||
|
@ -12,7 +12,7 @@ export default function useInputImage() {
|
|||||||
fetch(`${API_ENDPOINT}/inputimage`, { headers })
|
fetch(`${API_ENDPOINT}/inputimage`, { headers })
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
throw new Error("No input image found")
|
return
|
||||||
}
|
}
|
||||||
const filename = res.headers
|
const filename = res.headers
|
||||||
.get("content-disposition")
|
.get("content-disposition")
|
||||||
|
@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
|
|||||||
return res.data
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function switchPluginModel(
|
||||||
|
plugin_name: string,
|
||||||
|
model_name: string
|
||||||
|
) {
|
||||||
|
return api.post(`/switch_plugin_model`, { plugin_name, model_name })
|
||||||
|
}
|
||||||
|
|
||||||
export async function currentModel(): Promise<ModelInfo> {
|
export async function currentModel(): Promise<ModelInfo> {
|
||||||
const res = await api.get("/model")
|
const res = await api.get("/model")
|
||||||
return res.data
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
export function fetchModelInfos(): Promise<ModelInfo[]> {
|
|
||||||
return api.get("/models").then((response) => response.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function runPlugin(
|
export async function runPlugin(
|
||||||
genMask: boolean,
|
genMask: boolean,
|
||||||
name: string,
|
name: string,
|
||||||
|
@ -14,6 +14,9 @@ export interface PluginInfo {
|
|||||||
|
|
||||||
export interface ServerConfig {
|
export interface ServerConfig {
|
||||||
plugins: PluginInfo[]
|
plugins: PluginInfo[]
|
||||||
|
modelInfos: ModelInfo[]
|
||||||
|
removeBGModel: string
|
||||||
|
removeBGModels: string[]
|
||||||
enableFileManager: boolean
|
enableFileManager: boolean
|
||||||
enableAutoSaving: boolean
|
enableAutoSaving: boolean
|
||||||
enableControlnet: boolean
|
enableControlnet: boolean
|
||||||
|
Loading…
Reference in New Issue
Block a user