make generate mask from RemoveBG && AnimeSeg work

This commit is contained in:
Qing 2024-01-02 22:32:40 +08:00
parent 6253016019
commit aca85543ca
22 changed files with 244 additions and 100 deletions

View File

@ -28,11 +28,13 @@ from lama_cleaner.helper import (
pil_to_bytes, pil_to_bytes,
numpy_to_bytes, numpy_to_bytes,
concat_alpha_channel, concat_alpha_channel,
gen_frontend_mask,
) )
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo from lama_cleaner.model_info import ModelInfo
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import ( from lama_cleaner.schema import (
GenInfoResponse, GenInfoResponse,
ApiConfig, ApiConfig,
@ -41,6 +43,7 @@ from lama_cleaner.schema import (
InpaintRequest, InpaintRequest,
RunPluginRequest, RunPluginRequest,
SDSampler, SDSampler,
PluginInfo,
) )
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
@ -145,7 +148,8 @@ class Api:
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on # fmt: on
@ -173,7 +177,14 @@ class Api:
def api_server_config(self) -> ServerConfigResponse: def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse( return ServerConfigResponse(
plugins=list(self.plugins.keys()), plugins=[
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
for it in self.plugins.values()
],
enableFileManager=self.file_manager is not None, enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None, enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet, enableControlnet=self.model_manager.enable_controlnet,
@ -237,22 +248,22 @@ class Api:
headers={"X-Seed": str(req.sd_seed)}, headers={"X-Seed": str(req.sd_seed)},
) )
def api_run_plugin(self, req: RunPluginRequest): def api_run_plugin_gen_image(self, req: RunPluginRequest):
ext = "png" ext = "png"
if req.name not in self.plugins: if req.name not in self.plugins:
raise HTTPException(status_code=404, detail="Plugin not found") raise HTTPException(status_code=422, detail="Plugin not found")
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) if not self.plugins[req.name].support_gen_image:
bgr_np_img = self.plugins[req.name](rgb_np_img, req) raise HTTPException(
torch_gc() status_code=422, detail="Plugin does not support output image"
if req.name == InteractiveSeg.name:
return Response(
content=numpy_to_bytes(bgr_np_img, ext),
media_type=f"image/{ext}",
) )
if bgr_np_img.shape[2] == 4: rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
rgba_np_img = bgr_np_img bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
torch_gc()
if bgr_or_rgba_np_img.shape[2] == 4:
rgba_np_img = bgr_or_rgba_np_img
else: else:
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB) rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
return Response( return Response(
@ -265,6 +276,22 @@ class Api:
media_type=f"image/{ext}", media_type=f"image/{ext}",
) )
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_mask:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
torch_gc()
res_mask = gen_frontend_mask(bgr_or_gray_mask)
return Response(
content=numpy_to_bytes(res_mask, "png"),
media_type="image/png",
)
def api_samplers(self) -> List[str]: def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()] return [member.value for member in SDSampler.__members__.values()]
@ -290,7 +317,7 @@ class Api:
) )
return None return None
def _build_plugins(self) -> Dict: def _build_plugins(self) -> Dict[str, BasePlugin]:
return build_plugins( return build_plugins(
self.config.enable_interactive_seg, self.config.enable_interactive_seg,
self.config.interactive_seg_model, self.config.interactive_seg_model,

View File

@ -350,3 +350,23 @@ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
) )
return rgb_np_img return rgb_np_img
def gen_frontend_mask(bgr_or_gray_mask):
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
# fronted brush color "ffcc00bb"
# TODO: how to set kernel size?
kernel_size = 9
bgr_or_gray_mask = cv2.dilate(
bgr_or_gray_mask,
np.ones((kernel_size, kernel_size), np.uint8),
iterations=1,
)
res_mask = np.zeros(
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
)
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask

View File

@ -1,4 +1,3 @@
from enum import Enum
from typing import List from typing import List
from pydantic import computed_field, BaseModel from pydantic import computed_field, BaseModel

View File

@ -1,11 +1,13 @@
from typing import Dict
from loguru import logger from loguru import logger
from .interactive_seg import InteractiveSeg
from .remove_bg import RemoveBG
from .realesrgan import RealESRGANUpscaler
from .gfpgan_plugin import GFPGANPlugin
from .restoreformer import RestoreFormerPlugin
from .anime_seg import AnimeSeg from .anime_seg import AnimeSeg
from .gfpgan_plugin import GFPGANPlugin
from .interactive_seg import InteractiveSeg
from .realesrgan import RealESRGANUpscaler
from .remove_bg import RemoveBG
from .restoreformer import RestoreFormerPlugin
from ..const import InteractiveSegModel, Device, RealESRGANModel from ..const import InteractiveSegModel, Device, RealESRGANModel
@ -23,7 +25,7 @@ def build_plugins(
enable_restoreformer: bool, enable_restoreformer: bool,
restoreformer_device: Device, restoreformer_device: Device,
no_half: bool, no_half: bool,
): ) -> Dict:
plugins = {} plugins = {}
if enable_interactive_seg: if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin") logger.info(f"Initialize {InteractiveSeg.name} plugin")

View File

@ -416,6 +416,8 @@ ANIME_SEG_MODELS = {
class AnimeSeg(BasePlugin): class AnimeSeg(BasePlugin):
# Model from: https://github.com/SkyTNT/anime-segmentation # Model from: https://github.com/SkyTNT/anime-segmentation
name = "AnimeSeg" name = "AnimeSeg"
support_gen_image = True
support_gen_mask = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin):
ANIME_SEG_MODELS["md5"], ANIME_SEG_MODELS["md5"],
) )
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
mask = self.forward(rgb_np_img)
mask = Image.fromarray(mask, mode="L")
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
empty = Image.new("RGBA", (w0, h0), 0)
img = Image.fromarray(rgb_np_img)
cutout = Image.composite(img, empty, mask)
return np.asarray(cutout)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
return self.forward(rgb_np_img) return self.forward(rgb_np_img)
@torch.no_grad() @torch.inference_mode()
def forward(self, rgb_np_img): def forward(self, rgb_np_img):
s = 1024 s = 1024
@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin):
mask = self.model(tmpImg) mask = self.model(tmpImg)
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0)) mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
mask = Image.fromarray((mask * 255).astype("uint8"), mode="L") return (mask * 255).astype("uint8")
empty = Image.new("RGBA", (w0, h0), 0)
img = Image.fromarray(rgb_np_img)
cutout = Image.composite(img, empty, mask)
return np.asarray(cutout)

View File

@ -5,15 +5,23 @@ from lama_cleaner.schema import RunPluginRequest
class BasePlugin: class BasePlugin:
name: str
support_gen_image: bool = False
support_gen_mask: bool = False
def __init__(self): def __init__(self):
err_msg = self.check_dep() err_msg = self.check_dep()
if err_msg: if err_msg:
logger.error(err_msg) logger.error(err_msg)
exit(-1) exit(-1)
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array: def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return RGBA np image or BGR np image # return RGBA np image or BGR np image
... ...
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return GRAY or BGR np image, 255 means foreground, 0 means background
...
def check_dep(self): def check_dep(self):
... ...

View File

@ -1,4 +1,5 @@
import cv2 import cv2
import numpy as np
from loguru import logger from loguru import logger
from lama_cleaner.helper import download_model from lama_cleaner.helper import download_model
@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
class GFPGANPlugin(BasePlugin): class GFPGANPlugin(BasePlugin):
name = "GFPGAN" name = "GFPGAN"
support_gen_image = True
def __init__(self, device, upscaler=None): def __init__(self, device, upscaler=None):
super().__init__() super().__init__()
@ -37,7 +39,7 @@ class GFPGANPlugin(BasePlugin):
self.face_enhancer.face_helper.face_det.to(device) self.face_enhancer.face_helper.face_det.to(device)
) )
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5 weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")

View File

@ -4,6 +4,7 @@ from typing import List
import cv2 import cv2
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from lama_cleaner.helper import download_model from lama_cleaner.helper import download_model
@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
class InteractiveSeg(BasePlugin): class InteractiveSeg(BasePlugin):
name = "InteractiveSeg" name = "InteractiveSeg"
support_gen_mask = True
def __init__(self, model_name, device): def __init__(self, model_name, device):
super().__init__() super().__init__()
@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
) )
self.prev_img_md5 = None self.prev_img_md5 = None
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5) return self.forward(rgb_np_img, req.clicks, img_md5)
@torch.inference_mode()
def forward(self, rgb_np_img, clicks: List[List], img_md5: str): def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
input_point = [] input_point = []
input_label = [] input_label = []
@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
multimask_output=False, multimask_output=False,
) )
mask = masks[0].astype(np.uint8) * 255 mask = masks[0].astype(np.uint8) * 255
# TODO: how to set kernel size? return mask
kernel_size = 9
mask = cv2.dilate(
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
)
# fronted brush color "ffcc00bb"
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask

View File

@ -1,6 +1,8 @@
from enum import Enum from enum import Enum
import cv2 import cv2
import numpy as np
import torch
from loguru import logger from loguru import logger
from lama_cleaner.const import RealESRGANModel from lama_cleaner.const import RealESRGANModel
@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest
class RealESRGANUpscaler(BasePlugin): class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN" name = "RealESRGAN"
support_gen_image = True
def __init__(self, name, device, no_half=False): def __init__(self, name, device, no_half=False):
super().__init__() super().__init__()
@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin):
device=device, device=device,
) )
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
result = self.forward(bgr_np_img, req.scale) result = self.forward(bgr_np_img, req.scale)
logger.info(f"RealESRGAN output shape: {result.shape}") logger.info(f"RealESRGAN output shape: {result.shape}")
return result return result
@torch.inference_mode()
def forward(self, bgr_np_img, scale: float): def forward(self, bgr_np_img, scale: float):
# 输出是 BGR # 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]

View File

@ -9,6 +9,8 @@ from lama_cleaner.schema import RunPluginRequest
class RemoveBG(BasePlugin): class RemoveBG(BasePlugin):
name = "RemoveBG" name = "RemoveBG"
support_gen_mask = True
support_gen_image = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -20,17 +22,24 @@ class RemoveBG(BasePlugin):
self.session = new_session(model_name="u2net") self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
return self.forward(bgr_np_img)
def forward(self, bgr_np_img) -> np.ndarray:
from rembg import remove from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image # return BGRA image
output = remove(bgr_np_img, session=self.session) output = remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background
output = remove(bgr_np_img, session=self.session, only_mask=True)
return output
def check_dep(self): def check_dep(self):
try: try:
import rembg import rembg

View File

@ -1,4 +1,5 @@
import cv2 import cv2
import numpy as np
from loguru import logger from loguru import logger
from lama_cleaner.helper import download_model from lama_cleaner.helper import download_model
@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
class RestoreFormerPlugin(BasePlugin): class RestoreFormerPlugin(BasePlugin):
name = "RestoreFormer" name = "RestoreFormer"
support_gen_image = True
def __init__(self, device, upscaler=None): def __init__(self, device, upscaler=None):
super().__init__() super().__init__()
@ -32,7 +34,7 @@ class RestoreFormerPlugin(BasePlugin):
bg_upsampler=upscaler.model if upscaler is not None else None, bg_upsampler=upscaler.model if upscaler is not None else None,
) )
def __call__(self, rgb_np_img, req: RunPluginRequest): def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5 weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")

View File

@ -3,12 +3,17 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Literal, List from typing import Optional, Literal, List
from PIL.Image import Image from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, validator, field_validator
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
class PluginInfo(BaseModel):
name: str
support_gen_image: bool = False
support_gen_mask: bool = False
class CV2Flag(str, Enum): class CV2Flag(str, Enum):
INPAINT_NS = "INPAINT_NS" INPAINT_NS = "INPAINT_NS"
INPAINT_TELEA = "INPAINT_TELEA" INPAINT_TELEA = "INPAINT_TELEA"
@ -272,7 +277,7 @@ class GenInfoResponse(BaseModel):
class ServerConfigResponse(BaseModel): class ServerConfigResponse(BaseModel):
plugins: List[str] plugins: List[PluginInfo]
enableFileManager: bool enableFileManager: bool
enableAutoSaving: bool enableAutoSaving: bool
enableControlnet: bool enableControlnet: bool

View File

@ -3,7 +3,7 @@ import os
import time import time
from PIL import Image from PIL import Image
from lama_cleaner.helper import encode_pil_to_base64 from lama_cleaner.helper import encode_pil_to_base64, gen_frontend_mask
from lama_cleaner.plugins.anime_seg import AnimeSeg from lama_cleaner.plugins.anime_seg import AnimeSeg
from lama_cleaner.schema import RunPluginRequest from lama_cleaner.schema import RunPluginRequest
from lama_cleaner.tests.utils import check_device, current_dir, save_dir from lama_cleaner.tests.utils import check_device, current_dir, save_dir
@ -35,34 +35,48 @@ def _save(img, name):
def test_remove_bg(): def test_remove_bg():
model = RemoveBG() model = RemoveBG()
rgba_np_img = model( rgba_np_img = model.gen_image(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
) )
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
_save(res, "test_remove_bg.png") _save(res, "test_remove_bg.png")
bgr_np_img = model.gen_mask(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res_mask = gen_frontend_mask(bgr_np_img)
_save(res_mask, "test_remove_bg_frontend_mask.png")
assert len(bgr_np_img.shape) == 2
_save(bgr_np_img, "test_remove_bg_mask.jpeg")
def test_anime_seg(): def test_anime_seg():
model = AnimeSeg() model = AnimeSeg()
img = cv2.imread(str(current_dir / "anime_test.png")) img = cv2.imread(str(current_dir / "anime_test.png"))
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
assert len(res.shape) == 3 assert len(res.shape) == 3
assert res.shape[-1] == 4 assert res.shape[-1] == 4
_save(res, "test_anime_seg.png") _save(res, "test_anime_seg.png")
res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
assert len(res.shape) == 2
_save(res, "test_anime_seg_mask.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_upscale(device): def test_upscale(device):
check_device(device) check_device(device)
model = RealESRGANUpscaler("realesr-general-x4v3", device) model = RealESRGANUpscaler("realesr-general-x4v3", device)
res = model( res = model.gen_image(
rgb_img, rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2), RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
) )
_save(res, f"test_upscale_x2_{device}.png") _save(res, f"test_upscale_x2_{device}.png")
res = model( res = model.gen_image(
rgb_img, rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4), RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
) )
@ -73,7 +87,9 @@ def test_upscale(device):
def test_gfpgan(device): def test_gfpgan(device):
check_device(device) check_device(device)
model = GFPGANPlugin(device) model = GFPGANPlugin(device)
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)) res = model.gen_image(
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
)
_save(res, f"test_gfpgan_{device}.png") _save(res, f"test_gfpgan_{device}.png")
@ -81,7 +97,7 @@ def test_gfpgan(device):
def test_restoreformer(device): def test_restoreformer(device):
check_device(device) check_device(device)
model = RestoreFormerPlugin(device) model = RestoreFormerPlugin(device)
res = model( res = model.gen_image(
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64) rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
) )
_save(res, f"test_restoreformer_{device}.png") _save(res, f"test_restoreformer_{device}.png")
@ -91,7 +107,7 @@ def test_restoreformer(device):
def test_segment_anything(device): def test_segment_anything(device):
check_device(device) check_device(device)
model = InteractiveSeg("vit_l", device) model = InteractiveSeg("vit_l", device)
new_mask = model( new_mask = model.gen_mask(
rgb_img, rgb_img,
RunPluginRequest( RunPluginRequest(
name=InteractiveSeg.name, name=InteractiveSeg.name,

View File

@ -3,7 +3,7 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lama Cleaner</title> <title>IOPaint</title>
</head> </head>
<body> <body>
<div id="root"></div> <div id="root"></div>

View File

@ -272,7 +272,7 @@ export default function Editor(props: EditorProps) {
console.log("[useEffect] centerView") console.log("[useEffect] centerView")
// render 改变尺寸以后undo/redo 重新 center // render 改变尺寸以后undo/redo 重新 center
viewportRef?.current?.centerView(minScale, 1) viewportRef?.current?.centerView(minScale, 1)
}, [context?.canvas.height, context?.canvas.width, viewportRef, minScale]) }, [imageHeight, imageWidth, viewportRef, minScale])
// Zoom reset // Zoom reset
const resetZoom = useCallback(() => { const resetZoom = useCallback(() => {
@ -358,6 +358,7 @@ export default function Editor(props: EditorProps) {
const targetFile = await getCurrentRender() const targetFile = await getCurrentRender()
try { try {
const res = await runPlugin( const res = await runPlugin(
true,
PluginName.InteractiveSeg, PluginName.InteractiveSeg,
targetFile, targetFile,
undefined, undefined,

View File

@ -16,6 +16,7 @@ import {
Smile, Smile,
} from "lucide-react" } from "lucide-react"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { PluginInfo } from "@/lib/types"
export enum PluginName { export enum PluginName {
RemoveBG = "RemoveBG", RemoveBG = "RemoveBG",
@ -26,6 +27,7 @@ export enum PluginName {
InteractiveSeg = "InteractiveSeg", InteractiveSeg = "InteractiveSeg",
} }
// TODO: get plugin config from server and using form-render??
const pluginMap = { const pluginMap = {
[PluginName.RemoveBG]: { [PluginName.RemoveBG]: {
IconClass: Slice, IconClass: Slice,
@ -37,7 +39,7 @@ const pluginMap = {
}, },
[PluginName.RealESRGAN]: { [PluginName.RealESRGAN]: {
IconClass: Fullscreen, IconClass: Fullscreen,
showName: "RealESRGAN 4x", showName: "RealESRGAN",
}, },
[PluginName.GFPGAN]: { [PluginName.GFPGAN]: {
IconClass: Smile, IconClass: Smile,
@ -67,11 +69,11 @@ const Plugins = () => {
return null return null
} }
const onPluginClick = (pluginName: string) => { const onPluginClick = (genMask: boolean, pluginName: string) => {
if (pluginName === PluginName.InteractiveSeg) { if (pluginName === PluginName.InteractiveSeg) {
updateInteractiveSegState({ isInteractiveSeg: true }) updateInteractiveSegState({ isInteractiveSeg: true })
} else { } else {
runRenderablePlugin(pluginName) runRenderablePlugin(genMask, pluginName)
} }
} }
@ -87,14 +89,14 @@ const Plugins = () => {
<DropdownMenuSubContent> <DropdownMenuSubContent>
<DropdownMenuItem <DropdownMenuItem
onClick={() => onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 }) runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 2 })
} }
> >
upscale 2x upscale 2x
</DropdownMenuItem> </DropdownMenuItem>
<DropdownMenuItem <DropdownMenuItem
onClick={() => onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 }) runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 4 })
} }
> >
upscale 4x upscale 4x
@ -104,16 +106,44 @@ const Plugins = () => {
) )
} }
const renderGenImageAndMaskPlugin = (plugin: PluginInfo) => {
const { IconClass, showName } = pluginMap[plugin.name as PluginName]
return (
<DropdownMenuSub key={plugin.name}>
<DropdownMenuSubTrigger disabled={disabled}>
<div className="flex gap-2 items-center">
<IconClass className="p-1" />
{showName}
</div>
</DropdownMenuSubTrigger>
<DropdownMenuSubContent>
<DropdownMenuItem onClick={() => onPluginClick(false, plugin.name)}>
Remove Background
</DropdownMenuItem>
<DropdownMenuItem onClick={() => onPluginClick(true, plugin.name)}>
Generate Mask
</DropdownMenuItem>
</DropdownMenuSubContent>
</DropdownMenuSub>
)
}
const renderPlugins = () => { const renderPlugins = () => {
return plugins.map((plugin: string) => { return plugins.map((plugin: PluginInfo) => {
const { IconClass, showName } = pluginMap[plugin as PluginName] const { IconClass, showName } = pluginMap[plugin.name as PluginName]
if (plugin === PluginName.RealESRGAN) { if (plugin.name === PluginName.RealESRGAN) {
return renderRealESRGANPlugin() return renderRealESRGANPlugin()
} }
if (
plugin.name === PluginName.RemoveBG ||
plugin.name === PluginName.AnimeSeg
) {
return renderGenImageAndMaskPlugin(plugin)
}
return ( return (
<DropdownMenuItem <DropdownMenuItem
key={plugin} key={plugin.name}
onClick={() => onPluginClick(plugin)} onClick={() => onPluginClick(false, plugin.name)}
disabled={disabled} disabled={disabled}
> >
<div className="flex gap-2 items-center"> <div className="flex gap-2 items-center">

View File

@ -64,11 +64,6 @@ export function Shortcuts() {
<ShortCut content="Decrease Brush Size" keys={["["]} /> <ShortCut content="Decrease Brush Size" keys={["["]} />
<ShortCut content="Increase Brush Size" keys={["]"]} /> <ShortCut content="Increase Brush Size" keys={["]"]} />
<ShortCut content="View Original Image" keys={["Hold Tab"]} /> <ShortCut content="View Original Image" keys={["Hold Tab"]} />
<ShortCut
content="Multi-Stroke Drawing"
keys={[`Hold ${CmdOrCtrl()}`]}
/>
<ShortCut content="Cancel Drawing" keys={["Esc"]} />
<ShortCut content="Undo" keys={[CmdOrCtrl(), "Z"]} /> <ShortCut content="Undo" keys={[CmdOrCtrl(), "Z"]} />
<ShortCut content="Redo" keys={[CmdOrCtrl(), "Shift", "Z"]} /> <ShortCut content="Redo" keys={[CmdOrCtrl(), "Shift", "Z"]} />

View File

@ -658,13 +658,13 @@ const DiffusionOptions = () => {
updateSettings({ sdSampler: value }) updateSettings({ sdSampler: value })
}} }}
> >
<SelectTrigger className="w-[180px]"> <SelectTrigger className="w-[175px] text-xs">
<SelectValue placeholder="Select sampler" /> <SelectValue placeholder="Select sampler" />
</SelectTrigger> </SelectTrigger>
<SelectContent align="end"> <SelectContent align="end">
<SelectGroup> <SelectGroup>
{samplers.map((sampler) => ( {samplers.map((sampler) => (
<SelectItem key={sampler} value={sampler}> <SelectItem key={sampler} value={sampler} className="text-xs">
{sampler} {sampler}
</SelectItem> </SelectItem>
))} ))}

View File

@ -17,6 +17,7 @@ const ToastViewport = React.forwardRef<
"fixed top-0 z-[100] flex max-h-screen w-full flex-col-reverse p-4 sm:bottom-0 sm:right-0 sm:top-auto sm:flex-col md:max-w-[420px]", "fixed top-0 z-[100] flex max-h-screen w-full flex-col-reverse p-4 sm:bottom-0 sm:right-0 sm:top-auto sm:flex-col md:max-w-[420px]",
className className
)} )}
tabIndex={-1}
{...props} {...props}
/> />
)) ))
@ -47,6 +48,7 @@ const Toast = React.forwardRef<
<ToastPrimitives.Root <ToastPrimitives.Root
ref={ref} ref={ref}
className={cn(toastVariants({ variant }), className)} className={cn(toastVariants({ variant }), className)}
tabIndex={-1}
{...props} {...props}
/> />
) )
@ -63,6 +65,7 @@ const ToastAction = React.forwardRef<
"inline-flex h-8 shrink-0 items-center justify-center rounded-md border bg-transparent px-3 text-sm font-medium transition-colors hover:bg-secondary focus:outline-none focus:ring-1 focus:ring-ring disabled:pointer-events-none disabled:opacity-50 group-[.destructive]:border-muted/40 group-[.destructive]:hover:border-destructive/30 group-[.destructive]:hover:bg-destructive group-[.destructive]:hover:text-destructive-foreground group-[.destructive]:focus:ring-destructive", "inline-flex h-8 shrink-0 items-center justify-center rounded-md border bg-transparent px-3 text-sm font-medium transition-colors hover:bg-secondary focus:outline-none focus:ring-1 focus:ring-ring disabled:pointer-events-none disabled:opacity-50 group-[.destructive]:border-muted/40 group-[.destructive]:hover:border-destructive/30 group-[.destructive]:hover:bg-destructive group-[.destructive]:hover:text-destructive-foreground group-[.destructive]:focus:ring-destructive",
className className
)} )}
tabIndex={-1}
{...props} {...props}
/> />
)) ))
@ -79,6 +82,7 @@ const ToastClose = React.forwardRef<
className className
)} )}
toast-close="" toast-close=""
tabIndex={-1}
{...props} {...props}
> >
<Cross2Icon className="h-4 w-4" /> <Cross2Icon className="h-4 w-4" />
@ -93,6 +97,7 @@ const ToastTitle = React.forwardRef<
<ToastPrimitives.Title <ToastPrimitives.Title
ref={ref} ref={ref}
className={cn("text-sm font-semibold [&+div]:text-xs", className)} className={cn("text-sm font-semibold [&+div]:text-xs", className)}
tabIndex={-1}
{...props} {...props}
/> />
)) ))
@ -106,6 +111,7 @@ const ToastDescription = React.forwardRef<
ref={ref} ref={ref}
className={cn("text-sm opacity-90", className)} className={cn("text-sm opacity-90", className)}
{...props} {...props}
tabIndex={-1}
/> />
)) ))
ToastDescription.displayName = ToastPrimitives.Description.displayName ToastDescription.displayName = ToastPrimitives.Description.displayName

View File

@ -114,13 +114,15 @@ export function fetchModelInfos(): Promise<ModelInfo[]> {
} }
export async function runPlugin( export async function runPlugin(
genMask: boolean,
name: string, name: string,
imageFile: File, imageFile: File,
upscale?: number, upscale?: number,
clicks?: number[][] clicks?: number[][]
) { ) {
const imageBase64 = await convertToBase64(imageFile) const imageBase64 = await convertToBase64(imageFile)
const res = await fetch(`${API_ENDPOINT}/run_plugin`, { const p = genMask ? "run_plugin_gen_mask" : "run_plugin_gen_image"
const res = await fetch(`${API_ENDPOINT}/${p}`, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -156,7 +156,6 @@ type AppAction = {
setFile: (file: File) => Promise<void> setFile: (file: File) => Promise<void>
setCustomFile: (file: File) => void setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
getIsProcessing: () => boolean getIsProcessing: () => boolean
setBaseBrushSize: (newValue: number) => void setBaseBrushSize: (newValue: number) => void
getBrushSize: () => number getBrushSize: () => number
@ -190,6 +189,7 @@ type AppAction = {
showPrevMask: () => Promise<void> showPrevMask: () => Promise<void>
hidePrevMask: () => void hidePrevMask: () => void
runRenderablePlugin: ( runRenderablePlugin: (
genMask: boolean,
pluginName: string, pluginName: string,
params?: PluginParams params?: PluginParams
) => Promise<void> ) => Promise<void>
@ -521,28 +521,43 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}, },
runRenderablePlugin: async ( runRenderablePlugin: async (
genMask: boolean,
pluginName: string, pluginName: string,
params: PluginParams = { upscale: 1 } params: PluginParams = { upscale: 1 }
) => { ) => {
const { renders, lineGroups } = get().editorState const { renders, lineGroups } = get().editorState
set((state) => { set((state) => {
state.isInpainting = true state.isPluginRunning = true
}) })
try { try {
const start = new Date() const start = new Date()
const targetFile = await get().getCurrentTargetFile() const targetFile = await get().getCurrentTargetFile()
const res = await runPlugin(pluginName, targetFile, params.upscale) const res = await runPlugin(
genMask,
pluginName,
targetFile,
params.upscale
)
const { blob } = res const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob) if (!genMask) {
get().setImageSize(newRender.width, newRender.height) const newRender = new Image()
const newRenders = [...renders, newRender] await loadImage(newRender, blob)
const newLineGroups = [...lineGroups, []] get().setImageSize(newRender.width, newRender.height)
get().updateEditorState({ const newRenders = [...renders, newRender]
renders: newRenders, const newLineGroups = [...lineGroups, []]
lineGroups: newLineGroups, get().updateEditorState({
}) renders: newRenders,
lineGroups: newLineGroups,
})
} else {
const newMask = new Image()
await loadImage(newMask, blob)
get().updateInteractiveSegState({
interactiveSegMask: newMask,
})
}
const end = new Date() const end = new Date()
const time = end.getTime() - start.getTime() const time = end.getTime() - start.getTime()
toast({ toast({
@ -555,7 +570,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}) })
} }
set((state) => { set((state) => {
state.isInpainting = false state.isPluginRunning = false
}) })
}, },
@ -803,11 +818,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.isInpainting = newValue state.isInpainting = newValue
}), }),
setIsPluginRunning: (newValue: boolean) =>
set((state) => {
state.isPluginRunning = newValue
}),
setFile: async (file: File) => { setFile: async (file: File) => {
if (get().settings.enableAutoExtractPrompt) { if (get().settings.enableAutoExtractPrompt) {
try { try {

View File

@ -6,8 +6,14 @@ export interface Filename {
mtime: number mtime: number
} }
export interface PluginInfo {
name: string
support_gen_image: boolean
support_gen_mask: boolean
}
export interface ServerConfig { export interface ServerConfig {
plugins: string[] plugins: PluginInfo[]
enableFileManager: boolean enableFileManager: boolean
enableAutoSaving: boolean enableAutoSaving: boolean
enableControlnet: boolean enableControlnet: boolean