add plugin dep check
This commit is contained in:
parent
1433d21b9f
commit
d938f2da3c
@ -1,3 +1,4 @@
|
|||||||
from .interactive_seg import InteractiveSeg, Click
|
from .interactive_seg import InteractiveSeg, Click
|
||||||
from .remove_bg import RemoveBG
|
from .remove_bg import RemoveBG
|
||||||
from .realesrgan import RealESRGANUpscaler
|
from .realesrgan import RealESRGANUpscaler
|
||||||
|
from .gif import MakeGIF
|
||||||
|
15
lama_cleaner/plugins/base_plugin.py
Normal file
15
lama_cleaner/plugins/base_plugin.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlugin:
|
||||||
|
def __init__(self):
|
||||||
|
err_msg = self.check_dep()
|
||||||
|
if err_msg:
|
||||||
|
logger.error(err_msg)
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
...
|
||||||
|
|
||||||
|
def check_dep(self):
|
||||||
|
...
|
@ -4,6 +4,7 @@ import math
|
|||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from lama_cleaner.helper import load_img
|
from lama_cleaner.helper import load_img
|
||||||
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
|
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
|
||||||
@ -117,7 +118,7 @@ def make_compare_gif(
|
|||||||
[(right, 0), (right, height)], width=splitter_width, fill=splitter_color
|
[(right, 0), (right, height)], width=splitter_width, fill=splitter_color
|
||||||
)
|
)
|
||||||
images.append(new_frame)
|
images.append(new_frame)
|
||||||
|
|
||||||
for _ in range(30):
|
for _ in range(30):
|
||||||
images.append(clean_img)
|
images.append(clean_img)
|
||||||
|
|
||||||
@ -135,7 +136,7 @@ def make_compare_gif(
|
|||||||
return img_byte_arr.getvalue()
|
return img_byte_arr.getvalue()
|
||||||
|
|
||||||
|
|
||||||
class MakeGIF:
|
class MakeGIF(BasePlugin):
|
||||||
name = "MakeGIF"
|
name = "MakeGIF"
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
|
|
||||||
@ -14,6 +13,7 @@ from lama_cleaner.helper import (
|
|||||||
load_jit_model,
|
load_jit_model,
|
||||||
load_img,
|
load_img,
|
||||||
)
|
)
|
||||||
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
class Click(BaseModel):
|
class Click(BaseModel):
|
||||||
@ -195,10 +195,11 @@ INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InteractiveSeg:
|
class InteractiveSeg(BasePlugin):
|
||||||
name = "InteractiveSeg"
|
name = "InteractiveSeg"
|
||||||
|
|
||||||
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
model = load_jit_model(
|
model = load_jit_model(
|
||||||
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
|
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
|
||||||
|
@ -4,6 +4,7 @@ import cv2
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANModelName(str, Enum):
|
class RealESRGANModelName(str, Enum):
|
||||||
@ -15,7 +16,7 @@ class RealESRGANModelName(str, Enum):
|
|||||||
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANUpscaler:
|
class RealESRGANUpscaler(BasePlugin):
|
||||||
name = "RealESRGAN"
|
name = "RealESRGAN"
|
||||||
|
|
||||||
def __init__(self, name, device):
|
def __init__(self, name, device):
|
||||||
@ -84,7 +85,7 @@ class RealESRGANUpscaler:
|
|||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, files, form):
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
scale = float(form['upscale'])
|
scale = float(form["upscale"])
|
||||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
||||||
result = self.forward(bgr_np_img, scale)
|
result = self.forward(bgr_np_img, scale)
|
||||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||||
@ -94,3 +95,9 @@ class RealESRGANUpscaler:
|
|||||||
# 输出是 BGR
|
# 输出是 BGR
|
||||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||||
return upsampled
|
return upsampled
|
||||||
|
|
||||||
|
def check_dep(self):
|
||||||
|
try:
|
||||||
|
import realesrgan
|
||||||
|
except ImportError:
|
||||||
|
return "RealESRGAN is not installed, please install it first. pip install realesrgan"
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
|
||||||
class RemoveBG:
|
|
||||||
|
class RemoveBG(BasePlugin):
|
||||||
name = "RemoveBG"
|
name = "RemoveBG"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
from rembg import new_session
|
from rembg import new_session
|
||||||
|
|
||||||
self.session = new_session(model_name="u2net")
|
self.session = new_session(model_name="u2net")
|
||||||
@ -20,3 +23,11 @@ class RemoveBG:
|
|||||||
# return BGRA image
|
# return BGRA image
|
||||||
output = remove(bgr_np_img, session=self.session)
|
output = remove(bgr_np_img, session=self.session)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def check_dep(self):
|
||||||
|
try:
|
||||||
|
import rembg
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"RemoveBG is not installed, please install it first. pip install rembg"
|
||||||
|
)
|
||||||
|
@ -1,28 +1,26 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import imghdr
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import imghdr
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import SD15_MODELS
|
from lama_cleaner.const import SD15_MODELS
|
||||||
|
from lama_cleaner.file_manager import FileManager
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
|
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler, MakeGIF
|
||||||
from lama_cleaner.plugins.gif import MakeGIF
|
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
from lama_cleaner.file_manager import FileManager
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
|
Loading…
Reference in New Issue
Block a user