add plugins

This commit is contained in:
Qing 2023-03-22 12:57:18 +08:00
parent b48d964c2c
commit 5a38d28ad1
11 changed files with 283 additions and 91 deletions

View File

@ -7,8 +7,8 @@ import time
from io import BytesIO
from pathlib import Path
import numpy as np
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
# from watchdog.events import FileSystemEventHandler
# from watchdog.observers import Observer
from PIL import Image, ImageOps, PngImagePlugin
from loguru import logger
@ -19,7 +19,7 @@ from .storage_backends import FilesystemStorageBackend
from .utils import aspect_to_string, generate_filename, glob_img
class FileManager(FileSystemEventHandler):
class FileManager:
def __init__(self, app=None):
self.app = app
self._default_root_directory = "media"
@ -43,19 +43,19 @@ class FileManager(FileSystemEventHandler):
"output": datetime.utcnow(),
}
def start(self):
self.image_dir_filenames = self._media_names(self.root_directory)
self.output_dir_filenames = self._media_names(self.output_dir)
logger.info(f"Start watching image directory: {self.root_directory}")
self.image_dir_observer = Observer()
self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
self.image_dir_observer.start()
logger.info(f"Start watching output directory: {self.output_dir}")
self.output_dir_observer = Observer()
self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
self.output_dir_observer.start()
# def start(self):
# self.image_dir_filenames = self._media_names(self.root_directory)
# self.output_dir_filenames = self._media_names(self.output_dir)
#
# logger.info(f"Start watching image directory: {self.root_directory}")
# self.image_dir_observer = Observer()
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
# self.image_dir_observer.start()
#
# logger.info(f"Start watching output directory: {self.output_dir}")
# self.output_dir_observer = Observer()
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
# self.output_dir_observer.start()
def on_modified(self, event):
if not os.path.isdir(event.src_path):

View File

@ -69,9 +69,33 @@ def parse_args():
help="Disable model switch in frontend",
)
parser.add_argument(
"--quality", default=95, type=int, help=QUALITY_HELP,
"--quality",
default=95,
type=int,
help=QUALITY_HELP,
)
# Plugins
parser.add_argument(
"--enable-interactive-seg",
action="store_true",
help="Enable interactive segmentation. Always run on CPU",
)
parser.add_argument(
"--enable-remove-bg",
action="store_true",
help="Enable remove background. Always run on CPU",
)
parser.add_argument(
"--enable-realesrgan",
action="store_true",
help="Enable realesrgan super resolution",
)
parser.add_argument(
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
)
#########
# useless args
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
@ -123,6 +147,7 @@ def parse_args():
if args.model not in SD15_MODELS:
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir):
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
@ -132,6 +157,7 @@ def parse_args():
Path(args.model_dir).mkdir(exist_ok=True, parents=True)
os.environ["XDG_CACHE_HOME"] = args.model_dir
os.environ["U2NET_HOME"] = args.model_dir
if args.input and args.input is not None:
if not os.path.exists(args.input):

View File

@ -0,0 +1,3 @@
from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG
from .upscale import RealESRGANUpscaler

View File

@ -1,14 +1,19 @@
import json
import json
import os
from typing import Tuple, List
import cv2
from typing import Tuple, List
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import BaseModel
import numpy as np
from lama_cleaner.helper import only_keep_largest_contour, load_jit_model
from lama_cleaner.helper import (
load_jit_model,
load_img,
)
class Click(BaseModel):
@ -21,11 +26,11 @@ class Click(BaseModel):
def coords_and_indx(self):
return (*self.coords, self.indx)
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
def scale(self, x_ratio: float, y_ratio: float) -> "Click":
return Click(
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
is_positive=self.is_positive,
indx=self.indx
indx=self.indx,
)
@ -40,21 +45,32 @@ class ResizeTrans:
image_height, image_width = image_nd.shape[2:4]
self.image_height = image_height
self.image_width = image_width
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
image_nd_r = F.interpolate(
image_nd,
(self.crop_height, self.crop_width),
mode="bilinear",
align_corners=True,
)
y_ratio = self.crop_height / image_height
x_ratio = self.crop_width / image_width
clicks_lists_resized = []
for clicks_list in clicks_lists:
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
clicks_list_resized = [
click.scale(y_ratio, x_ratio) for click in clicks_list
]
clicks_lists_resized.append(clicks_list_resized)
return image_nd_r, clicks_lists_resized
def inv_transform(self, prob_map):
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
align_corners=True)
new_prob_map = F.interpolate(
prob_map,
(self.image_height, self.image_width),
mode="bilinear",
align_corners=True,
)
return new_prob_map
@ -106,8 +122,9 @@ class ISPredictor(object):
pred = torch.sigmoid(pred_logits)
pred = self.post_process(pred)
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
size=image_nd.size()[2:])
prediction = F.interpolate(
pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
)
for t in reversed(transforms):
prediction = t.inv_transform(prediction)
@ -121,32 +138,49 @@ class ISPredictor(object):
pred_mask = pred.cpu().numpy()[0][0]
# morph_open to remove small noise
kernel_size = self.open_kernel_size
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
)
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
dilate_kernel_size = self.dilate_kernel_size
if dilate_kernel_size > 1:
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
kernel = cv2.getStructuringElement(
cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)
)
pred_mask = cv2.dilate(pred_mask, kernel, 1)
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
num_pos_clicks = [
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
clicks_list = clicks_list[: self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
neg_clicks = [
click.coords_and_indx for click in clicks_list if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
@ -156,19 +190,45 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get(
"INTERACTIVE_SEG_MODEL_URL",
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
)
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf")
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
"INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf"
)
class InteractiveSeg:
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
device = torch.device('cpu')
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval()
self.predictor = ISPredictor(model, device,
infer_size=infer_size,
open_kernel_size=open_kernel_size,
dilate_kernel_size=dilate_kernel_size)
name = "InteractiveSeg"
def __call__(self, image, clicks, prev_mask=None):
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
device = torch.device("cpu")
model = load_jit_model(
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
).eval()
self.predictor = ISPredictor(
model,
device,
infer_size=infer_size,
open_kernel_size=open_kernel_size,
dilate_kernel_size=dilate_kernel_size,
)
def __call__(self, rgb_np_img, files, form):
image = rgb_np_img
if "mask" in files:
mask, _ = load_img(files["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
new_mask = self.forward(image, clicks=clicks, prev_mask=mask)
return new_mask
def forward(self, image, clicks, prev_mask=None):
"""
Args:
@ -183,7 +243,7 @@ class InteractiveSeg:
if prev_mask is None:
mask = torch.zeros_like(image[:, :1, :, :])
else:
logger.info('InteractiveSeg run with prev_mask')
logger.info("InteractiveSeg run with prev_mask")
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
pred_probs = self.predictor(image, clicks, mask)

View File

@ -0,0 +1,22 @@
import cv2
import numpy as np
class RemoveBG:
name = "RemoveBG"
def __init__(self):
from rembg import new_session
self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
return self.forward(bgr_np_img)
def forward(self, bgr_np_img) -> np.ndarray:
from rembg import remove
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return output

View File

@ -0,0 +1,46 @@
import cv2
from lama_cleaner.helper import download_model
class RealESRGANUpscaler:
name = "RealESRGAN"
def __init__(self, device):
super().__init__()
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
scale = 4
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_md5 = "99ec365d4afad750833258a1a24f44ca"
model_path = download_model(url, model_md5)
self.model = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
half=True if "cuda" in str(device) else False,
tile=640,
tile_pad=10,
pre_pad=10,
device=device,
)
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form["scale"])
return self.forward(bgr_np_img, scale)
def forward(self, bgr_np_img, scale: float):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3
import io
import json
import logging
import multiprocessing
import os
@ -16,12 +15,12 @@ import cv2
import torch
import numpy as np
from loguru import logger
from watchdog.events import FileSystemEventHandler
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager
@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
thumb: FileManager = None
output_dir: str = None
interactive_seg_model: InteractiveSeg = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
is_enable_auto_saving: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
def get_image_ext(img_bytes):
@ -319,35 +318,37 @@ def process():
return response
@app.route("/interactive_seg", methods=["POST"])
def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
@app.route("/run_plugin/", methods=["POST"])
def run_plugin():
form = request.form
files = request.files
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
name = form["name"]
if name not in plugins:
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
rgb_np_img, _ = load_img(origin_image_bytes)
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
res = plugins[name](rgb_np_img, files, form)
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, "png")),
io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png",
)
)
return response
@app.route("/plugins/", methods=["GET"])
def get_plugins():
return list(plugins.keys()), 200
@app.route("/model")
def current_model():
return model.name, 200
@ -423,14 +424,21 @@ def set_input_photo():
return "No Input Image"
class FSHandler(FileSystemEventHandler):
def on_modified(self, event):
print("File modified: %s" % event.src_path)
def build_plugins(args):
global plugins
if args.enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg()
if args.enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
if args.enable_realesrgan:
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
def main(args):
global model
global interactive_seg_model
global device
global input_image_path
global is_disable_model_switch
@ -442,6 +450,8 @@ def main(args):
global is_controlnet
global image_quality
build_plugins(args)
image_quality = args.quality
if args.sd_controlnet and args.model in SD15_MODELS:
@ -496,8 +506,6 @@ def main(args):
callback=diffuser_callback,
)
interactive_seg_model = InteractiveSeg()
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI

View File

@ -1,13 +1,12 @@
from pathlib import Path
import pytest
import cv2
import numpy as np
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.plugins import InteractiveSeg, Click
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
@ -15,17 +14,22 @@ img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
def test_interactive_seg():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)])
pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
)
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
def test_interactive_seg_with_negative_click():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
pred = interactive_seg_model(img, clicks=[
Click(coords=(256, 256), indx=0, is_positive=True),
Click(coords=(384, 256), indx=1, is_positive=False)
])
pred = interactive_seg_model.forward(
img,
clicks=[
Click(coords=(256, 256), indx=0, is_positive=True),
Click(coords=(384, 256), indx=1, is_positive=False),
],
)
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
@ -33,5 +37,7 @@ def test_interactive_seg_with_prev_mask():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
mask = np.zeros_like(img)[:, :, 0]
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask)
pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
)
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)

View File

@ -1,10 +1,5 @@
import os
import tempfile
from pathlib import Path
def test_load_model():
from lama_cleaner.interactive_seg import InteractiveSeg
from lama_cleaner.plugins import InteractiveSeg
from lama_cleaner.model_manager import ModelManager
interactive_seg_model = InteractiveSeg()

View File

@ -0,0 +1,27 @@
from pathlib import Path
import cv2
from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "bunny.jpeg"
def test_remove_bg():
model = RemoveBG()
img = cv2.imread(str(img_p))
res = model.forward(img)
cv2.imwrite(str(save_dir / "test_remove_bg.png"), res)
def test_upscale():
model = RealESRGANUpscaler("cpu")
img = cv2.imread(str(img_p))
res = model.forward(img, 2)
cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res)
res = model.forward(img, 4)
cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res)

View File

@ -14,7 +14,6 @@ markupsafe==2.0.1
scikit-image==0.19.3
diffusers[torch]==0.14.0
transformers>=4.25.1
watchdog==2.2.1
gradio
piexif==1.1.3
safetensors