import io import os import sys from typing import List, Optional from urllib.parse import urlparse import cv2 from PIL import Image, ImageOps import numpy as np import torch from loguru import logger from torch.hub import download_url_to_file, get_dir def get_cache_path_by_url(url): parts = urlparse(url) hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") if not os.path.isdir(model_dir): os.makedirs(model_dir) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) return cached_file def download_model(url): cached_file = get_cache_path_by_url(url) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None download_url_to_file(url, cached_file, hash_prefix, progress=True) return cached_file def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def load_jit_model(url_or_path, device): if os.path.exists(url_or_path): model_path = url_or_path else: model_path = download_model(url_or_path) logger.info(f"Load model from: {model_path}") try: model = torch.jit.load(model_path).to(device) except: logger.error( f"Failed to load {model_path}, delete model and restart lama-cleaner" ) exit(-1) model.eval() return model def load_model(model: torch.nn.Module, url_or_path, device): if os.path.exists(url_or_path): model_path = url_or_path else: model_path = download_model(url_or_path) try: state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model.to(device) logger.info(f"Load model from: {model_path}") except: logger.error( f"Failed to load {model_path}, delete model and restart lama-cleaner" ) exit(-1) model.eval() return model def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: data = cv2.imencode( f".{ext}", image_numpy, [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], )[1] image_bytes = data.tobytes() return image_bytes def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes: with io.BytesIO() as output: pil_img.save(output, format=ext, exif=exif, quality=95) image_bytes = output.getvalue() return image_bytes def load_img(img_bytes, gray: bool = False, return_exif: bool = False): alpha_channel = None image = Image.open(io.BytesIO(img_bytes)) try: if return_exif: exif = image.getexif() except: exif = None logger.error("Failed to extract exif from image") try: image = ImageOps.exif_transpose(image) except: pass if gray: image = image.convert("L") np_img = np.array(image) else: if image.mode == "RGBA": np_img = np.array(image) alpha_channel = np_img[:, :, -1] np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) else: image = image.convert("RGB") np_img = np.array(image) if return_exif: return np_img, alpha_channel, exif return np_img, alpha_channel def norm_img(np_img): if len(np_img.shape) == 2: np_img = np_img[:, :, np.newaxis] np_img = np.transpose(np_img, (2, 0, 1)) np_img = np_img.astype("float32") / 255 return np_img def resize_max_size( np_img, size_limit: int, interpolation=cv2.INTER_CUBIC ) -> np.ndarray: # Resize image's longer size to size_limit if longer size larger than size_limit h, w = np_img.shape[:2] if max(h, w) > size_limit: ratio = size_limit / max(h, w) new_w = int(w * ratio + 0.5) new_h = int(h * ratio + 0.5) return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) else: return np_img def pad_img_to_modulo( img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None ): """ Args: img: [H, W, C] mod: square: 是否为正方形 min_size: Returns: """ if len(img.shape) == 2: img = img[:, :, np.newaxis] height, width = img.shape[:2] out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) if min_size is not None: assert min_size % mod == 0 out_width = max(min_size, out_width) out_height = max(min_size, out_height) if square: max_size = max(out_height, out_width) out_height = max_size out_width = max_size return np.pad( img, ((0, out_height - height), (0, out_width - width), (0, 0)), mode="symmetric", ) def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: """ Args: mask: (h, w, 1) 0~255 Returns: """ height, width = mask.shape[:2] _, thresh = cv2.threshold(mask, 127, 255, 0) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] for cnt in contours: x, y, w, h = cv2.boundingRect(cnt) box = np.array([x, y, x + w, y + h]).astype(int) box[::2] = np.clip(box[::2], 0, width) box[1::2] = np.clip(box[1::2], 0, height) boxes.append(box) return boxes def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: """ Args: mask: (h, w) 0~255 Returns: """ _, thresh = cv2.threshold(mask, 127, 255, 0) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) max_area = 0 max_index = -1 for i, cnt in enumerate(contours): area = cv2.contourArea(cnt) if area > max_area: max_area = area max_index = i if max_index != -1: new_mask = np.zeros_like(mask) return cv2.drawContours(new_mask, contours, max_index, 255, -1) else: return mask