diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py index 884c148..73d3897 100644 --- a/lama_cleaner/file_manager/file_manager.py +++ b/lama_cleaner/file_manager/file_manager.py @@ -7,8 +7,8 @@ import time from io import BytesIO from pathlib import Path import numpy as np -from watchdog.events import FileSystemEventHandler -from watchdog.observers import Observer +# from watchdog.events import FileSystemEventHandler +# from watchdog.observers import Observer from PIL import Image, ImageOps, PngImagePlugin from loguru import logger @@ -19,7 +19,7 @@ from .storage_backends import FilesystemStorageBackend from .utils import aspect_to_string, generate_filename, glob_img -class FileManager(FileSystemEventHandler): +class FileManager: def __init__(self, app=None): self.app = app self._default_root_directory = "media" @@ -43,19 +43,19 @@ class FileManager(FileSystemEventHandler): "output": datetime.utcnow(), } - def start(self): - self.image_dir_filenames = self._media_names(self.root_directory) - self.output_dir_filenames = self._media_names(self.output_dir) - - logger.info(f"Start watching image directory: {self.root_directory}") - self.image_dir_observer = Observer() - self.image_dir_observer.schedule(self, self.root_directory, recursive=False) - self.image_dir_observer.start() - - logger.info(f"Start watching output directory: {self.output_dir}") - self.output_dir_observer = Observer() - self.output_dir_observer.schedule(self, self.output_dir, recursive=False) - self.output_dir_observer.start() + # def start(self): + # self.image_dir_filenames = self._media_names(self.root_directory) + # self.output_dir_filenames = self._media_names(self.output_dir) + # + # logger.info(f"Start watching image directory: {self.root_directory}") + # self.image_dir_observer = Observer() + # self.image_dir_observer.schedule(self, self.root_directory, recursive=False) + # self.image_dir_observer.start() + # + # logger.info(f"Start watching output directory: {self.output_dir}") + # self.output_dir_observer = Observer() + # self.output_dir_observer.schedule(self, self.output_dir, recursive=False) + # self.output_dir_observer.start() def on_modified(self, event): if not os.path.isdir(event.src_path): diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index c859881..ed55dc3 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -69,9 +69,33 @@ def parse_args(): help="Disable model switch in frontend", ) parser.add_argument( - "--quality", default=95, type=int, help=QUALITY_HELP, + "--quality", + default=95, + type=int, + help=QUALITY_HELP, ) + # Plugins + parser.add_argument( + "--enable-interactive-seg", + action="store_true", + help="Enable interactive segmentation. Always run on CPU", + ) + parser.add_argument( + "--enable-remove-bg", + action="store_true", + help="Enable remove background. Always run on CPU", + ) + parser.add_argument( + "--enable-realesrgan", + action="store_true", + help="Enable realesrgan super resolution", + ) + parser.add_argument( + "--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"] + ) + ######### + # useless args parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS) parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS) @@ -123,6 +147,7 @@ def parse_args(): if args.model not in SD15_MODELS: logger.warning(f"--sd_controlnet only support {SD15_MODELS}") + os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR if args.model_dir and args.model_dir is not None: if os.path.isfile(args.model_dir): parser.error(f"invalid --model-dir: {args.model_dir} is a file") @@ -132,6 +157,7 @@ def parse_args(): Path(args.model_dir).mkdir(exist_ok=True, parents=True) os.environ["XDG_CACHE_HOME"] = args.model_dir + os.environ["U2NET_HOME"] = args.model_dir if args.input and args.input is not None: if not os.path.exists(args.input): diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py new file mode 100644 index 0000000..711b023 --- /dev/null +++ b/lama_cleaner/plugins/__init__.py @@ -0,0 +1,3 @@ +from .interactive_seg import InteractiveSeg, Click +from .remove_bg import RemoveBG +from .upscale import RealESRGANUpscaler diff --git a/lama_cleaner/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py similarity index 65% rename from lama_cleaner/interactive_seg.py rename to lama_cleaner/plugins/interactive_seg.py index 9b07349..4f0c21c 100644 --- a/lama_cleaner/interactive_seg.py +++ b/lama_cleaner/plugins/interactive_seg.py @@ -1,14 +1,19 @@ +import json +import json import os +from typing import Tuple, List import cv2 -from typing import Tuple, List +import numpy as np import torch import torch.nn.functional as F from loguru import logger from pydantic import BaseModel -import numpy as np -from lama_cleaner.helper import only_keep_largest_contour, load_jit_model +from lama_cleaner.helper import ( + load_jit_model, + load_img, +) class Click(BaseModel): @@ -21,11 +26,11 @@ class Click(BaseModel): def coords_and_indx(self): return (*self.coords, self.indx) - def scale(self, x_ratio: float, y_ratio: float) -> 'Click': + def scale(self, x_ratio: float, y_ratio: float) -> "Click": return Click( coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio), is_positive=self.is_positive, - indx=self.indx + indx=self.indx, ) @@ -40,21 +45,32 @@ class ResizeTrans: image_height, image_width = image_nd.shape[2:4] self.image_height = image_height self.image_width = image_width - image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True) + image_nd_r = F.interpolate( + image_nd, + (self.crop_height, self.crop_width), + mode="bilinear", + align_corners=True, + ) y_ratio = self.crop_height / image_height x_ratio = self.crop_width / image_width clicks_lists_resized = [] for clicks_list in clicks_lists: - clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list] + clicks_list_resized = [ + click.scale(y_ratio, x_ratio) for click in clicks_list + ] clicks_lists_resized.append(clicks_list_resized) return image_nd_r, clicks_lists_resized def inv_transform(self, prob_map): - new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear', - align_corners=True) + new_prob_map = F.interpolate( + prob_map, + (self.image_height, self.image_width), + mode="bilinear", + align_corners=True, + ) return new_prob_map @@ -106,8 +122,9 @@ class ISPredictor(object): pred = torch.sigmoid(pred_logits) pred = self.post_process(pred) - prediction = F.interpolate(pred, mode='bilinear', align_corners=True, - size=image_nd.size()[2:]) + prediction = F.interpolate( + pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:] + ) for t in reversed(transforms): prediction = t.inv_transform(prediction) @@ -121,32 +138,49 @@ class ISPredictor(object): pred_mask = pred.cpu().numpy()[0][0] # morph_open to remove small noise kernel_size = self.open_kernel_size - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (kernel_size, kernel_size) + ) pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1) # Why dilate: make region slightly larger to avoid missing some pixels, this generally works better dilate_kernel_size = self.dilate_kernel_size if dilate_kernel_size > 1: - kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)) + kernel = cv2.getStructuringElement( + cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size) + ) pred_mask = cv2.dilate(pred_mask, kernel, 1) return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0) def get_points_nd(self, clicks_lists): total_clicks = [] - num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] - num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_pos_clicks = [ + sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists + ] + num_neg_clicks = [ + len(clicks_list) - num_pos + for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks) + ] num_max_points = max(num_pos_clicks + num_neg_clicks) if self.net_clicks_limit is not None: num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: - clicks_list = clicks_list[:self.net_clicks_limit] - pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] - pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + clicks_list = clicks_list[: self.net_clicks_limit] + pos_clicks = [ + click.coords_and_indx for click in clicks_list if click.is_positive + ] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [ + (-1, -1, -1) + ] - neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] - neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + neg_clicks = [ + click.coords_and_indx for click in clicks_list if not click.is_positive + ] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [ + (-1, -1, -1) + ] total_clicks.append(pos_clicks + neg_clicks) return torch.tensor(total_clicks, device=self.device) @@ -156,19 +190,45 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get( "INTERACTIVE_SEG_MODEL_URL", "https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt", ) -INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf") +INTERACTIVE_SEG_MODEL_MD5 = os.environ.get( + "INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf" +) class InteractiveSeg: - def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): - device = torch.device('cpu') - model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval() - self.predictor = ISPredictor(model, device, - infer_size=infer_size, - open_kernel_size=open_kernel_size, - dilate_kernel_size=dilate_kernel_size) + name = "InteractiveSeg" - def __call__(self, image, clicks, prev_mask=None): + def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): + device = torch.device("cpu") + model = load_jit_model( + INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5 + ).eval() + self.predictor = ISPredictor( + model, + device, + infer_size=infer_size, + open_kernel_size=open_kernel_size, + dilate_kernel_size=dilate_kernel_size, + ) + + def __call__(self, rgb_np_img, files, form): + image = rgb_np_img + if "mask" in files: + mask, _ = load_img(files["mask"].read(), gray=True) + else: + mask = None + + _clicks = json.loads(form["clicks"]) + clicks = [] + for i, click in enumerate(_clicks): + clicks.append( + Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1) + ) + + new_mask = self.forward(image, clicks=clicks, prev_mask=mask) + return new_mask + + def forward(self, image, clicks, prev_mask=None): """ Args: @@ -183,7 +243,7 @@ class InteractiveSeg: if prev_mask is None: mask = torch.zeros_like(image[:, :1, :, :]) else: - logger.info('InteractiveSeg run with prev_mask') + logger.info("InteractiveSeg run with prev_mask") mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float() pred_probs = self.predictor(image, clicks, mask) diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py new file mode 100644 index 0000000..e380bfd --- /dev/null +++ b/lama_cleaner/plugins/remove_bg.py @@ -0,0 +1,22 @@ +import cv2 +import numpy as np + + +class RemoveBG: + name = "RemoveBG" + + def __init__(self): + from rembg import new_session + + self.session = new_session(model_name="u2net") + + def __call__(self, rgb_np_img, files, form): + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + return self.forward(bgr_np_img) + + def forward(self, bgr_np_img) -> np.ndarray: + from rembg import remove + + # return BGRA image + output = remove(bgr_np_img, session=self.session) + return output diff --git a/lama_cleaner/plugins/upscale.py b/lama_cleaner/plugins/upscale.py new file mode 100644 index 0000000..e4eb8bc --- /dev/null +++ b/lama_cleaner/plugins/upscale.py @@ -0,0 +1,46 @@ +import cv2 + +from lama_cleaner.helper import download_model + + +class RealESRGANUpscaler: + name = "RealESRGAN" + + def __init__(self, device): + super().__init__() + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + + scale = 4 + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" + model_md5 = "99ec365d4afad750833258a1a24f44ca" + model_path = download_model(url, model_md5) + + self.model = RealESRGANer( + scale=scale, + model_path=model_path, + model=model, + half=True if "cuda" in str(device) else False, + tile=640, + tile_pad=10, + pre_pad=10, + device=device, + ) + + def __call__(self, rgb_np_img, files, form): + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + scale = float(form["scale"]) + return self.forward(bgr_np_img, scale) + + def forward(self, bgr_np_img, scale: float): + # 输出是 BGR + upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] + return upsampled diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index a896e64..a19b027 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import io -import json import logging import multiprocessing import os @@ -16,12 +15,12 @@ import cv2 import torch import numpy as np from loguru import logger -from watchdog.events import FileSystemEventHandler from lama_cleaner.const import SD15_MODELS -from lama_cleaner.interactive_seg import InteractiveSeg, Click from lama_cleaner.make_gif import make_compare_gif +from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_manager import ModelManager +from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler from lama_cleaner.schema import Config from lama_cleaner.file_manager import FileManager @@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"]) model: ModelManager = None thumb: FileManager = None output_dir: str = None -interactive_seg_model: InteractiveSeg = None device = None input_image_path: str = None is_disable_model_switch: bool = False @@ -94,6 +92,7 @@ is_enable_file_manager: bool = False is_enable_auto_saving: bool = False is_desktop: bool = False image_quality: int = 95 +plugins = {} def get_image_ext(img_bytes): @@ -319,35 +318,37 @@ def process(): return response -@app.route("/interactive_seg", methods=["POST"]) -def interactive_seg(): - input = request.files - origin_image_bytes = input["image"].read() # RGB - image, _ = load_img(origin_image_bytes) - if "mask" in input: - mask, _ = load_img(input["mask"].read(), gray=True) - else: - mask = None +@app.route("/run_plugin/", methods=["POST"]) +def run_plugin(): + form = request.form + files = request.files - _clicks = json.loads(request.form["clicks"]) - clicks = [] - for i, click in enumerate(_clicks): - clicks.append( - Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1) - ) + name = form["name"] + if name not in plugins: + return "Plugin not found", 500 + + origin_image_bytes = files["image"].read() # RGB + rgb_np_img, _ = load_img(origin_image_bytes) start = time.time() - new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) - logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") + res = plugins[name](rgb_np_img, files, form) + logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") + torch_gc() + response = make_response( send_file( - io.BytesIO(numpy_to_bytes(new_mask, "png")), + io.BytesIO(numpy_to_bytes(res, "png")), mimetype=f"image/png", ) ) return response +@app.route("/plugins/", methods=["GET"]) +def get_plugins(): + return list(plugins.keys()), 200 + + @app.route("/model") def current_model(): return model.name, 200 @@ -423,14 +424,21 @@ def set_input_photo(): return "No Input Image" -class FSHandler(FileSystemEventHandler): - def on_modified(self, event): - print("File modified: %s" % event.src_path) +def build_plugins(args): + global plugins + if args.enable_interactive_seg: + logger.info(f"Initialize {InteractiveSeg.name} plugin") + plugins[InteractiveSeg.name] = InteractiveSeg() + if args.enable_remove_bg: + logger.info(f"Initialize {RemoveBG.name} plugin") + plugins[RemoveBG.name] = RemoveBG() + if args.enable_realesrgan: + logger.info(f"Initialize {RealESRGANUpscaler.name} plugin") + plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device) def main(args): global model - global interactive_seg_model global device global input_image_path global is_disable_model_switch @@ -442,6 +450,8 @@ def main(args): global is_controlnet global image_quality + build_plugins(args) + image_quality = args.quality if args.sd_controlnet and args.model in SD15_MODELS: @@ -496,8 +506,6 @@ def main(args): callback=diffuser_callback, ) - interactive_seg_model = InteractiveSeg() - if args.gui: app_width, app_height = args.gui_size from flaskwebgui import FlaskUI diff --git a/lama_cleaner/tests/test_interactive_seg.py b/lama_cleaner/tests/test_interactive_seg.py index 7dcffd6..c8829d1 100644 --- a/lama_cleaner/tests/test_interactive_seg.py +++ b/lama_cleaner/tests/test_interactive_seg.py @@ -1,13 +1,12 @@ from pathlib import Path -import pytest import cv2 import numpy as np -from lama_cleaner.interactive_seg import InteractiveSeg, Click +from lama_cleaner.plugins import InteractiveSeg, Click current_dir = Path(__file__).parent.absolute().resolve() -save_dir = current_dir / 'result' +save_dir = current_dir / "result" save_dir.mkdir(exist_ok=True, parents=True) img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" @@ -15,17 +14,22 @@ img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" def test_interactive_seg(): interactive_seg_model = InteractiveSeg() img = cv2.imread(str(img_p)) - pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]) + pred = interactive_seg_model.forward( + img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)] + ) cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred) def test_interactive_seg_with_negative_click(): interactive_seg_model = InteractiveSeg() img = cv2.imread(str(img_p)) - pred = interactive_seg_model(img, clicks=[ - Click(coords=(256, 256), indx=0, is_positive=True), - Click(coords=(384, 256), indx=1, is_positive=False) - ]) + pred = interactive_seg_model.forward( + img, + clicks=[ + Click(coords=(256, 256), indx=0, is_positive=True), + Click(coords=(384, 256), indx=1, is_positive=False), + ], + ) cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred) @@ -33,5 +37,7 @@ def test_interactive_seg_with_prev_mask(): interactive_seg_model = InteractiveSeg() img = cv2.imread(str(img_p)) mask = np.zeros_like(img)[:, :, 0] - pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask) + pred = interactive_seg_model.forward( + img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask + ) cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred) diff --git a/lama_cleaner/tests/test_model_md5.py b/lama_cleaner/tests/test_model_md5.py index 7480fcd..3bec003 100644 --- a/lama_cleaner/tests/test_model_md5.py +++ b/lama_cleaner/tests/test_model_md5.py @@ -1,10 +1,5 @@ -import os -import tempfile -from pathlib import Path - - def test_load_model(): - from lama_cleaner.interactive_seg import InteractiveSeg + from lama_cleaner.plugins import InteractiveSeg from lama_cleaner.model_manager import ModelManager interactive_seg_model = InteractiveSeg() diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py new file mode 100644 index 0000000..0634a5d --- /dev/null +++ b/lama_cleaner/tests/test_plugins.py @@ -0,0 +1,27 @@ +from pathlib import Path + +import cv2 + +from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +img_p = current_dir / "bunny.jpeg" + + +def test_remove_bg(): + model = RemoveBG() + img = cv2.imread(str(img_p)) + res = model.forward(img) + cv2.imwrite(str(save_dir / "test_remove_bg.png"), res) + + +def test_upscale(): + model = RealESRGANUpscaler("cpu") + img = cv2.imread(str(img_p)) + res = model.forward(img, 2) + cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res) + + res = model.forward(img, 4) + cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res) diff --git a/requirements.txt b/requirements.txt index 27127f9..6a530b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,6 @@ markupsafe==2.0.1 scikit-image==0.19.3 diffusers[torch]==0.14.0 transformers>=4.25.1 -watchdog==2.2.1 gradio piexif==1.1.3 safetensors