add plugins

This commit is contained in:
Qing 2023-03-22 12:57:18 +08:00
parent b48d964c2c
commit 5a38d28ad1
11 changed files with 283 additions and 91 deletions

View File

@ -7,8 +7,8 @@ import time
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from watchdog.events import FileSystemEventHandler # from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer # from watchdog.observers import Observer
from PIL import Image, ImageOps, PngImagePlugin from PIL import Image, ImageOps, PngImagePlugin
from loguru import logger from loguru import logger
@ -19,7 +19,7 @@ from .storage_backends import FilesystemStorageBackend
from .utils import aspect_to_string, generate_filename, glob_img from .utils import aspect_to_string, generate_filename, glob_img
class FileManager(FileSystemEventHandler): class FileManager:
def __init__(self, app=None): def __init__(self, app=None):
self.app = app self.app = app
self._default_root_directory = "media" self._default_root_directory = "media"
@ -43,19 +43,19 @@ class FileManager(FileSystemEventHandler):
"output": datetime.utcnow(), "output": datetime.utcnow(),
} }
def start(self): # def start(self):
self.image_dir_filenames = self._media_names(self.root_directory) # self.image_dir_filenames = self._media_names(self.root_directory)
self.output_dir_filenames = self._media_names(self.output_dir) # self.output_dir_filenames = self._media_names(self.output_dir)
#
logger.info(f"Start watching image directory: {self.root_directory}") # logger.info(f"Start watching image directory: {self.root_directory}")
self.image_dir_observer = Observer() # self.image_dir_observer = Observer()
self.image_dir_observer.schedule(self, self.root_directory, recursive=False) # self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
self.image_dir_observer.start() # self.image_dir_observer.start()
#
logger.info(f"Start watching output directory: {self.output_dir}") # logger.info(f"Start watching output directory: {self.output_dir}")
self.output_dir_observer = Observer() # self.output_dir_observer = Observer()
self.output_dir_observer.schedule(self, self.output_dir, recursive=False) # self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
self.output_dir_observer.start() # self.output_dir_observer.start()
def on_modified(self, event): def on_modified(self, event):
if not os.path.isdir(event.src_path): if not os.path.isdir(event.src_path):

View File

@ -69,9 +69,33 @@ def parse_args():
help="Disable model switch in frontend", help="Disable model switch in frontend",
) )
parser.add_argument( parser.add_argument(
"--quality", default=95, type=int, help=QUALITY_HELP, "--quality",
default=95,
type=int,
help=QUALITY_HELP,
) )
# Plugins
parser.add_argument(
"--enable-interactive-seg",
action="store_true",
help="Enable interactive segmentation. Always run on CPU",
)
parser.add_argument(
"--enable-remove-bg",
action="store_true",
help="Enable remove background. Always run on CPU",
)
parser.add_argument(
"--enable-realesrgan",
action="store_true",
help="Enable realesrgan super resolution",
)
parser.add_argument(
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
)
#########
# useless args # useless args
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS) parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS) parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
@ -123,6 +147,7 @@ def parse_args():
if args.model not in SD15_MODELS: if args.model not in SD15_MODELS:
logger.warning(f"--sd_controlnet only support {SD15_MODELS}") logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
parser.error(f"invalid --model-dir: {args.model_dir} is a file") parser.error(f"invalid --model-dir: {args.model_dir} is a file")
@ -132,6 +157,7 @@ def parse_args():
Path(args.model_dir).mkdir(exist_ok=True, parents=True) Path(args.model_dir).mkdir(exist_ok=True, parents=True)
os.environ["XDG_CACHE_HOME"] = args.model_dir os.environ["XDG_CACHE_HOME"] = args.model_dir
os.environ["U2NET_HOME"] = args.model_dir
if args.input and args.input is not None: if args.input and args.input is not None:
if not os.path.exists(args.input): if not os.path.exists(args.input):

View File

@ -0,0 +1,3 @@
from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG
from .upscale import RealESRGANUpscaler

View File

@ -1,14 +1,19 @@
import json
import json
import os import os
from typing import Tuple, List
import cv2 import cv2
from typing import Tuple, List import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
import numpy as np
from lama_cleaner.helper import only_keep_largest_contour, load_jit_model from lama_cleaner.helper import (
load_jit_model,
load_img,
)
class Click(BaseModel): class Click(BaseModel):
@ -21,11 +26,11 @@ class Click(BaseModel):
def coords_and_indx(self): def coords_and_indx(self):
return (*self.coords, self.indx) return (*self.coords, self.indx)
def scale(self, x_ratio: float, y_ratio: float) -> 'Click': def scale(self, x_ratio: float, y_ratio: float) -> "Click":
return Click( return Click(
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio), coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
is_positive=self.is_positive, is_positive=self.is_positive,
indx=self.indx indx=self.indx,
) )
@ -40,21 +45,32 @@ class ResizeTrans:
image_height, image_width = image_nd.shape[2:4] image_height, image_width = image_nd.shape[2:4]
self.image_height = image_height self.image_height = image_height
self.image_width = image_width self.image_width = image_width
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True) image_nd_r = F.interpolate(
image_nd,
(self.crop_height, self.crop_width),
mode="bilinear",
align_corners=True,
)
y_ratio = self.crop_height / image_height y_ratio = self.crop_height / image_height
x_ratio = self.crop_width / image_width x_ratio = self.crop_width / image_width
clicks_lists_resized = [] clicks_lists_resized = []
for clicks_list in clicks_lists: for clicks_list in clicks_lists:
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list] clicks_list_resized = [
click.scale(y_ratio, x_ratio) for click in clicks_list
]
clicks_lists_resized.append(clicks_list_resized) clicks_lists_resized.append(clicks_list_resized)
return image_nd_r, clicks_lists_resized return image_nd_r, clicks_lists_resized
def inv_transform(self, prob_map): def inv_transform(self, prob_map):
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear', new_prob_map = F.interpolate(
align_corners=True) prob_map,
(self.image_height, self.image_width),
mode="bilinear",
align_corners=True,
)
return new_prob_map return new_prob_map
@ -106,8 +122,9 @@ class ISPredictor(object):
pred = torch.sigmoid(pred_logits) pred = torch.sigmoid(pred_logits)
pred = self.post_process(pred) pred = self.post_process(pred)
prediction = F.interpolate(pred, mode='bilinear', align_corners=True, prediction = F.interpolate(
size=image_nd.size()[2:]) pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
)
for t in reversed(transforms): for t in reversed(transforms):
prediction = t.inv_transform(prediction) prediction = t.inv_transform(prediction)
@ -121,32 +138,49 @@ class ISPredictor(object):
pred_mask = pred.cpu().numpy()[0][0] pred_mask = pred.cpu().numpy()[0][0]
# morph_open to remove small noise # morph_open to remove small noise
kernel_size = self.open_kernel_size kernel_size = self.open_kernel_size
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
)
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1) pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better # Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
dilate_kernel_size = self.dilate_kernel_size dilate_kernel_size = self.dilate_kernel_size
if dilate_kernel_size > 1: if dilate_kernel_size > 1:
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)) kernel = cv2.getStructuringElement(
cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)
)
pred_mask = cv2.dilate(pred_mask, kernel, 1) pred_mask = cv2.dilate(pred_mask, kernel, 1)
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0) return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
def get_points_nd(self, clicks_lists): def get_points_nd(self, clicks_lists):
total_clicks = [] total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] num_pos_clicks = [
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks) num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None: if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points) num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists: for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit] clicks_list = clicks_list[: self.net_clicks_limit]
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] pos_clicks = [
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] click.coords_and_indx for click in clicks_list if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] neg_clicks = [
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] click.coords_and_indx for click in clicks_list if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks) total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device) return torch.tensor(total_clicks, device=self.device)
@ -156,19 +190,45 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get(
"INTERACTIVE_SEG_MODEL_URL", "INTERACTIVE_SEG_MODEL_URL",
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt", "https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
) )
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf") INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
"INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf"
)
class InteractiveSeg: class InteractiveSeg:
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): name = "InteractiveSeg"
device = torch.device('cpu')
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval()
self.predictor = ISPredictor(model, device,
infer_size=infer_size,
open_kernel_size=open_kernel_size,
dilate_kernel_size=dilate_kernel_size)
def __call__(self, image, clicks, prev_mask=None): def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
device = torch.device("cpu")
model = load_jit_model(
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
).eval()
self.predictor = ISPredictor(
model,
device,
infer_size=infer_size,
open_kernel_size=open_kernel_size,
dilate_kernel_size=dilate_kernel_size,
)
def __call__(self, rgb_np_img, files, form):
image = rgb_np_img
if "mask" in files:
mask, _ = load_img(files["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
new_mask = self.forward(image, clicks=clicks, prev_mask=mask)
return new_mask
def forward(self, image, clicks, prev_mask=None):
""" """
Args: Args:
@ -183,7 +243,7 @@ class InteractiveSeg:
if prev_mask is None: if prev_mask is None:
mask = torch.zeros_like(image[:, :1, :, :]) mask = torch.zeros_like(image[:, :1, :, :])
else: else:
logger.info('InteractiveSeg run with prev_mask') logger.info("InteractiveSeg run with prev_mask")
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float() mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
pred_probs = self.predictor(image, clicks, mask) pred_probs = self.predictor(image, clicks, mask)

View File

@ -0,0 +1,22 @@
import cv2
import numpy as np
class RemoveBG:
name = "RemoveBG"
def __init__(self):
from rembg import new_session
self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
return self.forward(bgr_np_img)
def forward(self, bgr_np_img) -> np.ndarray:
from rembg import remove
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return output

View File

@ -0,0 +1,46 @@
import cv2
from lama_cleaner.helper import download_model
class RealESRGANUpscaler:
name = "RealESRGAN"
def __init__(self, device):
super().__init__()
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
scale = 4
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_md5 = "99ec365d4afad750833258a1a24f44ca"
model_path = download_model(url, model_md5)
self.model = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
half=True if "cuda" in str(device) else False,
tile=640,
tile_pad=10,
pre_pad=10,
device=device,
)
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form["scale"])
return self.forward(bgr_np_img, scale)
def forward(self, bgr_np_img, scale: float):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import io import io
import json
import logging import logging
import multiprocessing import multiprocessing
import os import os
@ -16,12 +15,12 @@ import cv2
import torch import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from watchdog.events import FileSystemEventHandler
from lama_cleaner.const import SD15_MODELS from lama_cleaner.const import SD15_MODELS
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None model: ModelManager = None
thumb: FileManager = None thumb: FileManager = None
output_dir: str = None output_dir: str = None
interactive_seg_model: InteractiveSeg = None
device = None device = None
input_image_path: str = None input_image_path: str = None
is_disable_model_switch: bool = False is_disable_model_switch: bool = False
@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
is_enable_auto_saving: bool = False is_enable_auto_saving: bool = False
is_desktop: bool = False is_desktop: bool = False
image_quality: int = 95 image_quality: int = 95
plugins = {}
def get_image_ext(img_bytes): def get_image_ext(img_bytes):
@ -319,35 +318,37 @@ def process():
return response return response
@app.route("/interactive_seg", methods=["POST"]) @app.route("/run_plugin/", methods=["POST"])
def interactive_seg(): def run_plugin():
input = request.files form = request.form
origin_image_bytes = input["image"].read() # RGB files = request.files
image, _ = load_img(origin_image_bytes)
if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(request.form["clicks"]) name = form["name"]
clicks = [] if name not in plugins:
for i, click in enumerate(_clicks): return "Plugin not found", 500
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1) origin_image_bytes = files["image"].read() # RGB
) rgb_np_img, _ = load_img(origin_image_bytes)
start = time.time() start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) res = plugins[name](rgb_np_img, files, form)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
response = make_response( response = make_response(
send_file( send_file(
io.BytesIO(numpy_to_bytes(new_mask, "png")), io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png", mimetype=f"image/png",
) )
) )
return response return response
@app.route("/plugins/", methods=["GET"])
def get_plugins():
return list(plugins.keys()), 200
@app.route("/model") @app.route("/model")
def current_model(): def current_model():
return model.name, 200 return model.name, 200
@ -423,14 +424,21 @@ def set_input_photo():
return "No Input Image" return "No Input Image"
class FSHandler(FileSystemEventHandler): def build_plugins(args):
def on_modified(self, event): global plugins
print("File modified: %s" % event.src_path) if args.enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg()
if args.enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
if args.enable_realesrgan:
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
def main(args): def main(args):
global model global model
global interactive_seg_model
global device global device
global input_image_path global input_image_path
global is_disable_model_switch global is_disable_model_switch
@ -442,6 +450,8 @@ def main(args):
global is_controlnet global is_controlnet
global image_quality global image_quality
build_plugins(args)
image_quality = args.quality image_quality = args.quality
if args.sd_controlnet and args.model in SD15_MODELS: if args.sd_controlnet and args.model in SD15_MODELS:
@ -496,8 +506,6 @@ def main(args):
callback=diffuser_callback, callback=diffuser_callback,
) )
interactive_seg_model = InteractiveSeg()
if args.gui: if args.gui:
app_width, app_height = args.gui_size app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI from flaskwebgui import FlaskUI

View File

@ -1,13 +1,12 @@
from pathlib import Path from pathlib import Path
import pytest
import cv2 import cv2
import numpy as np import numpy as np
from lama_cleaner.interactive_seg import InteractiveSeg, Click from lama_cleaner.plugins import InteractiveSeg, Click
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result' save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True) save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
@ -15,17 +14,22 @@ img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
def test_interactive_seg(): def test_interactive_seg():
interactive_seg_model = InteractiveSeg() interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p)) img = cv2.imread(str(img_p))
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]) pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
)
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred) cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
def test_interactive_seg_with_negative_click(): def test_interactive_seg_with_negative_click():
interactive_seg_model = InteractiveSeg() interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p)) img = cv2.imread(str(img_p))
pred = interactive_seg_model(img, clicks=[ pred = interactive_seg_model.forward(
Click(coords=(256, 256), indx=0, is_positive=True), img,
Click(coords=(384, 256), indx=1, is_positive=False) clicks=[
]) Click(coords=(256, 256), indx=0, is_positive=True),
Click(coords=(384, 256), indx=1, is_positive=False),
],
)
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred) cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
@ -33,5 +37,7 @@ def test_interactive_seg_with_prev_mask():
interactive_seg_model = InteractiveSeg() interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p)) img = cv2.imread(str(img_p))
mask = np.zeros_like(img)[:, :, 0] mask = np.zeros_like(img)[:, :, 0]
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask) pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
)
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred) cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)

View File

@ -1,10 +1,5 @@
import os
import tempfile
from pathlib import Path
def test_load_model(): def test_load_model():
from lama_cleaner.interactive_seg import InteractiveSeg from lama_cleaner.plugins import InteractiveSeg
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
interactive_seg_model = InteractiveSeg() interactive_seg_model = InteractiveSeg()

View File

@ -0,0 +1,27 @@
from pathlib import Path
import cv2
from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "bunny.jpeg"
def test_remove_bg():
model = RemoveBG()
img = cv2.imread(str(img_p))
res = model.forward(img)
cv2.imwrite(str(save_dir / "test_remove_bg.png"), res)
def test_upscale():
model = RealESRGANUpscaler("cpu")
img = cv2.imread(str(img_p))
res = model.forward(img, 2)
cv2.imwrite(str(save_dir / "test_upscale_x2.png"), res)
res = model.forward(img, 4)
cv2.imwrite(str(save_dir / "test_upscale_x4.png"), res)

View File

@ -14,7 +14,6 @@ markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers[torch]==0.14.0 diffusers[torch]==0.14.0
transformers>=4.25.1 transformers>=4.25.1
watchdog==2.2.1
gradio gradio
piexif==1.1.3 piexif==1.1.3
safetensors safetensors