IOPaint/lama_cleaner/model/base.py

248 lines
7.6 KiB
Python
Raw Normal View History

2022-04-15 18:11:51 +02:00
import abc
2022-07-14 10:49:03 +02:00
from typing import Optional
2022-04-15 18:11:51 +02:00
import cv2
import torch
import numpy as np
2022-04-15 18:11:51 +02:00
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
2022-07-14 10:49:03 +02:00
min_size: Optional[int] = None
2022-04-15 18:11:51 +02:00
pad_mod = 8
2022-07-14 10:49:03 +02:00
pad_to_square = False
2022-04-15 18:11:51 +02:00
2022-09-15 16:21:27 +02:00
def __init__(self, device, **kwargs):
2022-04-15 18:11:51 +02:00
"""
Args:
device:
"""
self.device = device
2022-09-15 16:21:27 +02:00
self.init_model(device, **kwargs)
2022-04-15 18:11:51 +02:00
@abc.abstractmethod
2022-09-15 16:21:27 +02:00
def init_model(self, device, **kwargs):
2022-04-15 18:11:51 +02:00
...
2022-04-17 17:31:12 +02:00
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool:
...
2022-04-15 18:11:51 +02:00
@abc.abstractmethod
def forward(self, image, mask, config: Config):
2022-07-14 10:49:03 +02:00
"""Input images and output images have same size
images: [H, W, C] RGB
2022-09-15 16:21:27 +02:00
masks: [H, W, 1] 255 masks 区域
2022-04-15 18:11:51 +02:00
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
2022-09-15 16:21:27 +02:00
pad_image = pad_img_to_modulo(
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
pad_mask = pad_img_to_modulo(
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
2022-07-14 10:49:03 +02:00
logger.info(f"final forward pad size: {pad_image.shape}")
result = self.forward(pad_image, pad_mask, config)
2022-04-15 18:11:51 +02:00
result = result[0:origin_height, 0:origin_width, :]
2022-11-25 02:29:20 +01:00
result, image, mask = self.forward_post_process(result, image, mask, config)
mask = mask[:, :, np.newaxis]
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
2022-04-15 18:11:51 +02:00
return result
2022-11-25 02:29:20 +01:00
def forward_post_process(self, result, image, mask, config):
return result, image, mask
2022-04-15 18:11:51 +02:00
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
2022-07-14 10:49:03 +02:00
images: [H, W, C] RGB, not normalized
masks: [H, W]
2022-04-15 18:11:51 +02:00
return: BGR IMAGE
"""
inpaint_result = None
logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
2022-09-15 16:21:27 +02:00
downsize_image = resize_max_size(
image, size_limit=config.hd_strategy_resize_limit
)
downsize_mask = resize_max_size(
mask, size_limit=config.hd_strategy_resize_limit
)
logger.info(
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
)
inpaint_result = self._pad_forward(
downsize_image, downsize_mask, config
)
2022-04-15 18:11:51 +02:00
# only paste masked area result
2022-09-15 16:21:27 +02:00
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
2022-09-02 05:08:32 +02:00
original_pixel_indices = mask < 127
2022-09-15 16:21:27 +02:00
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
original_pixel_indices
]
2022-04-15 18:11:51 +02:00
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
2022-09-04 10:00:42 +02:00
def _crop_box(self, image, mask, box, config: Config):
2022-04-15 18:11:51 +02:00
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
2022-09-04 10:00:42 +02:00
BGR IMAGE, (l, r, r, b)
2022-04-15 18:11:51 +02:00
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
2022-08-22 17:23:48 +02:00
_l = cx - w // 2
_r = cx + w // 2
_t = cy - h // 2
_b = cy + h // 2
l = max(_l, 0)
r = min(_r, img_w)
t = max(_t, 0)
b = min(_b, img_h)
2022-09-04 10:00:42 +02:00
# try to get more context when crop around image edge
2022-08-22 17:23:48 +02:00
if _l < 0:
r += abs(_l)
if _r > img_w:
2022-09-15 16:21:27 +02:00
l -= _r - img_w
2022-08-22 17:23:48 +02:00
if _t < 0:
b += abs(_t)
if _b > img_h:
2022-09-15 16:21:27 +02:00
t -= _b - img_h
2022-08-22 17:23:48 +02:00
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
2022-04-15 18:11:51 +02:00
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
2022-09-04 10:00:42 +02:00
return crop_img, crop_mask, [l, t, r, b]
def _calculate_cdf(self, histogram):
cdf = histogram.cumsum()
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
2022-11-25 02:29:20 +01:00
def _calculate_lookup(self, source_cdf, reference_cdf):
lookup_table = np.zeros(256)
lookup_val = 0
for source_index, source_val in enumerate(source_cdf):
for reference_index, reference_val in enumerate(reference_cdf):
if reference_val >= source_val:
lookup_val = reference_index
break
lookup_table[source_index] = lookup_val
return lookup_table
2022-11-25 02:29:20 +01:00
def _match_histograms(self, source, reference, mask):
transformed_channels = []
for channel in range(source.shape[-1]):
source_channel = source[:, :, channel]
reference_channel = reference[:, :, channel]
2022-11-25 02:29:20 +01:00
# only calculate histograms for non-masked parts
2022-11-25 02:29:20 +01:00
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
source_cdf = self._calculate_cdf(source_histogram)
reference_cdf = self._calculate_cdf(reference_histogram)
2022-11-25 02:29:20 +01:00
lookup = self._calculate_lookup(source_cdf, reference_cdf)
2022-11-25 02:29:20 +01:00
transformed_channels.append(cv2.LUT(source_channel, lookup))
2022-11-25 02:29:20 +01:00
result = cv2.merge(transformed_channels)
result = cv2.convertScaleAbs(result)
2022-11-25 02:29:20 +01:00
return result
2022-12-10 15:06:15 +01:00
def _apply_cropper(self, image, mask, config: Config):
img_h, img_w = image.shape[:2]
l, t, w, h = (
config.croper_x,
config.croper_y,
config.croper_width,
config.croper_height,
)
r = l + w
b = t + h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
return crop_img, crop_mask, (l, t, r, b)
2022-09-04 10:00:42 +02:00
def _run_box(self, image, mask, box, config: Config):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
2022-04-15 18:11:51 +02:00
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]