2022-04-15 18:11:51 +02:00
|
|
|
|
import abc
|
2022-07-14 10:49:03 +02:00
|
|
|
|
from typing import Optional
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import torch
|
2022-11-10 18:45:20 +01:00
|
|
|
|
import numpy as np
|
2022-04-15 18:11:51 +02:00
|
|
|
|
from loguru import logger
|
|
|
|
|
|
2023-05-13 07:45:27 +02:00
|
|
|
|
from lama_cleaner.helper import (
|
|
|
|
|
boxes_from_mask,
|
|
|
|
|
resize_max_size,
|
|
|
|
|
pad_img_to_modulo,
|
|
|
|
|
switch_mps_device,
|
|
|
|
|
)
|
2023-12-01 03:15:35 +01:00
|
|
|
|
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
2023-11-15 01:50:35 +01:00
|
|
|
|
from lama_cleaner.model.utils import get_scheduler
|
|
|
|
|
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InpaintModel:
|
2023-02-11 06:30:09 +01:00
|
|
|
|
name = "base"
|
2022-07-14 10:49:03 +02:00
|
|
|
|
min_size: Optional[int] = None
|
2022-04-15 18:11:51 +02:00
|
|
|
|
pad_mod = 8
|
2022-07-14 10:49:03 +02:00
|
|
|
|
pad_to_square = False
|
2023-12-01 03:15:35 +01:00
|
|
|
|
is_erase_model = False
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
2022-09-15 16:21:27 +02:00
|
|
|
|
def __init__(self, device, **kwargs):
|
2022-04-15 18:11:51 +02:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
device:
|
|
|
|
|
"""
|
2023-02-11 06:30:09 +01:00
|
|
|
|
device = switch_mps_device(self.name, device)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
self.device = device
|
2022-09-15 16:21:27 +02:00
|
|
|
|
self.init_model(device, **kwargs)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
2022-09-15 16:21:27 +02:00
|
|
|
|
def init_model(self, device, **kwargs):
|
2022-04-15 18:11:51 +02:00
|
|
|
|
...
|
|
|
|
|
|
2022-04-17 17:31:12 +02:00
|
|
|
|
@staticmethod
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
|
def is_downloaded() -> bool:
|
|
|
|
|
...
|
|
|
|
|
|
2022-04-15 18:11:51 +02:00
|
|
|
|
@abc.abstractmethod
|
|
|
|
|
def forward(self, image, mask, config: Config):
|
2022-07-14 10:49:03 +02:00
|
|
|
|
"""Input images and output images have same size
|
|
|
|
|
images: [H, W, C] RGB
|
2022-09-15 16:21:27 +02:00
|
|
|
|
masks: [H, W, 1] 255 为 masks 区域
|
2022-04-15 18:11:51 +02:00
|
|
|
|
return: BGR IMAGE
|
|
|
|
|
"""
|
|
|
|
|
...
|
|
|
|
|
|
2023-11-16 14:12:06 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def download():
|
|
|
|
|
...
|
|
|
|
|
|
2022-04-15 18:11:51 +02:00
|
|
|
|
def _pad_forward(self, image, mask, config: Config):
|
|
|
|
|
origin_height, origin_width = image.shape[:2]
|
2022-09-15 16:21:27 +02:00
|
|
|
|
pad_image = pad_img_to_modulo(
|
|
|
|
|
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
|
|
|
|
)
|
|
|
|
|
pad_mask = pad_img_to_modulo(
|
|
|
|
|
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
|
|
|
|
)
|
2022-07-14 10:49:03 +02:00
|
|
|
|
|
|
|
|
|
logger.info(f"final forward pad size: {pad_image.shape}")
|
|
|
|
|
|
|
|
|
|
result = self.forward(pad_image, pad_mask, config)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
result = result[0:origin_height, 0:origin_width, :]
|
|
|
|
|
|
2022-11-25 02:29:20 +01:00
|
|
|
|
result, image, mask = self.forward_post_process(result, image, mask, config)
|
2022-11-10 18:45:20 +01:00
|
|
|
|
|
2023-11-15 02:10:13 +01:00
|
|
|
|
if config.sd_prevent_unmasked_area:
|
|
|
|
|
mask = mask[:, :, np.newaxis]
|
|
|
|
|
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
2022-04-15 18:11:51 +02:00
|
|
|
|
return result
|
|
|
|
|
|
2022-11-25 02:29:20 +01:00
|
|
|
|
def forward_post_process(self, result, image, mask, config):
|
|
|
|
|
return result, image, mask
|
|
|
|
|
|
2022-04-15 18:11:51 +02:00
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def __call__(self, image, mask, config: Config):
|
|
|
|
|
"""
|
2022-07-14 10:49:03 +02:00
|
|
|
|
images: [H, W, C] RGB, not normalized
|
|
|
|
|
masks: [H, W]
|
2022-04-15 18:11:51 +02:00
|
|
|
|
return: BGR IMAGE
|
|
|
|
|
"""
|
|
|
|
|
inpaint_result = None
|
|
|
|
|
logger.info(f"hd_strategy: {config.hd_strategy}")
|
|
|
|
|
if config.hd_strategy == HDStrategy.CROP:
|
|
|
|
|
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
|
|
|
|
logger.info(f"Run crop strategy")
|
|
|
|
|
boxes = boxes_from_mask(mask)
|
|
|
|
|
crop_result = []
|
|
|
|
|
for box in boxes:
|
|
|
|
|
crop_image, crop_box = self._run_box(image, mask, box, config)
|
|
|
|
|
crop_result.append((crop_image, crop_box))
|
|
|
|
|
|
|
|
|
|
inpaint_result = image[:, :, ::-1]
|
|
|
|
|
for crop_image, crop_box in crop_result:
|
|
|
|
|
x1, y1, x2, y2 = crop_box
|
|
|
|
|
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
|
|
|
|
|
|
|
|
|
elif config.hd_strategy == HDStrategy.RESIZE:
|
|
|
|
|
if max(image.shape) > config.hd_strategy_resize_limit:
|
|
|
|
|
origin_size = image.shape[:2]
|
2022-09-15 16:21:27 +02:00
|
|
|
|
downsize_image = resize_max_size(
|
|
|
|
|
image, size_limit=config.hd_strategy_resize_limit
|
|
|
|
|
)
|
|
|
|
|
downsize_mask = resize_max_size(
|
|
|
|
|
mask, size_limit=config.hd_strategy_resize_limit
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
|
|
|
|
)
|
|
|
|
|
inpaint_result = self._pad_forward(
|
|
|
|
|
downsize_image, downsize_mask, config
|
|
|
|
|
)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
# only paste masked area result
|
2022-09-15 16:21:27 +02:00
|
|
|
|
inpaint_result = cv2.resize(
|
|
|
|
|
inpaint_result,
|
|
|
|
|
(origin_size[1], origin_size[0]),
|
|
|
|
|
interpolation=cv2.INTER_CUBIC,
|
|
|
|
|
)
|
2022-09-02 05:08:32 +02:00
|
|
|
|
original_pixel_indices = mask < 127
|
2022-09-15 16:21:27 +02:00
|
|
|
|
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
|
|
|
|
original_pixel_indices
|
|
|
|
|
]
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
if inpaint_result is None:
|
|
|
|
|
inpaint_result = self._pad_forward(image, mask, config)
|
|
|
|
|
|
|
|
|
|
return inpaint_result
|
|
|
|
|
|
2022-09-04 10:00:42 +02:00
|
|
|
|
def _crop_box(self, image, mask, box, config: Config):
|
2022-04-15 18:11:51 +02:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: [H, W, C] RGB
|
|
|
|
|
mask: [H, W, 1]
|
|
|
|
|
box: [left,top,right,bottom]
|
|
|
|
|
|
|
|
|
|
Returns:
|
2022-09-04 10:00:42 +02:00
|
|
|
|
BGR IMAGE, (l, r, r, b)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
"""
|
|
|
|
|
box_h = box[3] - box[1]
|
|
|
|
|
box_w = box[2] - box[0]
|
|
|
|
|
cx = (box[0] + box[2]) // 2
|
|
|
|
|
cy = (box[1] + box[3]) // 2
|
|
|
|
|
img_h, img_w = image.shape[:2]
|
|
|
|
|
|
|
|
|
|
w = box_w + config.hd_strategy_crop_margin * 2
|
|
|
|
|
h = box_h + config.hd_strategy_crop_margin * 2
|
|
|
|
|
|
2022-08-22 17:23:48 +02:00
|
|
|
|
_l = cx - w // 2
|
|
|
|
|
_r = cx + w // 2
|
|
|
|
|
_t = cy - h // 2
|
|
|
|
|
_b = cy + h // 2
|
|
|
|
|
|
|
|
|
|
l = max(_l, 0)
|
|
|
|
|
r = min(_r, img_w)
|
|
|
|
|
t = max(_t, 0)
|
|
|
|
|
b = min(_b, img_h)
|
|
|
|
|
|
2022-09-04 10:00:42 +02:00
|
|
|
|
# try to get more context when crop around image edge
|
2022-08-22 17:23:48 +02:00
|
|
|
|
if _l < 0:
|
|
|
|
|
r += abs(_l)
|
|
|
|
|
if _r > img_w:
|
2022-09-15 16:21:27 +02:00
|
|
|
|
l -= _r - img_w
|
2022-08-22 17:23:48 +02:00
|
|
|
|
if _t < 0:
|
|
|
|
|
b += abs(_t)
|
|
|
|
|
if _b > img_h:
|
2022-09-15 16:21:27 +02:00
|
|
|
|
t -= _b - img_h
|
2022-08-22 17:23:48 +02:00
|
|
|
|
|
|
|
|
|
l = max(l, 0)
|
|
|
|
|
r = min(r, img_w)
|
|
|
|
|
t = max(t, 0)
|
|
|
|
|
b = min(b, img_h)
|
2022-04-15 18:11:51 +02:00
|
|
|
|
|
|
|
|
|
crop_img = image[t:b, l:r, :]
|
|
|
|
|
crop_mask = mask[t:b, l:r]
|
|
|
|
|
|
|
|
|
|
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
|
|
|
|
|
2022-09-04 10:00:42 +02:00
|
|
|
|
return crop_img, crop_mask, [l, t, r, b]
|
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
def _calculate_cdf(self, histogram):
|
|
|
|
|
cdf = histogram.cumsum()
|
|
|
|
|
normalized_cdf = cdf / float(cdf.max())
|
|
|
|
|
return normalized_cdf
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
def _calculate_lookup(self, source_cdf, reference_cdf):
|
|
|
|
|
lookup_table = np.zeros(256)
|
|
|
|
|
lookup_val = 0
|
|
|
|
|
for source_index, source_val in enumerate(source_cdf):
|
|
|
|
|
for reference_index, reference_val in enumerate(reference_cdf):
|
|
|
|
|
if reference_val >= source_val:
|
|
|
|
|
lookup_val = reference_index
|
|
|
|
|
break
|
|
|
|
|
lookup_table[source_index] = lookup_val
|
|
|
|
|
return lookup_table
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
def _match_histograms(self, source, reference, mask):
|
|
|
|
|
transformed_channels = []
|
|
|
|
|
for channel in range(source.shape[-1]):
|
|
|
|
|
source_channel = source[:, :, channel]
|
|
|
|
|
reference_channel = reference[:, :, channel]
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
# only calculate histograms for non-masked parts
|
2022-11-25 02:29:20 +01:00
|
|
|
|
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
2023-05-13 07:45:27 +02:00
|
|
|
|
reference_histogram, _ = np.histogram(
|
|
|
|
|
reference_channel[mask == 0], 256, [0, 256]
|
|
|
|
|
)
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
source_cdf = self._calculate_cdf(source_histogram)
|
|
|
|
|
reference_cdf = self._calculate_cdf(reference_histogram)
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
result = cv2.merge(transformed_channels)
|
|
|
|
|
result = cv2.convertScaleAbs(result)
|
2022-11-25 02:29:20 +01:00
|
|
|
|
|
2022-11-23 23:50:58 +01:00
|
|
|
|
return result
|
|
|
|
|
|
2022-12-10 15:06:15 +01:00
|
|
|
|
def _apply_cropper(self, image, mask, config: Config):
|
|
|
|
|
img_h, img_w = image.shape[:2]
|
|
|
|
|
l, t, w, h = (
|
|
|
|
|
config.croper_x,
|
|
|
|
|
config.croper_y,
|
|
|
|
|
config.croper_width,
|
|
|
|
|
config.croper_height,
|
|
|
|
|
)
|
|
|
|
|
r = l + w
|
|
|
|
|
b = t + h
|
|
|
|
|
|
|
|
|
|
l = max(l, 0)
|
|
|
|
|
r = min(r, img_w)
|
|
|
|
|
t = max(t, 0)
|
|
|
|
|
b = min(b, img_h)
|
|
|
|
|
|
|
|
|
|
crop_img = image[t:b, l:r, :]
|
|
|
|
|
crop_mask = mask[t:b, l:r]
|
|
|
|
|
return crop_img, crop_mask, (l, t, r, b)
|
|
|
|
|
|
2022-09-04 10:00:42 +02:00
|
|
|
|
def _run_box(self, image, mask, box, config: Config):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image: [H, W, C] RGB
|
|
|
|
|
mask: [H, W, 1]
|
|
|
|
|
box: [left,top,right,bottom]
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
BGR IMAGE
|
|
|
|
|
"""
|
|
|
|
|
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
|
|
|
|
|
2022-04-15 18:11:51 +02:00
|
|
|
|
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
2023-01-27 13:59:22 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiffusionInpaintModel(InpaintModel):
|
2023-12-01 03:15:35 +01:00
|
|
|
|
def __init__(self, device, **kwargs):
|
|
|
|
|
if kwargs.get("model_id_or_path"):
|
|
|
|
|
# 用于自定义 diffusers 模型
|
|
|
|
|
self.model_id_or_path = kwargs["model_id_or_path"]
|
|
|
|
|
super().__init__(device, **kwargs)
|
|
|
|
|
|
2023-01-27 13:59:22 +01:00
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def __call__(self, image, mask, config: Config):
|
|
|
|
|
"""
|
|
|
|
|
images: [H, W, C] RGB, not normalized
|
|
|
|
|
masks: [H, W]
|
|
|
|
|
return: BGR IMAGE
|
|
|
|
|
"""
|
|
|
|
|
# boxes = boxes_from_mask(mask)
|
|
|
|
|
if config.use_croper:
|
2023-08-30 07:28:31 +02:00
|
|
|
|
if config.croper_is_outpainting:
|
|
|
|
|
inpaint_result = self._do_outpainting(image, config)
|
|
|
|
|
else:
|
|
|
|
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(
|
|
|
|
|
image, mask, config
|
|
|
|
|
)
|
|
|
|
|
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
|
|
|
|
inpaint_result = image[:, :, ::-1]
|
|
|
|
|
inpaint_result[t:b, l:r, :] = crop_image
|
2023-01-27 13:59:22 +01:00
|
|
|
|
else:
|
|
|
|
|
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
|
|
|
|
|
|
|
|
|
return inpaint_result
|
|
|
|
|
|
2023-08-30 07:28:31 +02:00
|
|
|
|
def _do_outpainting(self, image, config: Config):
|
|
|
|
|
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
|
|
|
|
# 从 image 中 crop 出 outpainting 区域
|
|
|
|
|
image_h, image_w = image.shape[:2]
|
|
|
|
|
cropper_l = config.croper_x
|
|
|
|
|
cropper_t = config.croper_y
|
|
|
|
|
cropper_r = config.croper_x + config.croper_width
|
|
|
|
|
cropper_b = config.croper_y + config.croper_height
|
|
|
|
|
image_l = 0
|
|
|
|
|
image_t = 0
|
|
|
|
|
image_r = image_w
|
|
|
|
|
image_b = image_h
|
|
|
|
|
|
|
|
|
|
# 类似求 IOU
|
|
|
|
|
l = max(cropper_l, image_l)
|
|
|
|
|
t = max(cropper_t, image_t)
|
|
|
|
|
r = min(cropper_r, image_r)
|
|
|
|
|
b = min(cropper_b, image_b)
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
|
0 <= l < r and 0 <= t < b
|
|
|
|
|
), f"cropper and image not overlap, {l},{t},{r},{b}"
|
|
|
|
|
|
|
|
|
|
cropped_image = image[t:b, l:r, :]
|
|
|
|
|
padding_l = max(0, image_l - cropper_l)
|
|
|
|
|
padding_t = max(0, image_t - cropper_t)
|
|
|
|
|
padding_r = max(0, cropper_r - image_r)
|
|
|
|
|
padding_b = max(0, cropper_b - image_b)
|
|
|
|
|
|
|
|
|
|
zero_padding_count = [padding_l, padding_t, padding_r, padding_b].count(0)
|
|
|
|
|
|
|
|
|
|
if zero_padding_count not in [0, 3]:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"padding count({zero_padding_count}) not 0 or 3, may result in bad edge outpainting"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
expanded_image, mask_image = expand_image(
|
|
|
|
|
cropped_image,
|
|
|
|
|
left=padding_l,
|
|
|
|
|
top=padding_t,
|
|
|
|
|
right=padding_r,
|
|
|
|
|
bottom=padding_b,
|
|
|
|
|
softness=config.sd_outpainting_softness,
|
|
|
|
|
space=config.sd_outpainting_space,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 最终扩大了的 image, BGR
|
|
|
|
|
expanded_cropped_result_image = self._scaled_pad_forward(
|
|
|
|
|
expanded_image, mask_image, config
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# RGB -> BGR
|
|
|
|
|
outpainting_image = cv2.copyMakeBorder(
|
|
|
|
|
image,
|
|
|
|
|
left=padding_l,
|
|
|
|
|
top=padding_t,
|
|
|
|
|
right=padding_r,
|
|
|
|
|
bottom=padding_b,
|
|
|
|
|
borderType=cv2.BORDER_CONSTANT,
|
|
|
|
|
value=0,
|
|
|
|
|
)[:, :, ::-1]
|
|
|
|
|
|
|
|
|
|
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
|
|
|
|
paste_t = 0 if config.croper_y < 0 else config.croper_y
|
|
|
|
|
paste_l = 0 if config.croper_x < 0 else config.croper_x
|
|
|
|
|
|
|
|
|
|
outpainting_image[
|
|
|
|
|
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
|
|
|
|
paste_l : paste_l + expanded_cropped_result_image.shape[1],
|
|
|
|
|
:,
|
|
|
|
|
] = expanded_cropped_result_image
|
|
|
|
|
return outpainting_image
|
|
|
|
|
|
2023-01-27 13:59:22 +01:00
|
|
|
|
def _scaled_pad_forward(self, image, mask, config: Config):
|
|
|
|
|
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
|
|
|
|
origin_size = image.shape[:2]
|
|
|
|
|
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
|
|
|
|
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
2023-05-13 07:45:27 +02:00
|
|
|
|
if config.sd_scale != 1:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
|
|
|
|
)
|
2023-01-27 13:59:22 +01:00
|
|
|
|
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
|
|
|
|
# only paste masked area result
|
|
|
|
|
inpaint_result = cv2.resize(
|
|
|
|
|
inpaint_result,
|
|
|
|
|
(origin_size[1], origin_size[0]),
|
|
|
|
|
interpolation=cv2.INTER_CUBIC,
|
|
|
|
|
)
|
2023-08-30 07:28:31 +02:00
|
|
|
|
|
|
|
|
|
# blend result, copy from g_diffuser_bot
|
|
|
|
|
# mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
|
|
|
|
|
# inpaint_result = np.clip(
|
|
|
|
|
# inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
|
|
|
|
|
# )
|
|
|
|
|
# original_pixel_indices = mask < 127
|
|
|
|
|
# inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
|
|
|
|
# original_pixel_indices
|
|
|
|
|
# ]
|
2023-01-27 13:59:22 +01:00
|
|
|
|
return inpaint_result
|
2023-11-15 01:50:35 +01:00
|
|
|
|
|
|
|
|
|
def set_scheduler(self, config: Config):
|
|
|
|
|
scheduler_config = self.model.scheduler.config
|
|
|
|
|
sd_sampler = config.sd_sampler
|
|
|
|
|
if config.sd_lcm_lora:
|
|
|
|
|
sd_sampler = SDSampler.lcm
|
|
|
|
|
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
|
|
|
|
self.model.scheduler = scheduler
|
2023-11-15 01:58:52 +01:00
|
|
|
|
|
|
|
|
|
def forward_post_process(self, result, image, mask, config):
|
|
|
|
|
if config.sd_match_histograms:
|
|
|
|
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
|
|
|
|
|
|
|
|
|
if config.sd_mask_blur != 0:
|
|
|
|
|
k = 2 * config.sd_mask_blur + 1
|
|
|
|
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
|
|
|
|
return result, image, mask
|