add plugin dep check

This commit is contained in:
Qing 2023-03-26 12:37:58 +08:00
parent 1433d21b9f
commit d938f2da3c
7 changed files with 48 additions and 14 deletions

View File

@ -1,3 +1,4 @@
from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG
from .realesrgan import RealESRGANUpscaler
from .gif import MakeGIF

View File

@ -0,0 +1,15 @@
from loguru import logger
class BasePlugin:
def __init__(self):
err_msg = self.check_dep()
if err_msg:
logger.error(err_msg)
exit(-1)
def __call__(self, rgb_np_img, files, form):
...
def check_dep(self):
...

View File

@ -4,6 +4,7 @@ import math
from PIL import Image, ImageDraw
from lama_cleaner.helper import load_img
from lama_cleaner.plugins.base_plugin import BasePlugin
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
@ -135,7 +136,7 @@ def make_compare_gif(
return img_byte_arr.getvalue()
class MakeGIF:
class MakeGIF(BasePlugin):
name = "MakeGIF"
def __call__(self, rgb_np_img, files, form):

View File

@ -1,5 +1,4 @@
import json
import json
import os
from typing import Tuple, List
@ -14,6 +13,7 @@ from lama_cleaner.helper import (
load_jit_model,
load_img,
)
from lama_cleaner.plugins.base_plugin import BasePlugin
class Click(BaseModel):
@ -195,10 +195,11 @@ INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
)
class InteractiveSeg:
class InteractiveSeg(BasePlugin):
name = "InteractiveSeg"
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
super().__init__()
device = torch.device("cpu")
model = load_jit_model(
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5

View File

@ -4,6 +4,7 @@ import cv2
from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
class RealESRGANModelName(str, Enum):
@ -15,7 +16,7 @@ class RealESRGANModelName(str, Enum):
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
class RealESRGANUpscaler:
class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN"
def __init__(self, name, device):
@ -84,7 +85,7 @@ class RealESRGANUpscaler:
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form['upscale'])
scale = float(form["upscale"])
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
result = self.forward(bgr_np_img, scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
@ -94,3 +95,9 @@ class RealESRGANUpscaler:
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled
def check_dep(self):
try:
import realesrgan
except ImportError:
return "RealESRGAN is not installed, please install it first. pip install realesrgan"

View File

@ -1,11 +1,14 @@
import cv2
import numpy as np
from lama_cleaner.plugins.base_plugin import BasePlugin
class RemoveBG:
class RemoveBG(BasePlugin):
name = "RemoveBG"
def __init__(self):
super().__init__()
from rembg import new_session
self.session = new_session(model_name="u2net")
@ -20,3 +23,11 @@ class RemoveBG:
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return output
def check_dep(self):
try:
import rembg
except ImportError:
return (
"RemoveBG is not installed, please install it first. pip install rembg"
)

View File

@ -1,28 +1,26 @@
#!/usr/bin/env python3
import imghdr
import io
import logging
import multiprocessing
import os
import random
import time
import imghdr
from pathlib import Path
from typing import Union
from PIL import Image
import cv2
import torch
import numpy as np
import torch
from PIL import Image
from loguru import logger
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.plugins.gif import MakeGIF
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler, MakeGIF
from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager
try:
torch._C._jit_override_can_fuse_on_cpu(False)