IOPaint/iopaint/helper.py

409 lines
12 KiB
Python
Raw Normal View History

2023-12-30 16:36:44 +01:00
import base64
import imghdr
2023-01-17 14:05:17 +01:00
import io
2021-11-15 08:22:34 +01:00
import os
import sys
2023-12-30 16:36:44 +01:00
from typing import List, Optional, Dict, Tuple
2021-11-15 08:22:34 +01:00
from urllib.parse import urlparse
import cv2
2023-05-07 10:58:55 +02:00
from PIL import Image, ImageOps, PngImagePlugin
2021-11-15 08:22:34 +01:00
import numpy as np
import torch
2024-01-05 08:19:23 +01:00
from iopaint.const import MPS_UNSUPPORT_MODELS
2022-07-14 10:49:03 +02:00
from loguru import logger
2021-11-15 08:22:34 +01:00
from torch.hub import download_url_to_file, get_dir
2023-02-26 02:19:48 +01:00
import hashlib
def md5sum(filename):
md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
md5.update(chunk)
return md5.hexdigest()
2021-11-15 08:22:34 +01:00
2023-02-11 06:30:09 +01:00
def switch_mps_device(model_name, device):
2023-12-01 03:15:35 +01:00
if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
2023-02-11 06:30:09 +01:00
logger.info(f"{model_name} not support mps, switch to cpu")
2023-02-14 02:08:56 +01:00
return torch.device("cpu")
2023-02-11 06:30:09 +01:00
return device
2022-04-17 17:31:12 +02:00
def get_cache_path_by_url(url):
2021-11-15 08:22:34 +01:00
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
2021-11-15 20:11:46 +01:00
if not os.path.isdir(model_dir):
2023-01-14 15:08:45 +01:00
os.makedirs(model_dir)
2021-11-15 08:22:34 +01:00
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
2022-04-17 17:31:12 +02:00
return cached_file
2023-02-26 02:19:48 +01:00
def download_model(url, model_md5: str = None):
2024-02-20 02:03:11 +01:00
if os.path.exists(url):
cached_file = url
else:
cached_file = get_cache_path_by_url(url)
2021-11-15 08:22:34 +01:00
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
2023-02-26 02:19:48 +01:00
if model_md5:
_md5 = md5sum(cached_file)
if model_md5 == _md5:
logger.info(f"Download model success, md5: {_md5}")
else:
2023-02-26 09:40:57 +01:00
try:
os.remove(cached_file)
logger.error(
2024-01-05 08:19:23 +01:00
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
2023-02-26 09:40:57 +01:00
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
logger.error(
2024-01-05 08:19:23 +01:00
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint."
2023-02-26 09:40:57 +01:00
)
2023-02-26 02:19:48 +01:00
exit(-1)
2021-11-15 08:22:34 +01:00
return cached_file
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
2023-02-26 02:19:48 +01:00
def handle_error(model_path, model_md5, e):
_md5 = md5sum(model_path)
if _md5 != model_md5:
2023-02-26 09:36:19 +01:00
try:
os.remove(model_path)
logger.error(
2024-01-05 08:19:23 +01:00
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
2023-02-26 09:36:19 +01:00
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
logger.error(
2024-01-05 08:19:23 +01:00
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint."
2023-02-26 09:36:19 +01:00
)
2023-02-26 02:19:48 +01:00
else:
logger.error(
f"Failed to load model {model_path},"
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
)
exit(-1)
def load_jit_model(url_or_path, device, model_md5: str):
2022-07-14 10:49:03 +02:00
if os.path.exists(url_or_path):
model_path = url_or_path
else:
2023-02-26 02:19:48 +01:00
model_path = download_model(url_or_path, model_md5)
2023-02-14 02:08:56 +01:00
logger.info(f"Loading model from: {model_path}")
2022-07-19 15:47:21 +02:00
try:
2023-02-14 02:08:56 +01:00
model = torch.jit.load(model_path, map_location="cpu").to(device)
except Exception as e:
2023-02-26 02:19:48 +01:00
handle_error(model_path, model_md5, e)
2022-08-22 17:24:02 +02:00
model.eval()
return model
2023-02-26 02:19:48 +01:00
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
2022-08-22 17:24:02 +02:00
if os.path.exists(url_or_path):
model_path = url_or_path
else:
2023-02-26 02:19:48 +01:00
model_path = download_model(url_or_path, model_md5)
2022-08-22 17:24:02 +02:00
try:
2023-02-26 02:19:48 +01:00
logger.info(f"Loading model from: {model_path}")
2023-02-06 15:00:47 +01:00
state_dict = torch.load(model_path, map_location="cpu")
2022-08-22 17:24:02 +02:00
model.load_state_dict(state_dict, strict=True)
model.to(device)
2023-02-26 02:19:48 +01:00
except Exception as e:
handle_error(model_path, model_md5, e)
2022-07-14 10:49:03 +02:00
model.eval()
return model
2022-04-09 01:23:33 +02:00
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
2022-07-19 15:47:21 +02:00
data = cv2.imencode(
f".{ext}",
image_numpy,
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
)[1]
2021-11-15 08:22:34 +01:00
image_bytes = data.tobytes()
return image_bytes
2023-12-29 02:55:47 +01:00
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
2023-02-06 15:00:47 +01:00
with io.BytesIO() as output:
2023-12-29 02:55:47 +01:00
kwargs = {k: v for k, v in infos.items() if v is not None}
2023-12-30 16:36:44 +01:00
if ext == "jpg":
ext = "jpeg"
2023-12-29 02:55:47 +01:00
if "png" == ext.lower() and "parameters" in kwargs:
2023-05-07 10:58:55 +02:00
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text("parameters", kwargs["parameters"])
kwargs["pnginfo"] = pnginfo_data
2023-12-29 02:55:47 +01:00
pil_img.save(output, format=ext, quality=quality, **kwargs)
2023-02-06 15:00:47 +01:00
image_bytes = output.getvalue()
return image_bytes
2023-12-29 02:55:47 +01:00
def load_img(img_bytes, gray: bool = False, return_info: bool = False):
2022-04-09 02:12:37 +02:00
alpha_channel = None
2023-01-17 14:05:17 +01:00
image = Image.open(io.BytesIO(img_bytes))
2023-02-06 15:00:47 +01:00
2023-12-29 02:55:47 +01:00
if return_info:
infos = image.info
2023-02-06 15:00:47 +01:00
2023-01-17 14:05:17 +01:00
try:
image = ImageOps.exif_transpose(image)
except:
pass
2021-11-15 08:22:34 +01:00
if gray:
2023-02-06 15:00:47 +01:00
image = image.convert("L")
2023-01-17 14:05:17 +01:00
np_img = np.array(image)
2021-11-15 08:22:34 +01:00
else:
2023-02-06 15:00:47 +01:00
if image.mode == "RGBA":
2023-01-17 14:05:17 +01:00
np_img = np.array(image)
2022-04-09 02:12:37 +02:00
alpha_channel = np_img[:, :, -1]
2023-01-17 14:05:17 +01:00
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
2023-02-06 15:00:47 +01:00
image = image.convert("RGB")
2023-01-17 14:05:17 +01:00
np_img = np.array(image)
2023-12-29 02:55:47 +01:00
if return_info:
return np_img, alpha_channel, infos
2022-04-09 02:12:37 +02:00
return np_img, alpha_channel
2021-11-15 08:22:34 +01:00
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
2021-11-15 08:22:34 +01:00
return np_img
def resize_max_size(
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
) -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
else:
return np_img
2022-07-19 15:47:21 +02:00
def pad_img_to_modulo(
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
):
2022-04-15 18:11:51 +02:00
"""
Args:
img: [H, W, C]
mod:
2022-07-14 10:49:03 +02:00
square: 是否为正方形
min_size:
2022-04-15 18:11:51 +02:00
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
2021-11-15 08:22:34 +01:00
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
2022-07-14 10:49:03 +02:00
if min_size is not None:
assert min_size % mod == 0
out_width = max(min_size, out_width)
out_height = max(min_size, out_height)
if square:
max_size = max(out_height, out_width)
out_height = max_size
out_width = max_size
2021-11-15 08:22:34 +01:00
return np.pad(
img,
2022-04-15 18:11:51 +02:00
((0, out_height - height), (0, out_width - width), (0, 0)),
2021-11-15 08:22:34 +01:00
mode="symmetric",
)
2022-03-23 03:02:01 +01:00
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
2022-04-15 18:11:51 +02:00
mask: (h, w, 1) 0~255
2022-03-23 03:02:01 +01:00
Returns:
"""
2022-04-15 18:11:51 +02:00
height, width = mask.shape[:2]
_, thresh = cv2.threshold(mask, 127, 255, 0)
2022-03-23 03:02:01 +01:00
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
2022-07-14 10:49:03 +02:00
box = np.array([x, y, x + w, y + h]).astype(int)
2022-03-23 03:02:01 +01:00
box[::2] = np.clip(box[::2], 0, width)
box[1::2] = np.clip(box[1::2], 0, height)
boxes.append(box)
return boxes
2022-11-27 14:25:27 +01:00
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (h, w) 0~255
Returns:
"""
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_area = 0
max_index = -1
for i, cnt in enumerate(contours):
area = cv2.contourArea(cnt)
if area > max_area:
max_area = area
max_index = i
if max_index != -1:
new_mask = np.zeros_like(mask)
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
else:
return mask
2023-12-27 15:00:07 +01:00
def is_mac():
return sys.platform == "darwin"
2023-12-30 16:36:44 +01:00
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def decode_base64_to_image(
encoding: str, gray=False
) -> Tuple[np.array, Optional[np.array], Dict]:
2024-01-31 15:10:13 +01:00
if encoding.startswith("data:image/") or encoding.startswith(
"data:application/octet-stream;base64,"
):
2023-12-30 16:36:44 +01:00
encoding = encoding.split(";")[1].split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
alpha_channel = None
try:
image = ImageOps.exif_transpose(image)
except:
pass
2024-01-04 14:39:59 +01:00
# exif_transpose will remove exif rotate infowe must call image.info after exif_transpose
infos = image.info
2023-12-30 16:36:44 +01:00
if gray:
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert("RGB")
np_img = np.array(image)
return np_img, alpha_channel, infos
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
img_bytes = pil_to_bytes(
image,
"png",
quality=quality,
infos=infos,
)
return base64.b64encode(img_bytes)
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
if alpha_channel is not None:
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
)
rgb_np_img = np.concatenate(
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
return rgb_np_img
2024-01-04 14:39:59 +01:00
def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
2024-01-30 06:30:43 +01:00
# fronted brush color "ffcc00bb"
2024-01-04 14:39:59 +01:00
# kernel_size = kernel_size*2+1
mask[mask >= 127] = 255
mask[mask < 127] = 0
2024-01-30 06:30:43 +01:00
if operate == "reverse":
mask = 255 - mask
2024-01-04 14:39:59 +01:00
else:
2024-01-30 06:30:43 +01:00
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
2024-01-04 14:39:59 +01:00
)
2024-01-30 06:30:43 +01:00
if operate == "expand":
mask = cv2.dilate(
mask,
kernel,
iterations=1,
)
else:
mask = cv2.erode(
mask,
kernel,
iterations=1,
)
2024-01-04 14:39:59 +01:00
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask
def gen_frontend_mask(bgr_or_gray_mask):
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
# fronted brush color "ffcc00bb"
# TODO: how to set kernel size?
kernel_size = 9
bgr_or_gray_mask = cv2.dilate(
bgr_or_gray_mask,
np.ones((kernel_size, kernel_size), np.uint8),
iterations=1,
)
res_mask = np.zeros(
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
)
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask