import abc from typing import Optional import cv2 import torch import numpy as np from loguru import logger from lama_cleaner.helper import ( boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device, ) from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.utils import get_scheduler from lama_cleaner.schema import Config, HDStrategy, SDSampler class InpaintModel: name = "base" min_size: Optional[int] = None pad_mod = 8 pad_to_square = False is_erase_model = False def __init__(self, device, **kwargs): """ Args: device: """ device = switch_mps_device(self.name, device) self.device = device self.init_model(device, **kwargs) @abc.abstractmethod def init_model(self, device, **kwargs): ... @staticmethod @abc.abstractmethod def is_downloaded() -> bool: return False @abc.abstractmethod def forward(self, image, mask, config: Config): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W, 1] 255 为 masks 区域 return: BGR IMAGE """ ... @staticmethod def download(): ... def _pad_forward(self, image, mask, config: Config): origin_height, origin_width = image.shape[:2] pad_image = pad_img_to_modulo( image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size ) pad_mask = pad_img_to_modulo( mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size ) logger.info(f"final forward pad size: {pad_image.shape}") image, mask = self.forward_pre_process(image, mask, config) result = self.forward(pad_image, pad_mask, config) result = result[0:origin_height, 0:origin_width, :] result, image, mask = self.forward_post_process(result, image, mask, config) if config.sd_prevent_unmasked_area: mask = mask[:, :, np.newaxis] result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) return result def forward_pre_process(self, image, mask, config): return image, mask def forward_post_process(self, result, image, mask, config): return result, image, mask @torch.no_grad() def __call__(self, image, mask, config: Config): """ images: [H, W, C] RGB, not normalized masks: [H, W] return: BGR IMAGE """ inpaint_result = None logger.info(f"hd_strategy: {config.hd_strategy}") if config.hd_strategy == HDStrategy.CROP: if max(image.shape) > config.hd_strategy_crop_trigger_size: logger.info(f"Run crop strategy") boxes = boxes_from_mask(mask) crop_result = [] for box in boxes: crop_image, crop_box = self._run_box(image, mask, box, config) crop_result.append((crop_image, crop_box)) inpaint_result = image[:, :, ::-1] for crop_image, crop_box in crop_result: x1, y1, x2, y2 = crop_box inpaint_result[y1:y2, x1:x2, :] = crop_image elif config.hd_strategy == HDStrategy.RESIZE: if max(image.shape) > config.hd_strategy_resize_limit: origin_size = image.shape[:2] downsize_image = resize_max_size( image, size_limit=config.hd_strategy_resize_limit ) downsize_mask = resize_max_size( mask, size_limit=config.hd_strategy_resize_limit ) logger.info( f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" ) inpaint_result = self._pad_forward( downsize_image, downsize_mask, config ) # only paste masked area result inpaint_result = cv2.resize( inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC, ) original_pixel_indices = mask < 127 inpaint_result[original_pixel_indices] = image[:, :, ::-1][ original_pixel_indices ] if inpaint_result is None: inpaint_result = self._pad_forward(image, mask, config) return inpaint_result def _crop_box(self, image, mask, box, config: Config): """ Args: image: [H, W, C] RGB mask: [H, W, 1] box: [left,top,right,bottom] Returns: BGR IMAGE, (l, r, r, b) """ box_h = box[3] - box[1] box_w = box[2] - box[0] cx = (box[0] + box[2]) // 2 cy = (box[1] + box[3]) // 2 img_h, img_w = image.shape[:2] w = box_w + config.hd_strategy_crop_margin * 2 h = box_h + config.hd_strategy_crop_margin * 2 _l = cx - w // 2 _r = cx + w // 2 _t = cy - h // 2 _b = cy + h // 2 l = max(_l, 0) r = min(_r, img_w) t = max(_t, 0) b = min(_b, img_h) # try to get more context when crop around image edge if _l < 0: r += abs(_l) if _r > img_w: l -= _r - img_w if _t < 0: b += abs(_t) if _b > img_h: t -= _b - img_h l = max(l, 0) r = min(r, img_w) t = max(t, 0) b = min(b, img_h) crop_img = image[t:b, l:r, :] crop_mask = mask[t:b, l:r] logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") return crop_img, crop_mask, [l, t, r, b] def _calculate_cdf(self, histogram): cdf = histogram.cumsum() normalized_cdf = cdf / float(cdf.max()) return normalized_cdf def _calculate_lookup(self, source_cdf, reference_cdf): lookup_table = np.zeros(256) lookup_val = 0 for source_index, source_val in enumerate(source_cdf): for reference_index, reference_val in enumerate(reference_cdf): if reference_val >= source_val: lookup_val = reference_index break lookup_table[source_index] = lookup_val return lookup_table def _match_histograms(self, source, reference, mask): transformed_channels = [] for channel in range(source.shape[-1]): source_channel = source[:, :, channel] reference_channel = reference[:, :, channel] # only calculate histograms for non-masked parts source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) reference_histogram, _ = np.histogram( reference_channel[mask == 0], 256, [0, 256] ) source_cdf = self._calculate_cdf(source_histogram) reference_cdf = self._calculate_cdf(reference_histogram) lookup = self._calculate_lookup(source_cdf, reference_cdf) transformed_channels.append(cv2.LUT(source_channel, lookup)) result = cv2.merge(transformed_channels) result = cv2.convertScaleAbs(result) return result def _apply_cropper(self, image, mask, config: Config): img_h, img_w = image.shape[:2] l, t, w, h = ( config.croper_x, config.croper_y, config.croper_width, config.croper_height, ) r = l + w b = t + h l = max(l, 0) r = min(r, img_w) t = max(t, 0) b = min(b, img_h) crop_img = image[t:b, l:r, :] crop_mask = mask[t:b, l:r] return crop_img, crop_mask, (l, t, r, b) def _run_box(self, image, mask, box, config: Config): """ Args: image: [H, W, C] RGB mask: [H, W, 1] box: [left,top,right,bottom] Returns: BGR IMAGE """ crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] class DiffusionInpaintModel(InpaintModel): def __init__(self, device, **kwargs): self.model_info = kwargs["model_info"] self.model_id_or_path = self.model_info.path super().__init__(device, **kwargs) @torch.no_grad() def __call__(self, image, mask, config: Config): """ images: [H, W, C] RGB, not normalized masks: [H, W] return: BGR IMAGE """ # boxes = boxes_from_mask(mask) if config.use_croper: crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) inpaint_result = image[:, :, ::-1] inpaint_result[t:b, l:r, :] = crop_image elif config.use_extender: inpaint_result = self._do_outpainting(image, config) else: inpaint_result = self._scaled_pad_forward(image, mask, config) return inpaint_result def _do_outpainting(self, image, config: Config): # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数 # 从 image 中 crop 出 outpainting 区域 image_h, image_w = image.shape[:2] cropper_l = config.extender_x cropper_t = config.extender_y cropper_r = config.extender_x + config.extender_width cropper_b = config.extender_y + config.extender_height image_l = 0 image_t = 0 image_r = image_w image_b = image_h # 类似求 IOU l = max(cropper_l, image_l) t = max(cropper_t, image_t) r = min(cropper_r, image_r) b = min(cropper_b, image_b) assert ( 0 <= l < r and 0 <= t < b ), f"cropper and image not overlap, {l},{t},{r},{b}" cropped_image = image[t:b, l:r, :] padding_l = max(0, image_l - cropper_l) padding_t = max(0, image_t - cropper_t) padding_r = max(0, cropper_r - image_r) padding_b = max(0, cropper_b - image_b) zero_padding_count = [padding_l, padding_t, padding_r, padding_b].count(0) if zero_padding_count not in [0, 3]: logger.warning( f"padding count({zero_padding_count}) not 0 or 3, may result in bad edge outpainting" ) expanded_image, mask_image = expand_image( cropped_image, left=padding_l, top=padding_t, right=padding_r, bottom=padding_b, softness=config.sd_outpainting_softness, space=config.sd_outpainting_space, ) # 最终扩大了的 image, BGR expanded_cropped_result_image = self._scaled_pad_forward( expanded_image, mask_image, config ) # RGB -> BGR outpainting_image = cv2.copyMakeBorder( image, left=padding_l, top=padding_t, right=padding_r, bottom=padding_b, borderType=cv2.BORDER_CONSTANT, value=0, )[:, :, ::-1] # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend paste_t = 0 if config.extender_y < 0 else config.extender_y paste_l = 0 if config.extender_x < 0 else config.extender_x outpainting_image[ paste_t : paste_t + expanded_cropped_result_image.shape[0], paste_l : paste_l + expanded_cropped_result_image.shape[1], :, ] = expanded_cropped_result_image return outpainting_image def _scaled_pad_forward(self, image, mask, config: Config): longer_side_length = int(config.sd_scale * max(image.shape[:2])) origin_size = image.shape[:2] downsize_image = resize_max_size(image, size_limit=longer_side_length) downsize_mask = resize_max_size(mask, size_limit=longer_side_length) if config.sd_scale != 1: logger.info( f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" ) inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) # only paste masked area result inpaint_result = cv2.resize( inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC, ) # blend result, copy from g_diffuser_bot # mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0) # inpaint_result = np.clip( # inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0 # ) # original_pixel_indices = mask < 127 # inpaint_result[original_pixel_indices] = image[:, :, ::-1][ # original_pixel_indices # ] return inpaint_result def set_scheduler(self, config: Config): scheduler_config = self.model.scheduler.config sd_sampler = config.sd_sampler if config.sd_lcm_lora: sd_sampler = SDSampler.lcm logger.info(f"LCM Lora enabled, use {sd_sampler} sampler") scheduler = get_scheduler(sd_sampler, scheduler_config) self.model.scheduler = scheduler def forward_pre_process(self, image, mask, config): if config.sd_mask_blur != 0: k = 2 * config.sd_mask_blur + 1 mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] return image, mask def forward_post_process(self, result, image, mask, config): if config.sd_match_histograms: result = self._match_histograms(result, image[:, :, ::-1], mask) if config.sd_mask_blur != 0: k = 2 * config.sd_mask_blur + 1 mask = cv2.GaussianBlur(mask, (k, k), 0) return result, image, mask