diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index d912a0a..87ae130 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -1,3 +1,4 @@ from .interactive_seg import InteractiveSeg, Click from .remove_bg import RemoveBG from .realesrgan import RealESRGANUpscaler +from .gif import MakeGIF diff --git a/lama_cleaner/plugins/base_plugin.py b/lama_cleaner/plugins/base_plugin.py new file mode 100644 index 0000000..39d7491 --- /dev/null +++ b/lama_cleaner/plugins/base_plugin.py @@ -0,0 +1,15 @@ +from loguru import logger + + +class BasePlugin: + def __init__(self): + err_msg = self.check_dep() + if err_msg: + logger.error(err_msg) + exit(-1) + + def __call__(self, rgb_np_img, files, form): + ... + + def check_dep(self): + ... diff --git a/lama_cleaner/plugins/gif.py b/lama_cleaner/plugins/gif.py index 3f69901..d9921c3 100644 --- a/lama_cleaner/plugins/gif.py +++ b/lama_cleaner/plugins/gif.py @@ -4,6 +4,7 @@ import math from PIL import Image, ImageDraw from lama_cleaner.helper import load_img +from lama_cleaner.plugins.base_plugin import BasePlugin def keep_ratio_resize(img, size, resample=Image.BILINEAR): @@ -117,7 +118,7 @@ def make_compare_gif( [(right, 0), (right, height)], width=splitter_width, fill=splitter_color ) images.append(new_frame) - + for _ in range(30): images.append(clean_img) @@ -135,7 +136,7 @@ def make_compare_gif( return img_byte_arr.getvalue() -class MakeGIF: +class MakeGIF(BasePlugin): name = "MakeGIF" def __call__(self, rgb_np_img, files, form): diff --git a/lama_cleaner/plugins/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py index 4f0c21c..8fe5720 100644 --- a/lama_cleaner/plugins/interactive_seg.py +++ b/lama_cleaner/plugins/interactive_seg.py @@ -1,5 +1,4 @@ import json -import json import os from typing import Tuple, List @@ -14,6 +13,7 @@ from lama_cleaner.helper import ( load_jit_model, load_img, ) +from lama_cleaner.plugins.base_plugin import BasePlugin class Click(BaseModel): @@ -195,10 +195,11 @@ INTERACTIVE_SEG_MODEL_MD5 = os.environ.get( ) -class InteractiveSeg: +class InteractiveSeg(BasePlugin): name = "InteractiveSeg" def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): + super().__init__() device = torch.device("cpu") model = load_jit_model( INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5 diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 6652764..6876b82 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -4,6 +4,7 @@ import cv2 from loguru import logger from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin class RealESRGANModelName(str, Enum): @@ -15,7 +16,7 @@ class RealESRGANModelName(str, Enum): RealESRGANModelNameList = [e.value for e in RealESRGANModelName] -class RealESRGANUpscaler: +class RealESRGANUpscaler(BasePlugin): name = "RealESRGAN" def __init__(self, name, device): @@ -84,7 +85,7 @@ class RealESRGANUpscaler: def __call__(self, rgb_np_img, files, form): bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) - scale = float(form['upscale']) + scale = float(form["upscale"]) logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}") result = self.forward(bgr_np_img, scale) logger.info(f"RealESRGAN output shape: {result.shape}") @@ -94,3 +95,9 @@ class RealESRGANUpscaler: # 输出是 BGR upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] return upsampled + + def check_dep(self): + try: + import realesrgan + except ImportError: + return "RealESRGAN is not installed, please install it first. pip install realesrgan" diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py index e380bfd..15dfc5b 100644 --- a/lama_cleaner/plugins/remove_bg.py +++ b/lama_cleaner/plugins/remove_bg.py @@ -1,11 +1,14 @@ import cv2 import numpy as np +from lama_cleaner.plugins.base_plugin import BasePlugin -class RemoveBG: + +class RemoveBG(BasePlugin): name = "RemoveBG" def __init__(self): + super().__init__() from rembg import new_session self.session = new_session(model_name="u2net") @@ -20,3 +23,11 @@ class RemoveBG: # return BGRA image output = remove(bgr_np_img, session=self.session) return output + + def check_dep(self): + try: + import rembg + except ImportError: + return ( + "RemoveBG is not installed, please install it first. pip install rembg" + ) diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index ead7190..7216410 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,28 +1,26 @@ #!/usr/bin/env python3 +import imghdr import io import logging import multiprocessing import os import random import time -import imghdr from pathlib import Path -from typing import Union -from PIL import Image import cv2 -import torch import numpy as np +import torch +from PIL import Image from loguru import logger from lama_cleaner.const import SD15_MODELS +from lama_cleaner.file_manager import FileManager from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_manager import ModelManager -from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler -from lama_cleaner.plugins.gif import MakeGIF +from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler, MakeGIF from lama_cleaner.schema import Config -from lama_cleaner.file_manager import FileManager try: torch._C._jit_override_can_fuse_on_cpu(False)