2023-01-17 14:05:17 +01:00
|
|
|
import io
|
2021-11-15 08:22:34 +01:00
|
|
|
import os
|
|
|
|
import sys
|
2022-07-14 10:49:03 +02:00
|
|
|
from typing import List, Optional
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
import cv2
|
2023-01-17 14:05:17 +01:00
|
|
|
from PIL import Image, ImageOps
|
2021-11-15 08:22:34 +01:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2022-07-14 10:49:03 +02:00
|
|
|
from loguru import logger
|
2021-11-15 08:22:34 +01:00
|
|
|
from torch.hub import download_url_to_file, get_dir
|
|
|
|
|
|
|
|
|
2022-04-17 17:31:12 +02:00
|
|
|
def get_cache_path_by_url(url):
|
2021-11-15 08:22:34 +01:00
|
|
|
parts = urlparse(url)
|
|
|
|
hub_dir = get_dir()
|
|
|
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
2021-11-15 20:11:46 +01:00
|
|
|
if not os.path.isdir(model_dir):
|
2023-01-14 15:08:45 +01:00
|
|
|
os.makedirs(model_dir)
|
2021-11-15 08:22:34 +01:00
|
|
|
filename = os.path.basename(parts.path)
|
|
|
|
cached_file = os.path.join(model_dir, filename)
|
2022-04-17 17:31:12 +02:00
|
|
|
return cached_file
|
|
|
|
|
|
|
|
|
|
|
|
def download_model(url):
|
|
|
|
cached_file = get_cache_path_by_url(url)
|
2021-11-15 08:22:34 +01:00
|
|
|
if not os.path.exists(cached_file):
|
|
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
|
|
hash_prefix = None
|
|
|
|
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
|
|
|
return cached_file
|
|
|
|
|
|
|
|
|
|
|
|
def ceil_modulo(x, mod):
|
|
|
|
if x % mod == 0:
|
|
|
|
return x
|
|
|
|
return (x // mod + 1) * mod
|
|
|
|
|
|
|
|
|
2022-07-14 10:49:03 +02:00
|
|
|
def load_jit_model(url_or_path, device):
|
|
|
|
if os.path.exists(url_or_path):
|
|
|
|
model_path = url_or_path
|
|
|
|
else:
|
|
|
|
model_path = download_model(url_or_path)
|
|
|
|
logger.info(f"Load model from: {model_path}")
|
2022-07-19 15:47:21 +02:00
|
|
|
try:
|
|
|
|
model = torch.jit.load(model_path).to(device)
|
2022-08-22 17:24:02 +02:00
|
|
|
except:
|
|
|
|
logger.error(
|
|
|
|
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
|
|
|
)
|
|
|
|
exit(-1)
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model: torch.nn.Module, url_or_path, device):
|
|
|
|
if os.path.exists(url_or_path):
|
|
|
|
model_path = url_or_path
|
|
|
|
else:
|
|
|
|
model_path = download_model(url_or_path)
|
|
|
|
|
|
|
|
try:
|
2023-02-06 15:00:47 +01:00
|
|
|
state_dict = torch.load(model_path, map_location="cpu")
|
2022-08-22 17:24:02 +02:00
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
model.to(device)
|
|
|
|
logger.info(f"Load model from: {model_path}")
|
2022-07-19 15:47:21 +02:00
|
|
|
except:
|
|
|
|
logger.error(
|
|
|
|
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
|
|
|
)
|
|
|
|
exit(-1)
|
2022-07-14 10:49:03 +02:00
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-04-09 01:23:33 +02:00
|
|
|
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
2022-07-19 15:47:21 +02:00
|
|
|
data = cv2.imencode(
|
|
|
|
f".{ext}",
|
|
|
|
image_numpy,
|
|
|
|
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
|
|
|
)[1]
|
2021-11-15 08:22:34 +01:00
|
|
|
image_bytes = data.tobytes()
|
|
|
|
return image_bytes
|
|
|
|
|
|
|
|
|
2023-02-06 15:00:47 +01:00
|
|
|
def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
|
|
|
|
with io.BytesIO() as output:
|
2023-02-07 14:06:31 +01:00
|
|
|
pil_img.save(output, format=ext, exif=exif, quality=95)
|
2023-02-06 15:00:47 +01:00
|
|
|
image_bytes = output.getvalue()
|
|
|
|
return image_bytes
|
|
|
|
|
|
|
|
|
|
|
|
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
|
2022-04-09 02:12:37 +02:00
|
|
|
alpha_channel = None
|
2023-01-17 14:05:17 +01:00
|
|
|
image = Image.open(io.BytesIO(img_bytes))
|
2023-02-06 15:00:47 +01:00
|
|
|
|
|
|
|
try:
|
|
|
|
if return_exif:
|
|
|
|
exif = image.getexif()
|
|
|
|
except:
|
|
|
|
exif = None
|
|
|
|
logger.error("Failed to extract exif from image")
|
|
|
|
|
2023-01-17 14:05:17 +01:00
|
|
|
try:
|
|
|
|
image = ImageOps.exif_transpose(image)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
if gray:
|
2023-02-06 15:00:47 +01:00
|
|
|
image = image.convert("L")
|
2023-01-17 14:05:17 +01:00
|
|
|
np_img = np.array(image)
|
2021-11-15 08:22:34 +01:00
|
|
|
else:
|
2023-02-06 15:00:47 +01:00
|
|
|
if image.mode == "RGBA":
|
2023-01-17 14:05:17 +01:00
|
|
|
np_img = np.array(image)
|
2022-04-09 02:12:37 +02:00
|
|
|
alpha_channel = np_img[:, :, -1]
|
2023-01-17 14:05:17 +01:00
|
|
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
2021-11-27 13:37:37 +01:00
|
|
|
else:
|
2023-02-06 15:00:47 +01:00
|
|
|
image = image.convert("RGB")
|
2023-01-17 14:05:17 +01:00
|
|
|
np_img = np.array(image)
|
2021-11-27 13:37:37 +01:00
|
|
|
|
2023-02-06 15:00:47 +01:00
|
|
|
if return_exif:
|
|
|
|
return np_img, alpha_channel, exif
|
2022-04-09 02:12:37 +02:00
|
|
|
return np_img, alpha_channel
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
|
2021-11-27 13:37:37 +01:00
|
|
|
def norm_img(np_img):
|
|
|
|
if len(np_img.shape) == 2:
|
|
|
|
np_img = np_img[:, :, np.newaxis]
|
|
|
|
np_img = np.transpose(np_img, (2, 0, 1))
|
|
|
|
np_img = np_img.astype("float32") / 255
|
2021-11-15 08:22:34 +01:00
|
|
|
return np_img
|
|
|
|
|
|
|
|
|
2021-11-27 13:37:37 +01:00
|
|
|
def resize_max_size(
|
|
|
|
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
|
|
|
) -> np.ndarray:
|
|
|
|
# Resize image's longer size to size_limit if longer size larger than size_limit
|
|
|
|
h, w = np_img.shape[:2]
|
|
|
|
if max(h, w) > size_limit:
|
|
|
|
ratio = size_limit / max(h, w)
|
|
|
|
new_w = int(w * ratio + 0.5)
|
|
|
|
new_h = int(h * ratio + 0.5)
|
|
|
|
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
|
|
|
else:
|
|
|
|
return np_img
|
|
|
|
|
|
|
|
|
2022-07-19 15:47:21 +02:00
|
|
|
def pad_img_to_modulo(
|
|
|
|
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
|
|
|
):
|
2022-04-15 18:11:51 +02:00
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
img: [H, W, C]
|
|
|
|
mod:
|
2022-07-14 10:49:03 +02:00
|
|
|
square: 是否为正方形
|
|
|
|
min_size:
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
|
|
|
if len(img.shape) == 2:
|
|
|
|
img = img[:, :, np.newaxis]
|
|
|
|
height, width = img.shape[:2]
|
2021-11-15 08:22:34 +01:00
|
|
|
out_height = ceil_modulo(height, mod)
|
|
|
|
out_width = ceil_modulo(width, mod)
|
2022-07-14 10:49:03 +02:00
|
|
|
|
|
|
|
if min_size is not None:
|
|
|
|
assert min_size % mod == 0
|
|
|
|
out_width = max(min_size, out_width)
|
|
|
|
out_height = max(min_size, out_height)
|
|
|
|
|
|
|
|
if square:
|
|
|
|
max_size = max(out_height, out_width)
|
|
|
|
out_height = max_size
|
|
|
|
out_width = max_size
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
return np.pad(
|
|
|
|
img,
|
2022-04-15 18:11:51 +02:00
|
|
|
((0, out_height - height), (0, out_width - width), (0, 0)),
|
2021-11-15 08:22:34 +01:00
|
|
|
mode="symmetric",
|
|
|
|
)
|
2022-03-23 03:02:01 +01:00
|
|
|
|
|
|
|
|
|
|
|
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
|
|
|
"""
|
|
|
|
Args:
|
2022-04-15 18:11:51 +02:00
|
|
|
mask: (h, w, 1) 0~255
|
2022-03-23 03:02:01 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
2022-04-15 18:11:51 +02:00
|
|
|
height, width = mask.shape[:2]
|
|
|
|
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
2022-03-23 03:02:01 +01:00
|
|
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
|
|
boxes = []
|
|
|
|
for cnt in contours:
|
|
|
|
x, y, w, h = cv2.boundingRect(cnt)
|
2022-07-14 10:49:03 +02:00
|
|
|
box = np.array([x, y, x + w, y + h]).astype(int)
|
2022-03-23 03:02:01 +01:00
|
|
|
|
|
|
|
box[::2] = np.clip(box[::2], 0, width)
|
|
|
|
box[1::2] = np.clip(box[1::2], 0, height)
|
|
|
|
boxes.append(box)
|
|
|
|
|
|
|
|
return boxes
|
2022-11-27 14:25:27 +01:00
|
|
|
|
|
|
|
|
|
|
|
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
mask: (h, w) 0~255
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
|
|
|
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
|
|
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
|
|
max_area = 0
|
|
|
|
max_index = -1
|
|
|
|
for i, cnt in enumerate(contours):
|
|
|
|
area = cv2.contourArea(cnt)
|
|
|
|
if area > max_area:
|
|
|
|
max_area = area
|
|
|
|
max_index = i
|
|
|
|
|
|
|
|
if max_index != -1:
|
|
|
|
new_mask = np.zeros_like(mask)
|
|
|
|
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
|
|
|
else:
|
|
|
|
return mask
|