IOPaint/lama_cleaner/helper.py
2022-03-23 10:15:23 +08:00

108 lines
3.0 KiB
Python

import os
import sys
from typing import List
from urllib.parse import urlparse
import cv2
import numpy as np
import torch
from torch.hub import download_url_to_file, get_dir
def download_model(url):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
if not os.path.isdir(model_dir):
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
return cached_file
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
data = cv2.imencode(".jpg", image_numpy)[1]
image_bytes = data.tobytes()
return image_bytes
def load_img(img_bytes, gray: bool = False):
nparr = np.frombuffer(img_bytes, np.uint8)
if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
else:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
return np_img
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def resize_max_size(
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
) -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
else:
return np_img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
mode="symmetric",
)
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (1, h, w) 0~1
Returns:
"""
height, width = mask.shape[1:]
_, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
box = np.array([x, y, x + w, y + h]).astype(np.int)
box[::2] = np.clip(box[::2], 0, width)
box[1::2] = np.clip(box[1::2], 0, height)
boxes.append(box)
return boxes