diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index b5e180c..a559326 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -28,11 +28,13 @@ from lama_cleaner.helper import ( pil_to_bytes, numpy_to_bytes, concat_alpha_channel, + gen_frontend_mask, ) from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_info import ModelInfo from lama_cleaner.model_manager import ModelManager from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg +from lama_cleaner.plugins.base_plugin import BasePlugin from lama_cleaner.schema import ( GenInfoResponse, ApiConfig, @@ -41,6 +43,7 @@ from lama_cleaner.schema import ( InpaintRequest, RunPluginRequest, SDSampler, + PluginInfo, ) from lama_cleaner.file_manager import FileManager @@ -145,7 +148,8 @@ class Api: self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) - self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"]) + self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) + self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on @@ -173,7 +177,14 @@ class Api: def api_server_config(self) -> ServerConfigResponse: return ServerConfigResponse( - plugins=list(self.plugins.keys()), + plugins=[ + PluginInfo( + name=it.name, + support_gen_image=it.support_gen_image, + support_gen_mask=it.support_gen_mask, + ) + for it in self.plugins.values() + ], enableFileManager=self.file_manager is not None, enableAutoSaving=self.config.output_dir is not None, enableControlnet=self.model_manager.enable_controlnet, @@ -237,22 +248,22 @@ class Api: headers={"X-Seed": str(req.sd_seed)}, ) - def api_run_plugin(self, req: RunPluginRequest): + def api_run_plugin_gen_image(self, req: RunPluginRequest): ext = "png" if req.name not in self.plugins: - raise HTTPException(status_code=404, detail="Plugin not found") - rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) - bgr_np_img = self.plugins[req.name](rgb_np_img, req) - torch_gc() - if req.name == InteractiveSeg.name: - return Response( - content=numpy_to_bytes(bgr_np_img, ext), - media_type=f"image/{ext}", + raise HTTPException(status_code=422, detail="Plugin not found") + if not self.plugins[req.name].support_gen_image: + raise HTTPException( + status_code=422, detail="Plugin does not support output image" ) - if bgr_np_img.shape[2] == 4: - rgba_np_img = bgr_np_img + rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) + torch_gc() + + if bgr_or_rgba_np_img.shape[2] == 4: + rgba_np_img = bgr_or_rgba_np_img else: - rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB) + rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) return Response( @@ -265,6 +276,22 @@ class Api: media_type=f"image/{ext}", ) + def api_run_plugin_gen_mask(self, req: RunPluginRequest): + if req.name not in self.plugins: + raise HTTPException(status_code=422, detail="Plugin not found") + if not self.plugins[req.name].support_gen_mask: + raise HTTPException( + status_code=422, detail="Plugin does not support output image" + ) + rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) + torch_gc() + res_mask = gen_frontend_mask(bgr_or_gray_mask) + return Response( + content=numpy_to_bytes(res_mask, "png"), + media_type="image/png", + ) + def api_samplers(self) -> List[str]: return [member.value for member in SDSampler.__members__.values()] @@ -290,7 +317,7 @@ class Api: ) return None - def _build_plugins(self) -> Dict: + def _build_plugins(self) -> Dict[str, BasePlugin]: return build_plugins( self.config.enable_interactive_seg, self.config.interactive_seg_model, diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 715ae51..00acc19 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -350,3 +350,23 @@ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray: (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 ) return rgb_np_img + + +def gen_frontend_mask(bgr_or_gray_mask): + if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1: + bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY) + + # fronted brush color "ffcc00bb" + # TODO: how to set kernel size? + kernel_size = 9 + bgr_or_gray_mask = cv2.dilate( + bgr_or_gray_mask, + np.ones((kernel_size, kernel_size), np.uint8), + iterations=1, + ) + res_mask = np.zeros( + (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8 + ) + res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)] + res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) + return res_mask diff --git a/lama_cleaner/model_info.py b/lama_cleaner/model_info.py index 199018b..d799b83 100644 --- a/lama_cleaner/model_info.py +++ b/lama_cleaner/model_info.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import List from pydantic import computed_field, BaseModel diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index 69ee2b5..fd321cd 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -1,11 +1,13 @@ +from typing import Dict + from loguru import logger -from .interactive_seg import InteractiveSeg -from .remove_bg import RemoveBG -from .realesrgan import RealESRGANUpscaler -from .gfpgan_plugin import GFPGANPlugin -from .restoreformer import RestoreFormerPlugin from .anime_seg import AnimeSeg +from .gfpgan_plugin import GFPGANPlugin +from .interactive_seg import InteractiveSeg +from .realesrgan import RealESRGANUpscaler +from .remove_bg import RemoveBG +from .restoreformer import RestoreFormerPlugin from ..const import InteractiveSegModel, Device, RealESRGANModel @@ -23,7 +25,7 @@ def build_plugins( enable_restoreformer: bool, restoreformer_device: Device, no_half: bool, -): +) -> Dict: plugins = {} if enable_interactive_seg: logger.info(f"Initialize {InteractiveSeg.name} plugin") diff --git a/lama_cleaner/plugins/anime_seg.py b/lama_cleaner/plugins/anime_seg.py index bf6bbf2..b4bb8b1 100644 --- a/lama_cleaner/plugins/anime_seg.py +++ b/lama_cleaner/plugins/anime_seg.py @@ -416,6 +416,8 @@ ANIME_SEG_MODELS = { class AnimeSeg(BasePlugin): # Model from: https://github.com/SkyTNT/anime-segmentation name = "AnimeSeg" + support_gen_image = True + support_gen_mask = True def __init__(self): super().__init__() @@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin): ANIME_SEG_MODELS["md5"], ) - def __call__(self, rgb_np_img, req: RunPluginRequest): + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + mask = self.forward(rgb_np_img) + mask = Image.fromarray(mask, mode="L") + h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1] + empty = Image.new("RGBA", (w0, h0), 0) + img = Image.fromarray(rgb_np_img) + cutout = Image.composite(img, empty, mask) + return np.asarray(cutout) + + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: return self.forward(rgb_np_img) - @torch.no_grad() + @torch.inference_mode() def forward(self, rgb_np_img): s = 1024 @@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin): mask = self.model(tmpImg) mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0)) - mask = Image.fromarray((mask * 255).astype("uint8"), mode="L") - - empty = Image.new("RGBA", (w0, h0), 0) - img = Image.fromarray(rgb_np_img) - cutout = Image.composite(img, empty, mask) - return np.asarray(cutout) + return (mask * 255).astype("uint8") diff --git a/lama_cleaner/plugins/base_plugin.py b/lama_cleaner/plugins/base_plugin.py index f57e0fc..fecc613 100644 --- a/lama_cleaner/plugins/base_plugin.py +++ b/lama_cleaner/plugins/base_plugin.py @@ -5,15 +5,23 @@ from lama_cleaner.schema import RunPluginRequest class BasePlugin: + name: str + support_gen_image: bool = False + support_gen_mask: bool = False + def __init__(self): err_msg = self.check_dep() if err_msg: logger.error(err_msg) exit(-1) - def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array: + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: # return RGBA np image or BGR np image ... + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + # return GRAY or BGR np image, 255 means foreground, 0 means background + ... + def check_dep(self): ... diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index 635bcf4..a94fba3 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -1,4 +1,5 @@ import cv2 +import numpy as np from loguru import logger from lama_cleaner.helper import download_model @@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest class GFPGANPlugin(BasePlugin): name = "GFPGAN" + support_gen_image = True def __init__(self, device, upscaler=None): super().__init__() @@ -37,7 +39,7 @@ class GFPGANPlugin(BasePlugin): self.face_enhancer.face_helper.face_det.to(device) ) - def __call__(self, rgb_np_img, req: RunPluginRequest): + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: weight = 0.5 bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") diff --git a/lama_cleaner/plugins/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py index 7df184d..6bb1899 100644 --- a/lama_cleaner/plugins/interactive_seg.py +++ b/lama_cleaner/plugins/interactive_seg.py @@ -4,6 +4,7 @@ from typing import List import cv2 import numpy as np +import torch from loguru import logger from lama_cleaner.helper import download_model @@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = { class InteractiveSeg(BasePlugin): name = "InteractiveSeg" + support_gen_mask = True def __init__(self, model_name, device): super().__init__() @@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin): ) self.prev_img_md5 = None - def __call__(self, rgb_np_img, req: RunPluginRequest): + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() return self.forward(rgb_np_img, req.clicks, img_md5) + @torch.inference_mode() def forward(self, rgb_np_img, clicks: List[List], img_md5: str): input_point = [] input_label = [] @@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin): multimask_output=False, ) mask = masks[0].astype(np.uint8) * 255 - # TODO: how to set kernel size? - kernel_size = 9 - mask = cv2.dilate( - mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1 - ) - # fronted brush color "ffcc00bb" - res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) - res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)] - res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) - return res_mask + return mask diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 3c796b3..d17942b 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -1,6 +1,8 @@ from enum import Enum import cv2 +import numpy as np +import torch from loguru import logger from lama_cleaner.const import RealESRGANModel @@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest class RealESRGANUpscaler(BasePlugin): name = "RealESRGAN" + support_gen_image = True def __init__(self, name, device, no_half=False): super().__init__() @@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin): device=device, ) - def __call__(self, rgb_np_img, req: RunPluginRequest): + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") result = self.forward(bgr_np_img, req.scale) logger.info(f"RealESRGAN output shape: {result.shape}") return result + @torch.inference_mode() def forward(self, bgr_np_img, scale: float): # 输出是 BGR upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py index 31fc565..a462af5 100644 --- a/lama_cleaner/plugins/remove_bg.py +++ b/lama_cleaner/plugins/remove_bg.py @@ -9,6 +9,8 @@ from lama_cleaner.schema import RunPluginRequest class RemoveBG(BasePlugin): name = "RemoveBG" + support_gen_mask = True + support_gen_image = True def __init__(self): super().__init__() @@ -20,17 +22,24 @@ class RemoveBG(BasePlugin): self.session = new_session(model_name="u2net") - def __call__(self, rgb_np_img, req: RunPluginRequest): - bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) - return self.forward(bgr_np_img) - - def forward(self, bgr_np_img) -> np.ndarray: + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: from rembg import remove + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + # return BGRA image output = remove(bgr_np_img, session=self.session) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + from rembg import remove + + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + + # return BGR image, 255 means foreground, 0 means background + output = remove(bgr_np_img, session=self.session, only_mask=True) + return output + def check_dep(self): try: import rembg diff --git a/lama_cleaner/plugins/restoreformer.py b/lama_cleaner/plugins/restoreformer.py index 3adc441..39592af 100644 --- a/lama_cleaner/plugins/restoreformer.py +++ b/lama_cleaner/plugins/restoreformer.py @@ -1,4 +1,5 @@ import cv2 +import numpy as np from loguru import logger from lama_cleaner.helper import download_model @@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest class RestoreFormerPlugin(BasePlugin): name = "RestoreFormer" + support_gen_image = True def __init__(self, device, upscaler=None): super().__init__() @@ -32,7 +34,7 @@ class RestoreFormerPlugin(BasePlugin): bg_upsampler=upscaler.model if upscaler is not None else None, ) - def __call__(self, rgb_np_img, req: RunPluginRequest): + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: weight = 0.5 bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index c347796..626e897 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -3,12 +3,17 @@ from enum import Enum from pathlib import Path from typing import Optional, Literal, List -from PIL.Image import Image -from pydantic import BaseModel, Field, validator, field_validator +from pydantic import BaseModel, Field, field_validator from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel +class PluginInfo(BaseModel): + name: str + support_gen_image: bool = False + support_gen_mask: bool = False + + class CV2Flag(str, Enum): INPAINT_NS = "INPAINT_NS" INPAINT_TELEA = "INPAINT_TELEA" @@ -272,7 +277,7 @@ class GenInfoResponse(BaseModel): class ServerConfigResponse(BaseModel): - plugins: List[str] + plugins: List[PluginInfo] enableFileManager: bool enableAutoSaving: bool enableControlnet: bool diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index 721ad51..8f7e8ea 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -3,7 +3,7 @@ import os import time from PIL import Image -from lama_cleaner.helper import encode_pil_to_base64 +from lama_cleaner.helper import encode_pil_to_base64, gen_frontend_mask from lama_cleaner.plugins.anime_seg import AnimeSeg from lama_cleaner.schema import RunPluginRequest from lama_cleaner.tests.utils import check_device, current_dir, save_dir @@ -35,34 +35,48 @@ def _save(img, name): def test_remove_bg(): model = RemoveBG() - rgba_np_img = model( + rgba_np_img = model.gen_image( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) _save(res, "test_remove_bg.png") + bgr_np_img = model.gen_mask( + rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) + ) + + res_mask = gen_frontend_mask(bgr_np_img) + _save(res_mask, "test_remove_bg_frontend_mask.png") + + assert len(bgr_np_img.shape) == 2 + _save(bgr_np_img, "test_remove_bg_mask.jpeg") + def test_anime_seg(): model = AnimeSeg() img = cv2.imread(str(current_dir / "anime_test.png")) img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) - res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) + res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) assert len(res.shape) == 3 assert res.shape[-1] == 4 _save(res, "test_anime_seg.png") + res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) + assert len(res.shape) == 2 + _save(res, "test_anime_seg_mask.png") + @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) def test_upscale(device): check_device(device) model = RealESRGANUpscaler("realesr-general-x4v3", device) - res = model( + res = model.gen_image( rgb_img, RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2), ) _save(res, f"test_upscale_x2_{device}.png") - res = model( + res = model.gen_image( rgb_img, RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4), ) @@ -73,7 +87,9 @@ def test_upscale(device): def test_gfpgan(device): check_device(device) model = GFPGANPlugin(device) - res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)) + res = model.gen_image( + rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64) + ) _save(res, f"test_gfpgan_{device}.png") @@ -81,7 +97,7 @@ def test_gfpgan(device): def test_restoreformer(device): check_device(device) model = RestoreFormerPlugin(device) - res = model( + res = model.gen_image( rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64) ) _save(res, f"test_restoreformer_{device}.png") @@ -91,7 +107,7 @@ def test_restoreformer(device): def test_segment_anything(device): check_device(device) model = InteractiveSeg("vit_l", device) - new_mask = model( + new_mask = model.gen_mask( rgb_img, RunPluginRequest( name=InteractiveSeg.name, diff --git a/web_app/index.html b/web_app/index.html index e92402c..01ffd1f 100644 --- a/web_app/index.html +++ b/web_app/index.html @@ -3,7 +3,7 @@ - Lama Cleaner + IOPaint
diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 5393593..0767d82 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -272,7 +272,7 @@ export default function Editor(props: EditorProps) { console.log("[useEffect] centerView") // render 改变尺寸以后,undo/redo 重新 center viewportRef?.current?.centerView(minScale, 1) - }, [context?.canvas.height, context?.canvas.width, viewportRef, minScale]) + }, [imageHeight, imageWidth, viewportRef, minScale]) // Zoom reset const resetZoom = useCallback(() => { @@ -358,6 +358,7 @@ export default function Editor(props: EditorProps) { const targetFile = await getCurrentRender() try { const res = await runPlugin( + true, PluginName.InteractiveSeg, targetFile, undefined, diff --git a/web_app/src/components/Plugins.tsx b/web_app/src/components/Plugins.tsx index 30c0fb1..40c3225 100644 --- a/web_app/src/components/Plugins.tsx +++ b/web_app/src/components/Plugins.tsx @@ -16,6 +16,7 @@ import { Smile, } from "lucide-react" import { useStore } from "@/lib/states" +import { PluginInfo } from "@/lib/types" export enum PluginName { RemoveBG = "RemoveBG", @@ -26,6 +27,7 @@ export enum PluginName { InteractiveSeg = "InteractiveSeg", } +// TODO: get plugin config from server and using form-render?? const pluginMap = { [PluginName.RemoveBG]: { IconClass: Slice, @@ -37,7 +39,7 @@ const pluginMap = { }, [PluginName.RealESRGAN]: { IconClass: Fullscreen, - showName: "RealESRGAN 4x", + showName: "RealESRGAN", }, [PluginName.GFPGAN]: { IconClass: Smile, @@ -67,11 +69,11 @@ const Plugins = () => { return null } - const onPluginClick = (pluginName: string) => { + const onPluginClick = (genMask: boolean, pluginName: string) => { if (pluginName === PluginName.InteractiveSeg) { updateInteractiveSegState({ isInteractiveSeg: true }) } else { - runRenderablePlugin(pluginName) + runRenderablePlugin(genMask, pluginName) } } @@ -87,14 +89,14 @@ const Plugins = () => { - runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 }) + runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 2 }) } > upscale 2x - runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 }) + runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 4 }) } > upscale 4x @@ -104,16 +106,44 @@ const Plugins = () => { ) } + const renderGenImageAndMaskPlugin = (plugin: PluginInfo) => { + const { IconClass, showName } = pluginMap[plugin.name as PluginName] + return ( + + +
+ + {showName} +
+
+ + onPluginClick(false, plugin.name)}> + Remove Background + + onPluginClick(true, plugin.name)}> + Generate Mask + + +
+ ) + } + const renderPlugins = () => { - return plugins.map((plugin: string) => { - const { IconClass, showName } = pluginMap[plugin as PluginName] - if (plugin === PluginName.RealESRGAN) { + return plugins.map((plugin: PluginInfo) => { + const { IconClass, showName } = pluginMap[plugin.name as PluginName] + if (plugin.name === PluginName.RealESRGAN) { return renderRealESRGANPlugin() } + if ( + plugin.name === PluginName.RemoveBG || + plugin.name === PluginName.AnimeSeg + ) { + return renderGenImageAndMaskPlugin(plugin) + } return ( onPluginClick(plugin)} + key={plugin.name} + onClick={() => onPluginClick(false, plugin.name)} disabled={disabled} >
diff --git a/web_app/src/components/Shortcuts.tsx b/web_app/src/components/Shortcuts.tsx index aac9f78..77105df 100644 --- a/web_app/src/components/Shortcuts.tsx +++ b/web_app/src/components/Shortcuts.tsx @@ -64,11 +64,6 @@ export function Shortcuts() { - - diff --git a/web_app/src/components/SidePanel/DiffusionOptions.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx index 557f4ef..3b8f608 100644 --- a/web_app/src/components/SidePanel/DiffusionOptions.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -658,13 +658,13 @@ const DiffusionOptions = () => { updateSettings({ sdSampler: value }) }} > - + {samplers.map((sampler) => ( - + {sampler} ))} diff --git a/web_app/src/components/ui/toast.tsx b/web_app/src/components/ui/toast.tsx index f79747a..768daa1 100644 --- a/web_app/src/components/ui/toast.tsx +++ b/web_app/src/components/ui/toast.tsx @@ -17,6 +17,7 @@ const ToastViewport = React.forwardRef< "fixed top-0 z-[100] flex max-h-screen w-full flex-col-reverse p-4 sm:bottom-0 sm:right-0 sm:top-auto sm:flex-col md:max-w-[420px]", className )} + tabIndex={-1} {...props} /> )) @@ -47,6 +48,7 @@ const Toast = React.forwardRef< ) @@ -63,6 +65,7 @@ const ToastAction = React.forwardRef< "inline-flex h-8 shrink-0 items-center justify-center rounded-md border bg-transparent px-3 text-sm font-medium transition-colors hover:bg-secondary focus:outline-none focus:ring-1 focus:ring-ring disabled:pointer-events-none disabled:opacity-50 group-[.destructive]:border-muted/40 group-[.destructive]:hover:border-destructive/30 group-[.destructive]:hover:bg-destructive group-[.destructive]:hover:text-destructive-foreground group-[.destructive]:focus:ring-destructive", className )} + tabIndex={-1} {...props} /> )) @@ -79,6 +82,7 @@ const ToastClose = React.forwardRef< className )} toast-close="" + tabIndex={-1} {...props} > @@ -93,6 +97,7 @@ const ToastTitle = React.forwardRef< )) @@ -106,6 +111,7 @@ const ToastDescription = React.forwardRef< ref={ref} className={cn("text-sm opacity-90", className)} {...props} + tabIndex={-1} /> )) ToastDescription.displayName = ToastPrimitives.Description.displayName diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 3766e78..77f45f6 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -114,13 +114,15 @@ export function fetchModelInfos(): Promise { } export async function runPlugin( + genMask: boolean, name: string, imageFile: File, upscale?: number, clicks?: number[][] ) { const imageBase64 = await convertToBase64(imageFile) - const res = await fetch(`${API_ENDPOINT}/run_plugin`, { + const p = genMask ? "run_plugin_gen_mask" : "run_plugin_gen_image" + const res = await fetch(`${API_ENDPOINT}/${p}`, { method: "POST", headers: { "Content-Type": "application/json", diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 6eba773..7814962 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -156,7 +156,6 @@ type AppAction = { setFile: (file: File) => Promise setCustomFile: (file: File) => void setIsInpainting: (newValue: boolean) => void - setIsPluginRunning: (newValue: boolean) => void getIsProcessing: () => boolean setBaseBrushSize: (newValue: number) => void getBrushSize: () => number @@ -190,6 +189,7 @@ type AppAction = { showPrevMask: () => Promise hidePrevMask: () => void runRenderablePlugin: ( + genMask: boolean, pluginName: string, params?: PluginParams ) => Promise @@ -521,28 +521,43 @@ export const useStore = createWithEqualityFn()( }, runRenderablePlugin: async ( + genMask: boolean, pluginName: string, params: PluginParams = { upscale: 1 } ) => { const { renders, lineGroups } = get().editorState set((state) => { - state.isInpainting = true + state.isPluginRunning = true }) try { const start = new Date() const targetFile = await get().getCurrentTargetFile() - const res = await runPlugin(pluginName, targetFile, params.upscale) + const res = await runPlugin( + genMask, + pluginName, + targetFile, + params.upscale + ) const { blob } = res - const newRender = new Image() - await loadImage(newRender, blob) - get().setImageSize(newRender.width, newRender.height) - const newRenders = [...renders, newRender] - const newLineGroups = [...lineGroups, []] - get().updateEditorState({ - renders: newRenders, - lineGroups: newLineGroups, - }) + + if (!genMask) { + const newRender = new Image() + await loadImage(newRender, blob) + get().setImageSize(newRender.width, newRender.height) + const newRenders = [...renders, newRender] + const newLineGroups = [...lineGroups, []] + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + }) + } else { + const newMask = new Image() + await loadImage(newMask, blob) + get().updateInteractiveSegState({ + interactiveSegMask: newMask, + }) + } const end = new Date() const time = end.getTime() - start.getTime() toast({ @@ -555,7 +570,7 @@ export const useStore = createWithEqualityFn()( }) } set((state) => { - state.isInpainting = false + state.isPluginRunning = false }) }, @@ -803,11 +818,6 @@ export const useStore = createWithEqualityFn()( state.isInpainting = newValue }), - setIsPluginRunning: (newValue: boolean) => - set((state) => { - state.isPluginRunning = newValue - }), - setFile: async (file: File) => { if (get().settings.enableAutoExtractPrompt) { try { diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index e88b53f..47502d8 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -6,8 +6,14 @@ export interface Filename { mtime: number } +export interface PluginInfo { + name: string + support_gen_image: boolean + support_gen_mask: boolean +} + export interface ServerConfig { - plugins: string[] + plugins: PluginInfo[] enableFileManager: boolean enableAutoSaving: boolean enableControlnet: boolean