add plugins
This commit is contained in:
parent
b48d964c2c
commit
5a38d28ad1
@ -7,8 +7,8 @@ import time
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from watchdog.events import FileSystemEventHandler
|
# from watchdog.events import FileSystemEventHandler
|
||||||
from watchdog.observers import Observer
|
# from watchdog.observers import Observer
|
||||||
|
|
||||||
from PIL import Image, ImageOps, PngImagePlugin
|
from PIL import Image, ImageOps, PngImagePlugin
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,7 +19,7 @@ from .storage_backends import FilesystemStorageBackend
|
|||||||
from .utils import aspect_to_string, generate_filename, glob_img
|
from .utils import aspect_to_string, generate_filename, glob_img
|
||||||
|
|
||||||
|
|
||||||
class FileManager(FileSystemEventHandler):
|
class FileManager:
|
||||||
def __init__(self, app=None):
|
def __init__(self, app=None):
|
||||||
self.app = app
|
self.app = app
|
||||||
self._default_root_directory = "media"
|
self._default_root_directory = "media"
|
||||||
@ -43,19 +43,19 @@ class FileManager(FileSystemEventHandler):
|
|||||||
"output": datetime.utcnow(),
|
"output": datetime.utcnow(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def start(self):
|
# def start(self):
|
||||||
self.image_dir_filenames = self._media_names(self.root_directory)
|
# self.image_dir_filenames = self._media_names(self.root_directory)
|
||||||
self.output_dir_filenames = self._media_names(self.output_dir)
|
# self.output_dir_filenames = self._media_names(self.output_dir)
|
||||||
|
#
|
||||||
logger.info(f"Start watching image directory: {self.root_directory}")
|
# logger.info(f"Start watching image directory: {self.root_directory}")
|
||||||
self.image_dir_observer = Observer()
|
# self.image_dir_observer = Observer()
|
||||||
self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
||||||
self.image_dir_observer.start()
|
# self.image_dir_observer.start()
|
||||||
|
#
|
||||||
logger.info(f"Start watching output directory: {self.output_dir}")
|
# logger.info(f"Start watching output directory: {self.output_dir}")
|
||||||
self.output_dir_observer = Observer()
|
# self.output_dir_observer = Observer()
|
||||||
self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
||||||
self.output_dir_observer.start()
|
# self.output_dir_observer.start()
|
||||||
|
|
||||||
def on_modified(self, event):
|
def on_modified(self, event):
|
||||||
if not os.path.isdir(event.src_path):
|
if not os.path.isdir(event.src_path):
|
||||||
|
@ -69,9 +69,33 @@ def parse_args():
|
|||||||
help="Disable model switch in frontend",
|
help="Disable model switch in frontend",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quality", default=95, type=int, help=QUALITY_HELP,
|
"--quality",
|
||||||
|
default=95,
|
||||||
|
type=int,
|
||||||
|
help=QUALITY_HELP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plugins
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-interactive-seg",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable interactive segmentation. Always run on CPU",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-remove-bg",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable remove background. Always run on CPU",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-realesrgan",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable realesrgan super resolution",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
|
||||||
|
)
|
||||||
|
#########
|
||||||
|
|
||||||
# useless args
|
# useless args
|
||||||
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
|
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
|
||||||
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
|
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
|
||||||
@ -123,6 +147,7 @@ def parse_args():
|
|||||||
if args.model not in SD15_MODELS:
|
if args.model not in SD15_MODELS:
|
||||||
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
|
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
|
||||||
|
|
||||||
|
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
|
||||||
if args.model_dir and args.model_dir is not None:
|
if args.model_dir and args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
||||||
@ -132,6 +157,7 @@ def parse_args():
|
|||||||
Path(args.model_dir).mkdir(exist_ok=True, parents=True)
|
Path(args.model_dir).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
||||||
|
os.environ["U2NET_HOME"] = args.model_dir
|
||||||
|
|
||||||
if args.input and args.input is not None:
|
if args.input and args.input is not None:
|
||||||
if not os.path.exists(args.input):
|
if not os.path.exists(args.input):
|
||||||
|
3
lama_cleaner/plugins/__init__.py
Normal file
3
lama_cleaner/plugins/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .interactive_seg import InteractiveSeg, Click
|
||||||
|
from .remove_bg import RemoveBG
|
||||||
|
from .upscale import RealESRGANUpscaler
|
@ -1,14 +1,19 @@
|
|||||||
|
import json
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from typing import Tuple, List
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lama_cleaner.helper import only_keep_largest_contour, load_jit_model
|
from lama_cleaner.helper import (
|
||||||
|
load_jit_model,
|
||||||
|
load_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Click(BaseModel):
|
class Click(BaseModel):
|
||||||
@ -21,11 +26,11 @@ class Click(BaseModel):
|
|||||||
def coords_and_indx(self):
|
def coords_and_indx(self):
|
||||||
return (*self.coords, self.indx)
|
return (*self.coords, self.indx)
|
||||||
|
|
||||||
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
|
def scale(self, x_ratio: float, y_ratio: float) -> "Click":
|
||||||
return Click(
|
return Click(
|
||||||
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
||||||
is_positive=self.is_positive,
|
is_positive=self.is_positive,
|
||||||
indx=self.indx
|
indx=self.indx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -40,21 +45,32 @@ class ResizeTrans:
|
|||||||
image_height, image_width = image_nd.shape[2:4]
|
image_height, image_width = image_nd.shape[2:4]
|
||||||
self.image_height = image_height
|
self.image_height = image_height
|
||||||
self.image_width = image_width
|
self.image_width = image_width
|
||||||
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
|
image_nd_r = F.interpolate(
|
||||||
|
image_nd,
|
||||||
|
(self.crop_height, self.crop_width),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
|
||||||
y_ratio = self.crop_height / image_height
|
y_ratio = self.crop_height / image_height
|
||||||
x_ratio = self.crop_width / image_width
|
x_ratio = self.crop_width / image_width
|
||||||
|
|
||||||
clicks_lists_resized = []
|
clicks_lists_resized = []
|
||||||
for clicks_list in clicks_lists:
|
for clicks_list in clicks_lists:
|
||||||
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
|
clicks_list_resized = [
|
||||||
|
click.scale(y_ratio, x_ratio) for click in clicks_list
|
||||||
|
]
|
||||||
clicks_lists_resized.append(clicks_list_resized)
|
clicks_lists_resized.append(clicks_list_resized)
|
||||||
|
|
||||||
return image_nd_r, clicks_lists_resized
|
return image_nd_r, clicks_lists_resized
|
||||||
|
|
||||||
def inv_transform(self, prob_map):
|
def inv_transform(self, prob_map):
|
||||||
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
|
new_prob_map = F.interpolate(
|
||||||
align_corners=True)
|
prob_map,
|
||||||
|
(self.image_height, self.image_width),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
|
||||||
return new_prob_map
|
return new_prob_map
|
||||||
|
|
||||||
@ -106,8 +122,9 @@ class ISPredictor(object):
|
|||||||
pred = torch.sigmoid(pred_logits)
|
pred = torch.sigmoid(pred_logits)
|
||||||
pred = self.post_process(pred)
|
pred = self.post_process(pred)
|
||||||
|
|
||||||
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
|
prediction = F.interpolate(
|
||||||
size=image_nd.size()[2:])
|
pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
|
||||||
|
)
|
||||||
|
|
||||||
for t in reversed(transforms):
|
for t in reversed(transforms):
|
||||||
prediction = t.inv_transform(prediction)
|
prediction = t.inv_transform(prediction)
|
||||||
@ -121,32 +138,49 @@ class ISPredictor(object):
|
|||||||
pred_mask = pred.cpu().numpy()[0][0]
|
pred_mask = pred.cpu().numpy()[0][0]
|
||||||
# morph_open to remove small noise
|
# morph_open to remove small noise
|
||||||
kernel_size = self.open_kernel_size
|
kernel_size = self.open_kernel_size
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
kernel = cv2.getStructuringElement(
|
||||||
|
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
|
||||||
|
)
|
||||||
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
||||||
|
|
||||||
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
||||||
dilate_kernel_size = self.dilate_kernel_size
|
dilate_kernel_size = self.dilate_kernel_size
|
||||||
if dilate_kernel_size > 1:
|
if dilate_kernel_size > 1:
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
|
kernel = cv2.getStructuringElement(
|
||||||
|
cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)
|
||||||
|
)
|
||||||
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
||||||
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
def get_points_nd(self, clicks_lists):
|
def get_points_nd(self, clicks_lists):
|
||||||
total_clicks = []
|
total_clicks = []
|
||||||
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
num_pos_clicks = [
|
||||||
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
|
||||||
|
]
|
||||||
|
num_neg_clicks = [
|
||||||
|
len(clicks_list) - num_pos
|
||||||
|
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
|
||||||
|
]
|
||||||
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
||||||
if self.net_clicks_limit is not None:
|
if self.net_clicks_limit is not None:
|
||||||
num_max_points = min(self.net_clicks_limit, num_max_points)
|
num_max_points = min(self.net_clicks_limit, num_max_points)
|
||||||
num_max_points = max(1, num_max_points)
|
num_max_points = max(1, num_max_points)
|
||||||
|
|
||||||
for clicks_list in clicks_lists:
|
for clicks_list in clicks_lists:
|
||||||
clicks_list = clicks_list[:self.net_clicks_limit]
|
clicks_list = clicks_list[: self.net_clicks_limit]
|
||||||
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
pos_clicks = [
|
||||||
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
click.coords_and_indx for click in clicks_list if click.is_positive
|
||||||
|
]
|
||||||
|
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
|
||||||
|
(-1, -1, -1)
|
||||||
|
]
|
||||||
|
|
||||||
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
neg_clicks = [
|
||||||
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
click.coords_and_indx for click in clicks_list if not click.is_positive
|
||||||
|
]
|
||||||
|
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
|
||||||
|
(-1, -1, -1)
|
||||||
|
]
|
||||||
total_clicks.append(pos_clicks + neg_clicks)
|
total_clicks.append(pos_clicks + neg_clicks)
|
||||||
|
|
||||||
return torch.tensor(total_clicks, device=self.device)
|
return torch.tensor(total_clicks, device=self.device)
|
||||||
@ -156,19 +190,45 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get(
|
|||||||
"INTERACTIVE_SEG_MODEL_URL",
|
"INTERACTIVE_SEG_MODEL_URL",
|
||||||
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
|
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
|
||||||
)
|
)
|
||||||
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf")
|
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
|
||||||
|
"INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InteractiveSeg:
|
class InteractiveSeg:
|
||||||
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
name = "InteractiveSeg"
|
||||||
device = torch.device('cpu')
|
|
||||||
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval()
|
|
||||||
self.predictor = ISPredictor(model, device,
|
|
||||||
infer_size=infer_size,
|
|
||||||
open_kernel_size=open_kernel_size,
|
|
||||||
dilate_kernel_size=dilate_kernel_size)
|
|
||||||
|
|
||||||
def __call__(self, image, clicks, prev_mask=None):
|
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model = load_jit_model(
|
||||||
|
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
|
||||||
|
).eval()
|
||||||
|
self.predictor = ISPredictor(
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
infer_size=infer_size,
|
||||||
|
open_kernel_size=open_kernel_size,
|
||||||
|
dilate_kernel_size=dilate_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
image = rgb_np_img
|
||||||
|
if "mask" in files:
|
||||||
|
mask, _ = load_img(files["mask"].read(), gray=True)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
_clicks = json.loads(form["clicks"])
|
||||||
|
clicks = []
|
||||||
|
for i, click in enumerate(_clicks):
|
||||||
|
clicks.append(
|
||||||
|
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_mask = self.forward(image, clicks=clicks, prev_mask=mask)
|
||||||
|
return new_mask
|
||||||
|
|
||||||
|
def forward(self, image, clicks, prev_mask=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -183,7 +243,7 @@ class InteractiveSeg:
|
|||||||
if prev_mask is None:
|
if prev_mask is None:
|
||||||
mask = torch.zeros_like(image[:, :1, :, :])
|
mask = torch.zeros_like(image[:, :1, :, :])
|
||||||
else:
|
else:
|
||||||
logger.info('InteractiveSeg run with prev_mask')
|
logger.info("InteractiveSeg run with prev_mask")
|
||||||
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
||||||
|
|
||||||
pred_probs = self.predictor(image, clicks, mask)
|
pred_probs = self.predictor(image, clicks, mask)
|
22
lama_cleaner/plugins/remove_bg.py
Normal file
22
lama_cleaner/plugins/remove_bg.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveBG:
|
||||||
|
name = "RemoveBG"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from rembg import new_session
|
||||||
|
|
||||||
|
self.session = new_session(model_name="u2net")
|
||||||
|
|
||||||
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
return self.forward(bgr_np_img)
|
||||||
|
|
||||||
|
def forward(self, bgr_np_img) -> np.ndarray:
|
||||||
|
from rembg import remove
|
||||||
|
|
||||||
|
# return BGRA image
|
||||||
|
output = remove(bgr_np_img, session=self.session)
|
||||||
|
return output
|
46
lama_cleaner/plugins/upscale.py
Normal file
46
lama_cleaner/plugins/upscale.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import cv2
|
||||||
|
|
||||||
|
from lama_cleaner.helper import download_model
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANUpscaler:
|
||||||
|
name = "RealESRGAN"
|
||||||
|
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
scale = 4
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
||||||
|
model_md5 = "99ec365d4afad750833258a1a24f44ca"
|
||||||
|
model_path = download_model(url, model_md5)
|
||||||
|
|
||||||
|
self.model = RealESRGANer(
|
||||||
|
scale=scale,
|
||||||
|
model_path=model_path,
|
||||||
|
model=model,
|
||||||
|
half=True if "cuda" in str(device) else False,
|
||||||
|
tile=640,
|
||||||
|
tile_pad=10,
|
||||||
|
pre_pad=10,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
|
scale = float(form["scale"])
|
||||||
|
return self.forward(bgr_np_img, scale)
|
||||||
|
|
||||||
|
def forward(self, bgr_np_img, scale: float):
|
||||||
|
# 输出是 BGR
|
||||||
|
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||||
|
return upsampled
|
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
@ -16,12 +15,12 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
|
|
||||||
from lama_cleaner.const import SD15_MODELS
|
from lama_cleaner.const import SD15_MODELS
|
||||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
|
||||||
from lama_cleaner.make_gif import make_compare_gif
|
from lama_cleaner.make_gif import make_compare_gif
|
||||||
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
from lama_cleaner.file_manager import FileManager
|
from lama_cleaner.file_manager import FileManager
|
||||||
|
|
||||||
@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
|
|||||||
model: ModelManager = None
|
model: ModelManager = None
|
||||||
thumb: FileManager = None
|
thumb: FileManager = None
|
||||||
output_dir: str = None
|
output_dir: str = None
|
||||||
interactive_seg_model: InteractiveSeg = None
|
|
||||||
device = None
|
device = None
|
||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
is_disable_model_switch: bool = False
|
is_disable_model_switch: bool = False
|
||||||
@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
|
|||||||
is_enable_auto_saving: bool = False
|
is_enable_auto_saving: bool = False
|
||||||
is_desktop: bool = False
|
is_desktop: bool = False
|
||||||
image_quality: int = 95
|
image_quality: int = 95
|
||||||
|
plugins = {}
|
||||||
|
|
||||||
|
|
||||||
def get_image_ext(img_bytes):
|
def get_image_ext(img_bytes):
|
||||||
@ -319,35 +318,37 @@ def process():
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.route("/interactive_seg", methods=["POST"])
|
@app.route("/run_plugin/", methods=["POST"])
|
||||||
def interactive_seg():
|
def run_plugin():
|
||||||
input = request.files
|
form = request.form
|
||||||
origin_image_bytes = input["image"].read() # RGB
|
files = request.files
|
||||||
image, _ = load_img(origin_image_bytes)
|
|
||||||
if "mask" in input:
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
_clicks = json.loads(request.form["clicks"])
|
name = form["name"]
|
||||||
clicks = []
|
if name not in plugins:
|
||||||
for i, click in enumerate(_clicks):
|
return "Plugin not found", 500
|
||||||
clicks.append(
|
|
||||||
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
origin_image_bytes = files["image"].read() # RGB
|
||||||
)
|
rgb_np_img, _ = load_img(origin_image_bytes)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
res = plugins[name](rgb_np_img, files, form)
|
||||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
response = make_response(
|
response = make_response(
|
||||||
send_file(
|
send_file(
|
||||||
io.BytesIO(numpy_to_bytes(new_mask, "png")),
|
io.BytesIO(numpy_to_bytes(res, "png")),
|
||||||
mimetype=f"image/png",
|
mimetype=f"image/png",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/plugins/", methods=["GET"])
|
||||||
|
def get_plugins():
|
||||||
|
return list(plugins.keys()), 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/model")
|
@app.route("/model")
|
||||||
def current_model():
|
def current_model():
|
||||||
return model.name, 200
|
return model.name, 200
|
||||||
@ -423,14 +424,21 @@ def set_input_photo():
|
|||||||
return "No Input Image"
|
return "No Input Image"
|
||||||
|
|
||||||
|
|
||||||
class FSHandler(FileSystemEventHandler):
|
def build_plugins(args):
|
||||||
def on_modified(self, event):
|
global plugins
|
||||||
print("File modified: %s" % event.src_path)
|
if args.enable_interactive_seg:
|
||||||
|
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||||
|
plugins[InteractiveSeg.name] = InteractiveSeg()
|
||||||
|
if args.enable_remove_bg:
|
||||||
|
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||||
|
plugins[RemoveBG.name] = RemoveBG()
|
||||||
|
if args.enable_realesrgan:
|
||||||
|
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
|
||||||
|
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
global model
|
global model
|
||||||
global interactive_seg_model
|
|
||||||
global device
|
global device
|
||||||
global input_image_path
|
global input_image_path
|
||||||
global is_disable_model_switch
|
global is_disable_model_switch
|
||||||
@ -442,6 +450,8 @@ def main(args):
|
|||||||
global is_controlnet
|
global is_controlnet
|
||||||
global image_quality
|
global image_quality
|
||||||
|
|
||||||
|
build_plugins(args)
|
||||||
|
|
||||||
image_quality = args.quality
|
image_quality = args.quality
|
||||||
|
|
||||||
if args.sd_controlnet and args.model in SD15_MODELS:
|
if args.sd_controlnet and args.model in SD15_MODELS:
|
||||||
@ -496,8 +506,6 @@ def main(args):
|
|||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
interactive_seg_model = InteractiveSeg()
|
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
app_width, app_height = args.gui_size
|
app_width, app_height = args.gui_size
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
from lama_cleaner.plugins import InteractiveSeg, Click
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / 'result'
|
save_dir = current_dir / "result"
|
||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
||||||
|
|
||||||
@ -15,17 +14,22 @@ img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
|||||||
def test_interactive_seg():
|
def test_interactive_seg():
|
||||||
interactive_seg_model = InteractiveSeg()
|
interactive_seg_model = InteractiveSeg()
|
||||||
img = cv2.imread(str(img_p))
|
img = cv2.imread(str(img_p))
|
||||||
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)])
|
pred = interactive_seg_model.forward(
|
||||||
|
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
|
||||||
|
)
|
||||||
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
|
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
|
||||||
|
|
||||||
|
|
||||||
def test_interactive_seg_with_negative_click():
|
def test_interactive_seg_with_negative_click():
|
||||||
interactive_seg_model = InteractiveSeg()
|
interactive_seg_model = InteractiveSeg()
|
||||||
img = cv2.imread(str(img_p))
|
img = cv2.imread(str(img_p))
|
||||||
pred = interactive_seg_model(img, clicks=[
|
pred = interactive_seg_model.forward(
|
||||||
Click(coords=(256, 256), indx=0, is_positive=True),
|
img,
|
||||||
Click(coords=(384, 256), indx=1, is_positive=False)
|
clicks=[
|
||||||
])
|
Click(coords=(256, 256), indx=0, is_positive=True),
|
||||||
|
Click(coords=(384, 256), indx=1, is_positive=False),
|
||||||
|
],
|
||||||
|
)
|
||||||
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
|
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
|
||||||
|
|
||||||
|
|
||||||
@ -33,5 +37,7 @@ def test_interactive_seg_with_prev_mask():
|
|||||||
interactive_seg_model = InteractiveSeg()
|
interactive_seg_model = InteractiveSeg()
|
||||||
img = cv2.imread(str(img_p))
|
img = cv2.imread(str(img_p))
|
||||||
mask = np.zeros_like(img)[:, :, 0]
|
mask = np.zeros_like(img)[:, :, 0]
|
||||||
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask)
|
pred = interactive_seg_model.forward(
|
||||||
|
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
|
||||||
|
)
|
||||||
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)
|
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)
|
||||||
|
@ -1,10 +1,5 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_model():
|
def test_load_model():
|
||||||
from lama_cleaner.interactive_seg import InteractiveSeg
|
from lama_cleaner.plugins import InteractiveSeg
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
|
||||||
interactive_seg_model = InteractiveSeg()
|
interactive_seg_model = InteractiveSeg()
|
||||||
|
27
lama_cleaner/tests/test_plugins.py
Normal file
27
lama_cleaner/tests/test_plugins.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / "result"
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
img_p = current_dir / "bunny.jpeg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_bg():
|
||||||
|
model = RemoveBG()
|
||||||
|
img = cv2.imread(str(img_p))
|
||||||
|
res = model.forward(img)
|
||||||
|
cv2.imwrite(str(save_dir / "test_remove_bg.png"), res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upscale():
|
||||||
|
model = RealESRGANUpscaler("cpu")
|
||||||
|
img = cv2.imread(str(img_p))
|
||||||
|
res = model.forward(img, 2)
|
||||||
|
cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res)
|
||||||
|
|
||||||
|
res = model.forward(img, 4)
|
||||||
|
cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res)
|
@ -14,7 +14,6 @@ markupsafe==2.0.1
|
|||||||
scikit-image==0.19.3
|
scikit-image==0.19.3
|
||||||
diffusers[torch]==0.14.0
|
diffusers[torch]==0.14.0
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
watchdog==2.2.1
|
|
||||||
gradio
|
gradio
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
safetensors
|
safetensors
|
||||||
|
Loading…
Reference in New Issue
Block a user