add plugins
This commit is contained in:
parent
b48d964c2c
commit
5a38d28ad1
@ -7,8 +7,8 @@ import time
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
# from watchdog.events import FileSystemEventHandler
|
||||
# from watchdog.observers import Observer
|
||||
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
from loguru import logger
|
||||
@ -19,7 +19,7 @@ from .storage_backends import FilesystemStorageBackend
|
||||
from .utils import aspect_to_string, generate_filename, glob_img
|
||||
|
||||
|
||||
class FileManager(FileSystemEventHandler):
|
||||
class FileManager:
|
||||
def __init__(self, app=None):
|
||||
self.app = app
|
||||
self._default_root_directory = "media"
|
||||
@ -43,19 +43,19 @@ class FileManager(FileSystemEventHandler):
|
||||
"output": datetime.utcnow(),
|
||||
}
|
||||
|
||||
def start(self):
|
||||
self.image_dir_filenames = self._media_names(self.root_directory)
|
||||
self.output_dir_filenames = self._media_names(self.output_dir)
|
||||
|
||||
logger.info(f"Start watching image directory: {self.root_directory}")
|
||||
self.image_dir_observer = Observer()
|
||||
self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
||||
self.image_dir_observer.start()
|
||||
|
||||
logger.info(f"Start watching output directory: {self.output_dir}")
|
||||
self.output_dir_observer = Observer()
|
||||
self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
||||
self.output_dir_observer.start()
|
||||
# def start(self):
|
||||
# self.image_dir_filenames = self._media_names(self.root_directory)
|
||||
# self.output_dir_filenames = self._media_names(self.output_dir)
|
||||
#
|
||||
# logger.info(f"Start watching image directory: {self.root_directory}")
|
||||
# self.image_dir_observer = Observer()
|
||||
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
||||
# self.image_dir_observer.start()
|
||||
#
|
||||
# logger.info(f"Start watching output directory: {self.output_dir}")
|
||||
# self.output_dir_observer = Observer()
|
||||
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
||||
# self.output_dir_observer.start()
|
||||
|
||||
def on_modified(self, event):
|
||||
if not os.path.isdir(event.src_path):
|
||||
|
@ -69,9 +69,33 @@ def parse_args():
|
||||
help="Disable model switch in frontend",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quality", default=95, type=int, help=QUALITY_HELP,
|
||||
"--quality",
|
||||
default=95,
|
||||
type=int,
|
||||
help=QUALITY_HELP,
|
||||
)
|
||||
|
||||
# Plugins
|
||||
parser.add_argument(
|
||||
"--enable-interactive-seg",
|
||||
action="store_true",
|
||||
help="Enable interactive segmentation. Always run on CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-remove-bg",
|
||||
action="store_true",
|
||||
help="Enable remove background. Always run on CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-realesrgan",
|
||||
action="store_true",
|
||||
help="Enable realesrgan super resolution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
|
||||
)
|
||||
#########
|
||||
|
||||
# useless args
|
||||
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
|
||||
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
|
||||
@ -123,6 +147,7 @@ def parse_args():
|
||||
if args.model not in SD15_MODELS:
|
||||
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
|
||||
|
||||
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
|
||||
if args.model_dir and args.model_dir is not None:
|
||||
if os.path.isfile(args.model_dir):
|
||||
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
||||
@ -132,6 +157,7 @@ def parse_args():
|
||||
Path(args.model_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
||||
os.environ["U2NET_HOME"] = args.model_dir
|
||||
|
||||
if args.input and args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
|
3
lama_cleaner/plugins/__init__.py
Normal file
3
lama_cleaner/plugins/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .interactive_seg import InteractiveSeg, Click
|
||||
from .remove_bg import RemoveBG
|
||||
from .upscale import RealESRGANUpscaler
|
@ -1,14 +1,19 @@
|
||||
import json
|
||||
import json
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
|
||||
import cv2
|
||||
from typing import Tuple, List
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import only_keep_largest_contour, load_jit_model
|
||||
from lama_cleaner.helper import (
|
||||
load_jit_model,
|
||||
load_img,
|
||||
)
|
||||
|
||||
|
||||
class Click(BaseModel):
|
||||
@ -21,11 +26,11 @@ class Click(BaseModel):
|
||||
def coords_and_indx(self):
|
||||
return (*self.coords, self.indx)
|
||||
|
||||
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
|
||||
def scale(self, x_ratio: float, y_ratio: float) -> "Click":
|
||||
return Click(
|
||||
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
||||
is_positive=self.is_positive,
|
||||
indx=self.indx
|
||||
indx=self.indx,
|
||||
)
|
||||
|
||||
|
||||
@ -40,21 +45,32 @@ class ResizeTrans:
|
||||
image_height, image_width = image_nd.shape[2:4]
|
||||
self.image_height = image_height
|
||||
self.image_width = image_width
|
||||
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
|
||||
image_nd_r = F.interpolate(
|
||||
image_nd,
|
||||
(self.crop_height, self.crop_width),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
y_ratio = self.crop_height / image_height
|
||||
x_ratio = self.crop_width / image_width
|
||||
|
||||
clicks_lists_resized = []
|
||||
for clicks_list in clicks_lists:
|
||||
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
|
||||
clicks_list_resized = [
|
||||
click.scale(y_ratio, x_ratio) for click in clicks_list
|
||||
]
|
||||
clicks_lists_resized.append(clicks_list_resized)
|
||||
|
||||
return image_nd_r, clicks_lists_resized
|
||||
|
||||
def inv_transform(self, prob_map):
|
||||
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
|
||||
align_corners=True)
|
||||
new_prob_map = F.interpolate(
|
||||
prob_map,
|
||||
(self.image_height, self.image_width),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
return new_prob_map
|
||||
|
||||
@ -106,8 +122,9 @@ class ISPredictor(object):
|
||||
pred = torch.sigmoid(pred_logits)
|
||||
pred = self.post_process(pred)
|
||||
|
||||
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
|
||||
size=image_nd.size()[2:])
|
||||
prediction = F.interpolate(
|
||||
pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
|
||||
)
|
||||
|
||||
for t in reversed(transforms):
|
||||
prediction = t.inv_transform(prediction)
|
||||
@ -121,32 +138,49 @@ class ISPredictor(object):
|
||||
pred_mask = pred.cpu().numpy()[0][0]
|
||||
# morph_open to remove small noise
|
||||
kernel_size = self.open_kernel_size
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
|
||||
)
|
||||
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
||||
|
||||
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
||||
dilate_kernel_size = self.dilate_kernel_size
|
||||
if dilate_kernel_size > 1:
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
|
||||
kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)
|
||||
)
|
||||
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
||||
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def get_points_nd(self, clicks_lists):
|
||||
total_clicks = []
|
||||
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
||||
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
||||
num_pos_clicks = [
|
||||
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
|
||||
]
|
||||
num_neg_clicks = [
|
||||
len(clicks_list) - num_pos
|
||||
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
|
||||
]
|
||||
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
||||
if self.net_clicks_limit is not None:
|
||||
num_max_points = min(self.net_clicks_limit, num_max_points)
|
||||
num_max_points = max(1, num_max_points)
|
||||
|
||||
for clicks_list in clicks_lists:
|
||||
clicks_list = clicks_list[:self.net_clicks_limit]
|
||||
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
||||
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
||||
clicks_list = clicks_list[: self.net_clicks_limit]
|
||||
pos_clicks = [
|
||||
click.coords_and_indx for click in clicks_list if click.is_positive
|
||||
]
|
||||
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
|
||||
(-1, -1, -1)
|
||||
]
|
||||
|
||||
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
||||
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
||||
neg_clicks = [
|
||||
click.coords_and_indx for click in clicks_list if not click.is_positive
|
||||
]
|
||||
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
|
||||
(-1, -1, -1)
|
||||
]
|
||||
total_clicks.append(pos_clicks + neg_clicks)
|
||||
|
||||
return torch.tensor(total_clicks, device=self.device)
|
||||
@ -156,19 +190,45 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get(
|
||||
"INTERACTIVE_SEG_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
|
||||
)
|
||||
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf")
|
||||
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
|
||||
"INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf"
|
||||
)
|
||||
|
||||
|
||||
class InteractiveSeg:
|
||||
name = "InteractiveSeg"
|
||||
|
||||
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
||||
device = torch.device('cpu')
|
||||
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval()
|
||||
self.predictor = ISPredictor(model, device,
|
||||
device = torch.device("cpu")
|
||||
model = load_jit_model(
|
||||
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
|
||||
).eval()
|
||||
self.predictor = ISPredictor(
|
||||
model,
|
||||
device,
|
||||
infer_size=infer_size,
|
||||
open_kernel_size=open_kernel_size,
|
||||
dilate_kernel_size=dilate_kernel_size)
|
||||
dilate_kernel_size=dilate_kernel_size,
|
||||
)
|
||||
|
||||
def __call__(self, image, clicks, prev_mask=None):
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
image = rgb_np_img
|
||||
if "mask" in files:
|
||||
mask, _ = load_img(files["mask"].read(), gray=True)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
_clicks = json.loads(form["clicks"])
|
||||
clicks = []
|
||||
for i, click in enumerate(_clicks):
|
||||
clicks.append(
|
||||
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||
)
|
||||
|
||||
new_mask = self.forward(image, clicks=clicks, prev_mask=mask)
|
||||
return new_mask
|
||||
|
||||
def forward(self, image, clicks, prev_mask=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -183,7 +243,7 @@ class InteractiveSeg:
|
||||
if prev_mask is None:
|
||||
mask = torch.zeros_like(image[:, :1, :, :])
|
||||
else:
|
||||
logger.info('InteractiveSeg run with prev_mask')
|
||||
logger.info("InteractiveSeg run with prev_mask")
|
||||
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
||||
|
||||
pred_probs = self.predictor(image, clicks, mask)
|
22
lama_cleaner/plugins/remove_bg.py
Normal file
22
lama_cleaner/plugins/remove_bg.py
Normal file
@ -0,0 +1,22 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RemoveBG:
|
||||
name = "RemoveBG"
|
||||
|
||||
def __init__(self):
|
||||
from rembg import new_session
|
||||
|
||||
self.session = new_session(model_name="u2net")
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
return self.forward(bgr_np_img)
|
||||
|
||||
def forward(self, bgr_np_img) -> np.ndarray:
|
||||
from rembg import remove
|
||||
|
||||
# return BGRA image
|
||||
output = remove(bgr_np_img, session=self.session)
|
||||
return output
|
46
lama_cleaner/plugins/upscale.py
Normal file
46
lama_cleaner/plugins/upscale.py
Normal file
@ -0,0 +1,46 @@
|
||||
import cv2
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
|
||||
|
||||
class RealESRGANUpscaler:
|
||||
name = "RealESRGAN"
|
||||
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
scale = 4
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
||||
model_md5 = "99ec365d4afad750833258a1a24f44ca"
|
||||
model_path = download_model(url, model_md5)
|
||||
|
||||
self.model = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model=model,
|
||||
half=True if "cuda" in str(device) else False,
|
||||
tile=640,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
scale = float(form["scale"])
|
||||
return self.forward(bgr_np_img, scale)
|
||||
|
||||
def forward(self, bgr_np_img, scale: float):
|
||||
# 输出是 BGR
|
||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||
return upsampled
|
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -16,12 +15,12 @@ import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from lama_cleaner.const import SD15_MODELS
|
||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||
from lama_cleaner.make_gif import make_compare_gif
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
|
||||
@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
|
||||
model: ModelManager = None
|
||||
thumb: FileManager = None
|
||||
output_dir: str = None
|
||||
interactive_seg_model: InteractiveSeg = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
|
||||
is_enable_auto_saving: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
@ -319,35 +318,37 @@ def process():
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/interactive_seg", methods=["POST"])
|
||||
def interactive_seg():
|
||||
input = request.files
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
if "mask" in input:
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
else:
|
||||
mask = None
|
||||
@app.route("/run_plugin/", methods=["POST"])
|
||||
def run_plugin():
|
||||
form = request.form
|
||||
files = request.files
|
||||
|
||||
_clicks = json.loads(request.form["clicks"])
|
||||
clicks = []
|
||||
for i, click in enumerate(_clicks):
|
||||
clicks.append(
|
||||
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||
)
|
||||
name = form["name"]
|
||||
if name not in plugins:
|
||||
return "Plugin not found", 500
|
||||
|
||||
origin_image_bytes = files["image"].read() # RGB
|
||||
rgb_np_img, _ = load_img(origin_image_bytes)
|
||||
|
||||
start = time.time()
|
||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||
res = plugins[name](rgb_np_img, files, form)
|
||||
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
|
||||
torch_gc()
|
||||
|
||||
response = make_response(
|
||||
send_file(
|
||||
io.BytesIO(numpy_to_bytes(new_mask, "png")),
|
||||
io.BytesIO(numpy_to_bytes(res, "png")),
|
||||
mimetype=f"image/png",
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/plugins/", methods=["GET"])
|
||||
def get_plugins():
|
||||
return list(plugins.keys()), 200
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
@ -423,14 +424,21 @@ def set_input_photo():
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
class FSHandler(FileSystemEventHandler):
|
||||
def on_modified(self, event):
|
||||
print("File modified: %s" % event.src_path)
|
||||
def build_plugins(args):
|
||||
global plugins
|
||||
if args.enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg()
|
||||
if args.enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
if args.enable_realesrgan:
|
||||
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
|
||||
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
|
||||
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global interactive_seg_model
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
@ -442,6 +450,8 @@ def main(args):
|
||||
global is_controlnet
|
||||
global image_quality
|
||||
|
||||
build_plugins(args)
|
||||
|
||||
image_quality = args.quality
|
||||
|
||||
if args.sd_controlnet and args.model in SD15_MODELS:
|
||||
@ -496,8 +506,6 @@ def main(args):
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
|
@ -1,13 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||
from lama_cleaner.plugins import InteractiveSeg, Click
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / 'result'
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
||||
|
||||
@ -15,17 +14,22 @@ img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
||||
def test_interactive_seg():
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
img = cv2.imread(str(img_p))
|
||||
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)])
|
||||
pred = interactive_seg_model.forward(
|
||||
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
|
||||
)
|
||||
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
|
||||
|
||||
|
||||
def test_interactive_seg_with_negative_click():
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
img = cv2.imread(str(img_p))
|
||||
pred = interactive_seg_model(img, clicks=[
|
||||
pred = interactive_seg_model.forward(
|
||||
img,
|
||||
clicks=[
|
||||
Click(coords=(256, 256), indx=0, is_positive=True),
|
||||
Click(coords=(384, 256), indx=1, is_positive=False)
|
||||
])
|
||||
Click(coords=(384, 256), indx=1, is_positive=False),
|
||||
],
|
||||
)
|
||||
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
|
||||
|
||||
|
||||
@ -33,5 +37,7 @@ def test_interactive_seg_with_prev_mask():
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
img = cv2.imread(str(img_p))
|
||||
mask = np.zeros_like(img)[:, :, 0]
|
||||
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask)
|
||||
pred = interactive_seg_model.forward(
|
||||
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
|
||||
)
|
||||
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)
|
||||
|
@ -1,10 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_load_model():
|
||||
from lama_cleaner.interactive_seg import InteractiveSeg
|
||||
from lama_cleaner.plugins import InteractiveSeg
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
|
27
lama_cleaner/tests/test_plugins.py
Normal file
27
lama_cleaner/tests/test_plugins.py
Normal file
@ -0,0 +1,27 @@
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
img_p = current_dir / "bunny.jpeg"
|
||||
|
||||
|
||||
def test_remove_bg():
|
||||
model = RemoveBG()
|
||||
img = cv2.imread(str(img_p))
|
||||
res = model.forward(img)
|
||||
cv2.imwrite(str(save_dir / "test_remove_bg.png"), res)
|
||||
|
||||
|
||||
def test_upscale():
|
||||
model = RealESRGANUpscaler("cpu")
|
||||
img = cv2.imread(str(img_p))
|
||||
res = model.forward(img, 2)
|
||||
cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res)
|
||||
|
||||
res = model.forward(img, 4)
|
||||
cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res)
|
@ -14,7 +14,6 @@ markupsafe==2.0.1
|
||||
scikit-image==0.19.3
|
||||
diffusers[torch]==0.14.0
|
||||
transformers>=4.25.1
|
||||
watchdog==2.2.1
|
||||
gradio
|
||||
piexif==1.1.3
|
||||
safetensors
|
||||
|
Loading…
Reference in New Issue
Block a user