make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
parent
6253016019
commit
aca85543ca
@ -28,11 +28,13 @@ from lama_cleaner.helper import (
|
|||||||
pil_to_bytes,
|
pil_to_bytes,
|
||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
concat_alpha_channel,
|
concat_alpha_channel,
|
||||||
|
gen_frontend_mask,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model_info import ModelInfo
|
from lama_cleaner.model_info import ModelInfo
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
|
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
|
||||||
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
from lama_cleaner.schema import (
|
from lama_cleaner.schema import (
|
||||||
GenInfoResponse,
|
GenInfoResponse,
|
||||||
ApiConfig,
|
ApiConfig,
|
||||||
@ -41,6 +43,7 @@ from lama_cleaner.schema import (
|
|||||||
InpaintRequest,
|
InpaintRequest,
|
||||||
RunPluginRequest,
|
RunPluginRequest,
|
||||||
SDSampler,
|
SDSampler,
|
||||||
|
PluginInfo,
|
||||||
)
|
)
|
||||||
from lama_cleaner.file_manager import FileManager
|
from lama_cleaner.file_manager import FileManager
|
||||||
|
|
||||||
@ -145,7 +148,8 @@ class Api:
|
|||||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||||
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -173,7 +177,14 @@ class Api:
|
|||||||
|
|
||||||
def api_server_config(self) -> ServerConfigResponse:
|
def api_server_config(self) -> ServerConfigResponse:
|
||||||
return ServerConfigResponse(
|
return ServerConfigResponse(
|
||||||
plugins=list(self.plugins.keys()),
|
plugins=[
|
||||||
|
PluginInfo(
|
||||||
|
name=it.name,
|
||||||
|
support_gen_image=it.support_gen_image,
|
||||||
|
support_gen_mask=it.support_gen_mask,
|
||||||
|
)
|
||||||
|
for it in self.plugins.values()
|
||||||
|
],
|
||||||
enableFileManager=self.file_manager is not None,
|
enableFileManager=self.file_manager is not None,
|
||||||
enableAutoSaving=self.config.output_dir is not None,
|
enableAutoSaving=self.config.output_dir is not None,
|
||||||
enableControlnet=self.model_manager.enable_controlnet,
|
enableControlnet=self.model_manager.enable_controlnet,
|
||||||
@ -237,22 +248,22 @@ class Api:
|
|||||||
headers={"X-Seed": str(req.sd_seed)},
|
headers={"X-Seed": str(req.sd_seed)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_run_plugin(self, req: RunPluginRequest):
|
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
||||||
ext = "png"
|
ext = "png"
|
||||||
if req.name not in self.plugins:
|
if req.name not in self.plugins:
|
||||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
if not self.plugins[req.name].support_gen_image:
|
||||||
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
|
raise HTTPException(
|
||||||
torch_gc()
|
status_code=422, detail="Plugin does not support output image"
|
||||||
if req.name == InteractiveSeg.name:
|
|
||||||
return Response(
|
|
||||||
content=numpy_to_bytes(bgr_np_img, ext),
|
|
||||||
media_type=f"image/{ext}",
|
|
||||||
)
|
)
|
||||||
if bgr_np_img.shape[2] == 4:
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||||
rgba_np_img = bgr_np_img
|
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
if bgr_or_rgba_np_img.shape[2] == 4:
|
||||||
|
rgba_np_img = bgr_or_rgba_np_img
|
||||||
else:
|
else:
|
||||||
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
|
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
||||||
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
@ -265,6 +276,22 @@ class Api:
|
|||||||
media_type=f"image/{ext}",
|
media_type=f"image/{ext}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
||||||
|
if req.name not in self.plugins:
|
||||||
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||||
|
if not self.plugins[req.name].support_gen_mask:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422, detail="Plugin does not support output image"
|
||||||
|
)
|
||||||
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||||
|
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||||
|
torch_gc()
|
||||||
|
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||||
|
return Response(
|
||||||
|
content=numpy_to_bytes(res_mask, "png"),
|
||||||
|
media_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
def api_samplers(self) -> List[str]:
|
def api_samplers(self) -> List[str]:
|
||||||
return [member.value for member in SDSampler.__members__.values()]
|
return [member.value for member in SDSampler.__members__.values()]
|
||||||
|
|
||||||
@ -290,7 +317,7 @@ class Api:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _build_plugins(self) -> Dict:
|
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
||||||
return build_plugins(
|
return build_plugins(
|
||||||
self.config.enable_interactive_seg,
|
self.config.enable_interactive_seg,
|
||||||
self.config.interactive_seg_model,
|
self.config.interactive_seg_model,
|
||||||
|
@ -350,3 +350,23 @@ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
|
|||||||
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||||
)
|
)
|
||||||
return rgb_np_img
|
return rgb_np_img
|
||||||
|
|
||||||
|
|
||||||
|
def gen_frontend_mask(bgr_or_gray_mask):
|
||||||
|
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
|
||||||
|
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# fronted brush color "ffcc00bb"
|
||||||
|
# TODO: how to set kernel size?
|
||||||
|
kernel_size = 9
|
||||||
|
bgr_or_gray_mask = cv2.dilate(
|
||||||
|
bgr_or_gray_mask,
|
||||||
|
np.ones((kernel_size, kernel_size), np.uint8),
|
||||||
|
iterations=1,
|
||||||
|
)
|
||||||
|
res_mask = np.zeros(
|
||||||
|
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
|
||||||
|
)
|
||||||
|
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
||||||
|
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||||
|
return res_mask
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import computed_field, BaseModel
|
from pydantic import computed_field, BaseModel
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .interactive_seg import InteractiveSeg
|
|
||||||
from .remove_bg import RemoveBG
|
|
||||||
from .realesrgan import RealESRGANUpscaler
|
|
||||||
from .gfpgan_plugin import GFPGANPlugin
|
|
||||||
from .restoreformer import RestoreFormerPlugin
|
|
||||||
from .anime_seg import AnimeSeg
|
from .anime_seg import AnimeSeg
|
||||||
|
from .gfpgan_plugin import GFPGANPlugin
|
||||||
|
from .interactive_seg import InteractiveSeg
|
||||||
|
from .realesrgan import RealESRGANUpscaler
|
||||||
|
from .remove_bg import RemoveBG
|
||||||
|
from .restoreformer import RestoreFormerPlugin
|
||||||
from ..const import InteractiveSegModel, Device, RealESRGANModel
|
from ..const import InteractiveSegModel, Device, RealESRGANModel
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ def build_plugins(
|
|||||||
enable_restoreformer: bool,
|
enable_restoreformer: bool,
|
||||||
restoreformer_device: Device,
|
restoreformer_device: Device,
|
||||||
no_half: bool,
|
no_half: bool,
|
||||||
):
|
) -> Dict:
|
||||||
plugins = {}
|
plugins = {}
|
||||||
if enable_interactive_seg:
|
if enable_interactive_seg:
|
||||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||||
|
@ -416,6 +416,8 @@ ANIME_SEG_MODELS = {
|
|||||||
class AnimeSeg(BasePlugin):
|
class AnimeSeg(BasePlugin):
|
||||||
# Model from: https://github.com/SkyTNT/anime-segmentation
|
# Model from: https://github.com/SkyTNT/anime-segmentation
|
||||||
name = "AnimeSeg"
|
name = "AnimeSeg"
|
||||||
|
support_gen_image = True
|
||||||
|
support_gen_mask = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin):
|
|||||||
ANIME_SEG_MODELS["md5"],
|
ANIME_SEG_MODELS["md5"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
|
mask = self.forward(rgb_np_img)
|
||||||
|
mask = Image.fromarray(mask, mode="L")
|
||||||
|
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
||||||
|
empty = Image.new("RGBA", (w0, h0), 0)
|
||||||
|
img = Image.fromarray(rgb_np_img)
|
||||||
|
cutout = Image.composite(img, empty, mask)
|
||||||
|
return np.asarray(cutout)
|
||||||
|
|
||||||
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
return self.forward(rgb_np_img)
|
return self.forward(rgb_np_img)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def forward(self, rgb_np_img):
|
def forward(self, rgb_np_img):
|
||||||
s = 1024
|
s = 1024
|
||||||
|
|
||||||
@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin):
|
|||||||
mask = self.model(tmpImg)
|
mask = self.model(tmpImg)
|
||||||
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
||||||
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
||||||
mask = Image.fromarray((mask * 255).astype("uint8"), mode="L")
|
return (mask * 255).astype("uint8")
|
||||||
|
|
||||||
empty = Image.new("RGBA", (w0, h0), 0)
|
|
||||||
img = Image.fromarray(rgb_np_img)
|
|
||||||
cutout = Image.composite(img, empty, mask)
|
|
||||||
return np.asarray(cutout)
|
|
||||||
|
@ -5,15 +5,23 @@ from lama_cleaner.schema import RunPluginRequest
|
|||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
|
name: str
|
||||||
|
support_gen_image: bool = False
|
||||||
|
support_gen_mask: bool = False
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
err_msg = self.check_dep()
|
err_msg = self.check_dep()
|
||||||
if err_msg:
|
if err_msg:
|
||||||
logger.error(err_msg)
|
logger.error(err_msg)
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
# return RGBA np image or BGR np image
|
# return RGBA np image or BGR np image
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
|
# return GRAY or BGR np image, 255 means foreground, 0 means background
|
||||||
|
...
|
||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
...
|
...
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
|
|||||||
|
|
||||||
class GFPGANPlugin(BasePlugin):
|
class GFPGANPlugin(BasePlugin):
|
||||||
name = "GFPGAN"
|
name = "GFPGAN"
|
||||||
|
support_gen_image = True
|
||||||
|
|
||||||
def __init__(self, device, upscaler=None):
|
def __init__(self, device, upscaler=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -37,7 +39,7 @@ class GFPGANPlugin(BasePlugin):
|
|||||||
self.face_enhancer.face_helper.face_det.to(device)
|
self.face_enhancer.face_helper.face_det.to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
weight = 0.5
|
weight = 0.5
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
|
|||||||
|
|
||||||
class InteractiveSeg(BasePlugin):
|
class InteractiveSeg(BasePlugin):
|
||||||
name = "InteractiveSeg"
|
name = "InteractiveSeg"
|
||||||
|
support_gen_mask = True
|
||||||
|
|
||||||
def __init__(self, model_name, device):
|
def __init__(self, model_name, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
|
|||||||
)
|
)
|
||||||
self.prev_img_md5 = None
|
self.prev_img_md5 = None
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||||
input_point = []
|
input_point = []
|
||||||
input_label = []
|
input_label = []
|
||||||
@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
|
|||||||
multimask_output=False,
|
multimask_output=False,
|
||||||
)
|
)
|
||||||
mask = masks[0].astype(np.uint8) * 255
|
mask = masks[0].astype(np.uint8) * 255
|
||||||
# TODO: how to set kernel size?
|
return mask
|
||||||
kernel_size = 9
|
|
||||||
mask = cv2.dilate(
|
|
||||||
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
|
|
||||||
)
|
|
||||||
# fronted brush color "ffcc00bb"
|
|
||||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
|
||||||
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
|
|
||||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
|
||||||
return res_mask
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import RealESRGANModel
|
from lama_cleaner.const import RealESRGANModel
|
||||||
@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest
|
|||||||
|
|
||||||
class RealESRGANUpscaler(BasePlugin):
|
class RealESRGANUpscaler(BasePlugin):
|
||||||
name = "RealESRGAN"
|
name = "RealESRGAN"
|
||||||
|
support_gen_image = True
|
||||||
|
|
||||||
def __init__(self, name, device, no_half=False):
|
def __init__(self, name, device, no_half=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||||
result = self.forward(bgr_np_img, req.scale)
|
result = self.forward(bgr_np_img, req.scale)
|
||||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def forward(self, bgr_np_img, scale: float):
|
def forward(self, bgr_np_img, scale: float):
|
||||||
# 输出是 BGR
|
# 输出是 BGR
|
||||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||||
|
@ -9,6 +9,8 @@ from lama_cleaner.schema import RunPluginRequest
|
|||||||
|
|
||||||
class RemoveBG(BasePlugin):
|
class RemoveBG(BasePlugin):
|
||||||
name = "RemoveBG"
|
name = "RemoveBG"
|
||||||
|
support_gen_mask = True
|
||||||
|
support_gen_image = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -20,17 +22,24 @@ class RemoveBG(BasePlugin):
|
|||||||
|
|
||||||
self.session = new_session(model_name="u2net")
|
self.session = new_session(model_name="u2net")
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
|
||||||
return self.forward(bgr_np_img)
|
|
||||||
|
|
||||||
def forward(self, bgr_np_img) -> np.ndarray:
|
|
||||||
from rembg import remove
|
from rembg import remove
|
||||||
|
|
||||||
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
# return BGRA image
|
# return BGRA image
|
||||||
output = remove(bgr_np_img, session=self.session)
|
output = remove(bgr_np_img, session=self.session)
|
||||||
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
||||||
|
|
||||||
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
|
from rembg import remove
|
||||||
|
|
||||||
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# return BGR image, 255 means foreground, 0 means background
|
||||||
|
output = remove(bgr_np_img, session=self.session, only_mask=True)
|
||||||
|
return output
|
||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
try:
|
try:
|
||||||
import rembg
|
import rembg
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
|
|||||||
|
|
||||||
class RestoreFormerPlugin(BasePlugin):
|
class RestoreFormerPlugin(BasePlugin):
|
||||||
name = "RestoreFormer"
|
name = "RestoreFormer"
|
||||||
|
support_gen_image = True
|
||||||
|
|
||||||
def __init__(self, device, upscaler=None):
|
def __init__(self, device, upscaler=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -32,7 +34,7 @@ class RestoreFormerPlugin(BasePlugin):
|
|||||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
weight = 0.5
|
weight = 0.5
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||||
|
@ -3,12 +3,17 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, List
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
from PIL.Image import Image
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from pydantic import BaseModel, Field, validator, field_validator
|
|
||||||
|
|
||||||
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
|
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
support_gen_image: bool = False
|
||||||
|
support_gen_mask: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CV2Flag(str, Enum):
|
class CV2Flag(str, Enum):
|
||||||
INPAINT_NS = "INPAINT_NS"
|
INPAINT_NS = "INPAINT_NS"
|
||||||
INPAINT_TELEA = "INPAINT_TELEA"
|
INPAINT_TELEA = "INPAINT_TELEA"
|
||||||
@ -272,7 +277,7 @@ class GenInfoResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ServerConfigResponse(BaseModel):
|
class ServerConfigResponse(BaseModel):
|
||||||
plugins: List[str]
|
plugins: List[PluginInfo]
|
||||||
enableFileManager: bool
|
enableFileManager: bool
|
||||||
enableAutoSaving: bool
|
enableAutoSaving: bool
|
||||||
enableControlnet: bool
|
enableControlnet: bool
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from lama_cleaner.helper import encode_pil_to_base64
|
from lama_cleaner.helper import encode_pil_to_base64, gen_frontend_mask
|
||||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||||
from lama_cleaner.schema import RunPluginRequest
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
||||||
@ -35,34 +35,48 @@ def _save(img, name):
|
|||||||
|
|
||||||
def test_remove_bg():
|
def test_remove_bg():
|
||||||
model = RemoveBG()
|
model = RemoveBG()
|
||||||
rgba_np_img = model(
|
rgba_np_img = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||||
_save(res, "test_remove_bg.png")
|
_save(res, "test_remove_bg.png")
|
||||||
|
|
||||||
|
bgr_np_img = model.gen_mask(
|
||||||
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
|
)
|
||||||
|
|
||||||
|
res_mask = gen_frontend_mask(bgr_np_img)
|
||||||
|
_save(res_mask, "test_remove_bg_frontend_mask.png")
|
||||||
|
|
||||||
|
assert len(bgr_np_img.shape) == 2
|
||||||
|
_save(bgr_np_img, "test_remove_bg_mask.jpeg")
|
||||||
|
|
||||||
|
|
||||||
def test_anime_seg():
|
def test_anime_seg():
|
||||||
model = AnimeSeg()
|
model = AnimeSeg()
|
||||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||||
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
||||||
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||||
assert len(res.shape) == 3
|
assert len(res.shape) == 3
|
||||||
assert res.shape[-1] == 4
|
assert res.shape[-1] == 4
|
||||||
_save(res, "test_anime_seg.png")
|
_save(res, "test_anime_seg.png")
|
||||||
|
|
||||||
|
res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||||
|
assert len(res.shape) == 2
|
||||||
|
_save(res, "test_anime_seg_mask.png")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||||
def test_upscale(device):
|
def test_upscale(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
||||||
res = model(
|
res = model.gen_image(
|
||||||
rgb_img,
|
rgb_img,
|
||||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
||||||
)
|
)
|
||||||
_save(res, f"test_upscale_x2_{device}.png")
|
_save(res, f"test_upscale_x2_{device}.png")
|
||||||
|
|
||||||
res = model(
|
res = model.gen_image(
|
||||||
rgb_img,
|
rgb_img,
|
||||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
||||||
)
|
)
|
||||||
@ -73,7 +87,9 @@ def test_upscale(device):
|
|||||||
def test_gfpgan(device):
|
def test_gfpgan(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = GFPGANPlugin(device)
|
model = GFPGANPlugin(device)
|
||||||
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64))
|
res = model.gen_image(
|
||||||
|
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||||
|
)
|
||||||
_save(res, f"test_gfpgan_{device}.png")
|
_save(res, f"test_gfpgan_{device}.png")
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +97,7 @@ def test_gfpgan(device):
|
|||||||
def test_restoreformer(device):
|
def test_restoreformer(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = RestoreFormerPlugin(device)
|
model = RestoreFormerPlugin(device)
|
||||||
res = model(
|
res = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
_save(res, f"test_restoreformer_{device}.png")
|
_save(res, f"test_restoreformer_{device}.png")
|
||||||
@ -91,7 +107,7 @@ def test_restoreformer(device):
|
|||||||
def test_segment_anything(device):
|
def test_segment_anything(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = InteractiveSeg("vit_l", device)
|
model = InteractiveSeg("vit_l", device)
|
||||||
new_mask = model(
|
new_mask = model.gen_mask(
|
||||||
rgb_img,
|
rgb_img,
|
||||||
RunPluginRequest(
|
RunPluginRequest(
|
||||||
name=InteractiveSeg.name,
|
name=InteractiveSeg.name,
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Lama Cleaner</title>
|
<title>IOPaint</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="root"></div>
|
<div id="root"></div>
|
||||||
|
@ -272,7 +272,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
console.log("[useEffect] centerView")
|
console.log("[useEffect] centerView")
|
||||||
// render 改变尺寸以后,undo/redo 重新 center
|
// render 改变尺寸以后,undo/redo 重新 center
|
||||||
viewportRef?.current?.centerView(minScale, 1)
|
viewportRef?.current?.centerView(minScale, 1)
|
||||||
}, [context?.canvas.height, context?.canvas.width, viewportRef, minScale])
|
}, [imageHeight, imageWidth, viewportRef, minScale])
|
||||||
|
|
||||||
// Zoom reset
|
// Zoom reset
|
||||||
const resetZoom = useCallback(() => {
|
const resetZoom = useCallback(() => {
|
||||||
@ -358,6 +358,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
const targetFile = await getCurrentRender()
|
const targetFile = await getCurrentRender()
|
||||||
try {
|
try {
|
||||||
const res = await runPlugin(
|
const res = await runPlugin(
|
||||||
|
true,
|
||||||
PluginName.InteractiveSeg,
|
PluginName.InteractiveSeg,
|
||||||
targetFile,
|
targetFile,
|
||||||
undefined,
|
undefined,
|
||||||
|
@ -16,6 +16,7 @@ import {
|
|||||||
Smile,
|
Smile,
|
||||||
} from "lucide-react"
|
} from "lucide-react"
|
||||||
import { useStore } from "@/lib/states"
|
import { useStore } from "@/lib/states"
|
||||||
|
import { PluginInfo } from "@/lib/types"
|
||||||
|
|
||||||
export enum PluginName {
|
export enum PluginName {
|
||||||
RemoveBG = "RemoveBG",
|
RemoveBG = "RemoveBG",
|
||||||
@ -26,6 +27,7 @@ export enum PluginName {
|
|||||||
InteractiveSeg = "InteractiveSeg",
|
InteractiveSeg = "InteractiveSeg",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: get plugin config from server and using form-render??
|
||||||
const pluginMap = {
|
const pluginMap = {
|
||||||
[PluginName.RemoveBG]: {
|
[PluginName.RemoveBG]: {
|
||||||
IconClass: Slice,
|
IconClass: Slice,
|
||||||
@ -37,7 +39,7 @@ const pluginMap = {
|
|||||||
},
|
},
|
||||||
[PluginName.RealESRGAN]: {
|
[PluginName.RealESRGAN]: {
|
||||||
IconClass: Fullscreen,
|
IconClass: Fullscreen,
|
||||||
showName: "RealESRGAN 4x",
|
showName: "RealESRGAN",
|
||||||
},
|
},
|
||||||
[PluginName.GFPGAN]: {
|
[PluginName.GFPGAN]: {
|
||||||
IconClass: Smile,
|
IconClass: Smile,
|
||||||
@ -67,11 +69,11 @@ const Plugins = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
const onPluginClick = (pluginName: string) => {
|
const onPluginClick = (genMask: boolean, pluginName: string) => {
|
||||||
if (pluginName === PluginName.InteractiveSeg) {
|
if (pluginName === PluginName.InteractiveSeg) {
|
||||||
updateInteractiveSegState({ isInteractiveSeg: true })
|
updateInteractiveSegState({ isInteractiveSeg: true })
|
||||||
} else {
|
} else {
|
||||||
runRenderablePlugin(pluginName)
|
runRenderablePlugin(genMask, pluginName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,14 +89,14 @@ const Plugins = () => {
|
|||||||
<DropdownMenuSubContent>
|
<DropdownMenuSubContent>
|
||||||
<DropdownMenuItem
|
<DropdownMenuItem
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 })
|
runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 2 })
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
upscale 2x
|
upscale 2x
|
||||||
</DropdownMenuItem>
|
</DropdownMenuItem>
|
||||||
<DropdownMenuItem
|
<DropdownMenuItem
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 })
|
runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 4 })
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
upscale 4x
|
upscale 4x
|
||||||
@ -104,16 +106,44 @@ const Plugins = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderGenImageAndMaskPlugin = (plugin: PluginInfo) => {
|
||||||
|
const { IconClass, showName } = pluginMap[plugin.name as PluginName]
|
||||||
|
return (
|
||||||
|
<DropdownMenuSub key={plugin.name}>
|
||||||
|
<DropdownMenuSubTrigger disabled={disabled}>
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<IconClass className="p-1" />
|
||||||
|
{showName}
|
||||||
|
</div>
|
||||||
|
</DropdownMenuSubTrigger>
|
||||||
|
<DropdownMenuSubContent>
|
||||||
|
<DropdownMenuItem onClick={() => onPluginClick(false, plugin.name)}>
|
||||||
|
Remove Background
|
||||||
|
</DropdownMenuItem>
|
||||||
|
<DropdownMenuItem onClick={() => onPluginClick(true, plugin.name)}>
|
||||||
|
Generate Mask
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</DropdownMenuSubContent>
|
||||||
|
</DropdownMenuSub>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderPlugins = () => {
|
const renderPlugins = () => {
|
||||||
return plugins.map((plugin: string) => {
|
return plugins.map((plugin: PluginInfo) => {
|
||||||
const { IconClass, showName } = pluginMap[plugin as PluginName]
|
const { IconClass, showName } = pluginMap[plugin.name as PluginName]
|
||||||
if (plugin === PluginName.RealESRGAN) {
|
if (plugin.name === PluginName.RealESRGAN) {
|
||||||
return renderRealESRGANPlugin()
|
return renderRealESRGANPlugin()
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
plugin.name === PluginName.RemoveBG ||
|
||||||
|
plugin.name === PluginName.AnimeSeg
|
||||||
|
) {
|
||||||
|
return renderGenImageAndMaskPlugin(plugin)
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<DropdownMenuItem
|
<DropdownMenuItem
|
||||||
key={plugin}
|
key={plugin.name}
|
||||||
onClick={() => onPluginClick(plugin)}
|
onClick={() => onPluginClick(false, plugin.name)}
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
>
|
>
|
||||||
<div className="flex gap-2 items-center">
|
<div className="flex gap-2 items-center">
|
||||||
|
@ -64,11 +64,6 @@ export function Shortcuts() {
|
|||||||
<ShortCut content="Decrease Brush Size" keys={["["]} />
|
<ShortCut content="Decrease Brush Size" keys={["["]} />
|
||||||
<ShortCut content="Increase Brush Size" keys={["]"]} />
|
<ShortCut content="Increase Brush Size" keys={["]"]} />
|
||||||
<ShortCut content="View Original Image" keys={["Hold Tab"]} />
|
<ShortCut content="View Original Image" keys={["Hold Tab"]} />
|
||||||
<ShortCut
|
|
||||||
content="Multi-Stroke Drawing"
|
|
||||||
keys={[`Hold ${CmdOrCtrl()}`]}
|
|
||||||
/>
|
|
||||||
<ShortCut content="Cancel Drawing" keys={["Esc"]} />
|
|
||||||
|
|
||||||
<ShortCut content="Undo" keys={[CmdOrCtrl(), "Z"]} />
|
<ShortCut content="Undo" keys={[CmdOrCtrl(), "Z"]} />
|
||||||
<ShortCut content="Redo" keys={[CmdOrCtrl(), "Shift", "Z"]} />
|
<ShortCut content="Redo" keys={[CmdOrCtrl(), "Shift", "Z"]} />
|
||||||
|
@ -658,13 +658,13 @@ const DiffusionOptions = () => {
|
|||||||
updateSettings({ sdSampler: value })
|
updateSettings({ sdSampler: value })
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-[180px]">
|
<SelectTrigger className="w-[175px] text-xs">
|
||||||
<SelectValue placeholder="Select sampler" />
|
<SelectValue placeholder="Select sampler" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent align="end">
|
<SelectContent align="end">
|
||||||
<SelectGroup>
|
<SelectGroup>
|
||||||
{samplers.map((sampler) => (
|
{samplers.map((sampler) => (
|
||||||
<SelectItem key={sampler} value={sampler}>
|
<SelectItem key={sampler} value={sampler} className="text-xs">
|
||||||
{sampler}
|
{sampler}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
|
@ -17,6 +17,7 @@ const ToastViewport = React.forwardRef<
|
|||||||
"fixed top-0 z-[100] flex max-h-screen w-full flex-col-reverse p-4 sm:bottom-0 sm:right-0 sm:top-auto sm:flex-col md:max-w-[420px]",
|
"fixed top-0 z-[100] flex max-h-screen w-full flex-col-reverse p-4 sm:bottom-0 sm:right-0 sm:top-auto sm:flex-col md:max-w-[420px]",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
|
tabIndex={-1}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
))
|
))
|
||||||
@ -47,6 +48,7 @@ const Toast = React.forwardRef<
|
|||||||
<ToastPrimitives.Root
|
<ToastPrimitives.Root
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(toastVariants({ variant }), className)}
|
className={cn(toastVariants({ variant }), className)}
|
||||||
|
tabIndex={-1}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
@ -63,6 +65,7 @@ const ToastAction = React.forwardRef<
|
|||||||
"inline-flex h-8 shrink-0 items-center justify-center rounded-md border bg-transparent px-3 text-sm font-medium transition-colors hover:bg-secondary focus:outline-none focus:ring-1 focus:ring-ring disabled:pointer-events-none disabled:opacity-50 group-[.destructive]:border-muted/40 group-[.destructive]:hover:border-destructive/30 group-[.destructive]:hover:bg-destructive group-[.destructive]:hover:text-destructive-foreground group-[.destructive]:focus:ring-destructive",
|
"inline-flex h-8 shrink-0 items-center justify-center rounded-md border bg-transparent px-3 text-sm font-medium transition-colors hover:bg-secondary focus:outline-none focus:ring-1 focus:ring-ring disabled:pointer-events-none disabled:opacity-50 group-[.destructive]:border-muted/40 group-[.destructive]:hover:border-destructive/30 group-[.destructive]:hover:bg-destructive group-[.destructive]:hover:text-destructive-foreground group-[.destructive]:focus:ring-destructive",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
|
tabIndex={-1}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
))
|
))
|
||||||
@ -79,6 +82,7 @@ const ToastClose = React.forwardRef<
|
|||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
toast-close=""
|
toast-close=""
|
||||||
|
tabIndex={-1}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
<Cross2Icon className="h-4 w-4" />
|
<Cross2Icon className="h-4 w-4" />
|
||||||
@ -93,6 +97,7 @@ const ToastTitle = React.forwardRef<
|
|||||||
<ToastPrimitives.Title
|
<ToastPrimitives.Title
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn("text-sm font-semibold [&+div]:text-xs", className)}
|
className={cn("text-sm font-semibold [&+div]:text-xs", className)}
|
||||||
|
tabIndex={-1}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
))
|
))
|
||||||
@ -106,6 +111,7 @@ const ToastDescription = React.forwardRef<
|
|||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn("text-sm opacity-90", className)}
|
className={cn("text-sm opacity-90", className)}
|
||||||
{...props}
|
{...props}
|
||||||
|
tabIndex={-1}
|
||||||
/>
|
/>
|
||||||
))
|
))
|
||||||
ToastDescription.displayName = ToastPrimitives.Description.displayName
|
ToastDescription.displayName = ToastPrimitives.Description.displayName
|
||||||
|
@ -114,13 +114,15 @@ export function fetchModelInfos(): Promise<ModelInfo[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function runPlugin(
|
export async function runPlugin(
|
||||||
|
genMask: boolean,
|
||||||
name: string,
|
name: string,
|
||||||
imageFile: File,
|
imageFile: File,
|
||||||
upscale?: number,
|
upscale?: number,
|
||||||
clicks?: number[][]
|
clicks?: number[][]
|
||||||
) {
|
) {
|
||||||
const imageBase64 = await convertToBase64(imageFile)
|
const imageBase64 = await convertToBase64(imageFile)
|
||||||
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
|
const p = genMask ? "run_plugin_gen_mask" : "run_plugin_gen_image"
|
||||||
|
const res = await fetch(`${API_ENDPOINT}/${p}`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -156,7 +156,6 @@ type AppAction = {
|
|||||||
setFile: (file: File) => Promise<void>
|
setFile: (file: File) => Promise<void>
|
||||||
setCustomFile: (file: File) => void
|
setCustomFile: (file: File) => void
|
||||||
setIsInpainting: (newValue: boolean) => void
|
setIsInpainting: (newValue: boolean) => void
|
||||||
setIsPluginRunning: (newValue: boolean) => void
|
|
||||||
getIsProcessing: () => boolean
|
getIsProcessing: () => boolean
|
||||||
setBaseBrushSize: (newValue: number) => void
|
setBaseBrushSize: (newValue: number) => void
|
||||||
getBrushSize: () => number
|
getBrushSize: () => number
|
||||||
@ -190,6 +189,7 @@ type AppAction = {
|
|||||||
showPrevMask: () => Promise<void>
|
showPrevMask: () => Promise<void>
|
||||||
hidePrevMask: () => void
|
hidePrevMask: () => void
|
||||||
runRenderablePlugin: (
|
runRenderablePlugin: (
|
||||||
|
genMask: boolean,
|
||||||
pluginName: string,
|
pluginName: string,
|
||||||
params?: PluginParams
|
params?: PluginParams
|
||||||
) => Promise<void>
|
) => Promise<void>
|
||||||
@ -521,28 +521,43 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
},
|
},
|
||||||
|
|
||||||
runRenderablePlugin: async (
|
runRenderablePlugin: async (
|
||||||
|
genMask: boolean,
|
||||||
pluginName: string,
|
pluginName: string,
|
||||||
params: PluginParams = { upscale: 1 }
|
params: PluginParams = { upscale: 1 }
|
||||||
) => {
|
) => {
|
||||||
const { renders, lineGroups } = get().editorState
|
const { renders, lineGroups } = get().editorState
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.isInpainting = true
|
state.isPluginRunning = true
|
||||||
})
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const start = new Date()
|
const start = new Date()
|
||||||
const targetFile = await get().getCurrentTargetFile()
|
const targetFile = await get().getCurrentTargetFile()
|
||||||
const res = await runPlugin(pluginName, targetFile, params.upscale)
|
const res = await runPlugin(
|
||||||
|
genMask,
|
||||||
|
pluginName,
|
||||||
|
targetFile,
|
||||||
|
params.upscale
|
||||||
|
)
|
||||||
const { blob } = res
|
const { blob } = res
|
||||||
const newRender = new Image()
|
|
||||||
await loadImage(newRender, blob)
|
if (!genMask) {
|
||||||
get().setImageSize(newRender.width, newRender.height)
|
const newRender = new Image()
|
||||||
const newRenders = [...renders, newRender]
|
await loadImage(newRender, blob)
|
||||||
const newLineGroups = [...lineGroups, []]
|
get().setImageSize(newRender.width, newRender.height)
|
||||||
get().updateEditorState({
|
const newRenders = [...renders, newRender]
|
||||||
renders: newRenders,
|
const newLineGroups = [...lineGroups, []]
|
||||||
lineGroups: newLineGroups,
|
get().updateEditorState({
|
||||||
})
|
renders: newRenders,
|
||||||
|
lineGroups: newLineGroups,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const newMask = new Image()
|
||||||
|
await loadImage(newMask, blob)
|
||||||
|
get().updateInteractiveSegState({
|
||||||
|
interactiveSegMask: newMask,
|
||||||
|
})
|
||||||
|
}
|
||||||
const end = new Date()
|
const end = new Date()
|
||||||
const time = end.getTime() - start.getTime()
|
const time = end.getTime() - start.getTime()
|
||||||
toast({
|
toast({
|
||||||
@ -555,7 +570,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.isInpainting = false
|
state.isPluginRunning = false
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -803,11 +818,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
state.isInpainting = newValue
|
state.isInpainting = newValue
|
||||||
}),
|
}),
|
||||||
|
|
||||||
setIsPluginRunning: (newValue: boolean) =>
|
|
||||||
set((state) => {
|
|
||||||
state.isPluginRunning = newValue
|
|
||||||
}),
|
|
||||||
|
|
||||||
setFile: async (file: File) => {
|
setFile: async (file: File) => {
|
||||||
if (get().settings.enableAutoExtractPrompt) {
|
if (get().settings.enableAutoExtractPrompt) {
|
||||||
try {
|
try {
|
||||||
|
@ -6,8 +6,14 @@ export interface Filename {
|
|||||||
mtime: number
|
mtime: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface PluginInfo {
|
||||||
|
name: string
|
||||||
|
support_gen_image: boolean
|
||||||
|
support_gen_mask: boolean
|
||||||
|
}
|
||||||
|
|
||||||
export interface ServerConfig {
|
export interface ServerConfig {
|
||||||
plugins: string[]
|
plugins: PluginInfo[]
|
||||||
enableFileManager: boolean
|
enableFileManager: boolean
|
||||||
enableAutoSaving: boolean
|
enableAutoSaving: boolean
|
||||||
enableControlnet: boolean
|
enableControlnet: boolean
|
||||||
|
Loading…
Reference in New Issue
Block a user