From c7c309cb89dd7bfe4d68c59d032d2edfa315c7f7 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 30 Aug 2023 13:28:31 +0800 Subject: [PATCH] backend add outpainting --- lama_cleaner/model/base.py | 101 ++++++++++++-- lama_cleaner/model/g_diffuser_bot.py | 181 +++++++++++++++++++++++++ lama_cleaner/model/utils.py | 10 +- lama_cleaner/schema.py | 5 + lama_cleaner/tests/test_outpainting.py | 70 ++++++++++ 5 files changed, 354 insertions(+), 13 deletions(-) create mode 100644 lama_cleaner/model/g_diffuser_bot.py create mode 100644 lama_cleaner/tests/test_outpainting.py diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 08c27b1..c452690 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -12,6 +12,7 @@ from lama_cleaner.helper import ( pad_img_to_modulo, switch_mps_device, ) +from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb from lama_cleaner.schema import Config, HDStrategy @@ -266,15 +267,93 @@ class DiffusionInpaintModel(InpaintModel): """ # boxes = boxes_from_mask(mask) if config.use_croper: - crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) - crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) - inpaint_result = image[:, :, ::-1] - inpaint_result[t:b, l:r, :] = crop_image + if config.croper_is_outpainting: + inpaint_result = self._do_outpainting(image, config) + else: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper( + image, mask, config + ) + crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image else: inpaint_result = self._scaled_pad_forward(image, mask, config) return inpaint_result + def _do_outpainting(self, image, config: Config): + # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数 + # 从 image 中 crop 出 outpainting 区域 + image_h, image_w = image.shape[:2] + cropper_l = config.croper_x + cropper_t = config.croper_y + cropper_r = config.croper_x + config.croper_width + cropper_b = config.croper_y + config.croper_height + image_l = 0 + image_t = 0 + image_r = image_w + image_b = image_h + + # 类似求 IOU + l = max(cropper_l, image_l) + t = max(cropper_t, image_t) + r = min(cropper_r, image_r) + b = min(cropper_b, image_b) + + assert ( + 0 <= l < r and 0 <= t < b + ), f"cropper and image not overlap, {l},{t},{r},{b}" + + cropped_image = image[t:b, l:r, :] + padding_l = max(0, image_l - cropper_l) + padding_t = max(0, image_t - cropper_t) + padding_r = max(0, cropper_r - image_r) + padding_b = max(0, cropper_b - image_b) + + zero_padding_count = [padding_l, padding_t, padding_r, padding_b].count(0) + + if zero_padding_count not in [0, 3]: + logger.warning( + f"padding count({zero_padding_count}) not 0 or 3, may result in bad edge outpainting" + ) + + expanded_image, mask_image = expand_image( + cropped_image, + left=padding_l, + top=padding_t, + right=padding_r, + bottom=padding_b, + softness=config.sd_outpainting_softness, + space=config.sd_outpainting_space, + ) + + # 最终扩大了的 image, BGR + expanded_cropped_result_image = self._scaled_pad_forward( + expanded_image, mask_image, config + ) + + # RGB -> BGR + outpainting_image = cv2.copyMakeBorder( + image, + left=padding_l, + top=padding_t, + right=padding_r, + bottom=padding_b, + borderType=cv2.BORDER_CONSTANT, + value=0, + )[:, :, ::-1] + + # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend + paste_t = 0 if config.croper_y < 0 else config.croper_y + paste_l = 0 if config.croper_x < 0 else config.croper_x + + outpainting_image[ + paste_t : paste_t + expanded_cropped_result_image.shape[0], + paste_l : paste_l + expanded_cropped_result_image.shape[1], + :, + ] = expanded_cropped_result_image + return outpainting_image + def _scaled_pad_forward(self, image, mask, config: Config): longer_side_length = int(config.sd_scale * max(image.shape[:2])) origin_size = image.shape[:2] @@ -291,8 +370,14 @@ class DiffusionInpaintModel(InpaintModel): (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC, ) - original_pixel_indices = mask < 127 - inpaint_result[original_pixel_indices] = image[:, :, ::-1][ - original_pixel_indices - ] + + # blend result, copy from g_diffuser_bot + # mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0) + # inpaint_result = np.clip( + # inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0 + # ) + # original_pixel_indices = mask < 127 + # inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + # original_pixel_indices + # ] return inpaint_result diff --git a/lama_cleaner/model/g_diffuser_bot.py b/lama_cleaner/model/g_diffuser_bot.py new file mode 100644 index 0000000..450d618 --- /dev/null +++ b/lama_cleaner/model/g_diffuser_bot.py @@ -0,0 +1,181 @@ +# code copy from: https://github.com/parlance-zz/g-diffuser-bot +import cv2 +import numpy as np + + +def np_img_grey_to_rgb(data): + if data.ndim == 3: + return data + return np.expand_dims(data, 2) * np.ones((1, 1, 3)) + + +def convolve(data1, data2): # fast convolution with fft + if data1.ndim != data2.ndim: # promote to rgb if mismatch + if data1.ndim < 3: + data1 = np_img_grey_to_rgb(data1) + if data2.ndim < 3: + data2 = np_img_grey_to_rgb(data2) + return ifft2(fft2(data1) * fft2(data2)) + + +def fft2(data): + if data.ndim > 2: # multiple channels + out_fft = np.zeros( + (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 + ) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") + out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) + else: # single channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") + out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) + + return out_fft + + +def ifft2(data): + if data.ndim > 2: # multiple channels + out_ifft = np.zeros( + (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 + ) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") + out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) + else: # single channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") + out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) + + return out_ifft + + +def get_gradient_kernel(width, height, std=3.14, mode="linear"): + window_scale_x = float( + width / min(width, height) + ) # for non-square aspect ratios we still want a circular kernel + window_scale_y = float(height / min(width, height)) + if mode == "gaussian": + x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x + kx = np.exp(-x * x * std) + if window_scale_x != window_scale_y: + y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y + ky = np.exp(-y * y * std) + else: + y = x + ky = kx + return np.outer(kx, ky) + elif mode == "linear": + x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x + if window_scale_x != window_scale_y: + y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y + else: + y = x + return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0) + else: + raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode)) + + +def image_blur(data, std=3.14, mode="linear"): + width = data.shape[0] + height = data.shape[1] + kernel = get_gradient_kernel(width, height, std, mode=mode) + return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel)))) + + +def soften_mask(np_rgba_image, softness, space): + if softness == 0: + return np_rgba_image + softness = min(softness, 1.0) + space = np.clip(space, 0.0, 1.0) + original_max_opacity = np.max(np_rgba_image[:, :, 3]) + out_mask = np_rgba_image[:, :, 3] <= 0.0 + blurred_mask = image_blur(np_rgba_image[:, :, 3], 3.5 / softness, mode="linear") + blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0) + np_rgba_image[ + :, :, 3 + ] *= blurred_mask # preserve partial opacity in original input mask + np_rgba_image[:, :, 3] /= np.max(np_rgba_image[:, :, 3]) # renormalize + np_rgba_image[:, :, 3] = np.clip( + np_rgba_image[:, :, 3] - space, 0.0, 1.0 + ) # make space + np_rgba_image[:, :, 3] /= np.max(np_rgba_image[:, :, 3]) # and renormalize again + np_rgba_image[:, :, 3] *= original_max_opacity # restore original max opacity + return np_rgba_image + + +def expand_image( + cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float +): + origin_h, origin_w = cv2_img.shape[:2] + new_width = cv2_img.shape[1] + left + right + new_height = cv2_img.shape[0] + top + bottom + new_img = np.zeros((new_height, new_width, 4), np.uint8) # expanded image is rgba + + print( + "Expanding input image from {0}x{1} to {2}x{3}".format( + cv2_img.shape[1], cv2_img.shape[0], new_width, new_height + ) + ) + if cv2_img.shape[2] == 3: # rgb input image + new_img[ + top : top + cv2_img.shape[0], left : left + cv2_img.shape[1], 0:3 + ] = cv2_img + new_img[ + top : top + cv2_img.shape[0], left : left + cv2_img.shape[1], 3 + ] = 255 # fully opaque + elif cv2_img.shape[2] == 4: # rgba input image + new_img[top : top + cv2_img.shape[0], left : left + cv2_img.shape[1]] = cv2_img + else: + raise Exception( + "Unsupported image format: {0} channels".format(cv2_img.shape[2]) + ) + + if softness > 0.0: + new_img = soften_mask(new_img / 255.0, softness / 100.0, space / 100.0) + new_img = (np.clip(new_img, 0.0, 1.0) * 255.0).astype(np.uint8) + + mask_image = 255.0 - new_img[:, :, 3] # extract mask from alpha channel and invert + rgb_init_image = ( + 0.0 + new_img[:, :, 0:3] + ) # strip mask from init_img leaving only rgb channels + + hard_mask = np.zeros_like(cv2_img[:, :, 0]) + if top != 0: + hard_mask[0 : origin_h // 2, :] = 255 + if bottom != 0: + hard_mask[origin_h // 2 :, :] = 255 + if left != 0: + hard_mask[:, 0 : origin_w // 2] = 255 + if right != 0: + hard_mask[:, origin_w // 2 :] = 255 + hard_mask = cv2.copyMakeBorder( + hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255 + ) + mask_image = np.where(hard_mask > 0, mask_image, 0) + return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8) + + +if __name__ == "__main__": + from pathlib import Path + + current_dir = Path(__file__).parent.absolute().resolve() + image_path = current_dir.parent / "tests" / "bunny.jpeg" + init_image = cv2.imread(str(image_path)) + init_image, mask_image = expand_image( + init_image, + 200, + 200, + 0, + 0, + 60, + 50, + ) + print(mask_image.dtype, mask_image.min(), mask_image.max()) + print(init_image.dtype, init_image.min(), init_image.max()) + mask_image = mask_image.astype(np.uint8) + init_image = init_image.astype(np.uint8) + cv2.imwrite("expanded_image.png", init_image) + cv2.imwrite("expanded_mask.png", mask_image) diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py index 998db43..e1e9927 100644 --- a/lama_cleaner/model/utils.py +++ b/lama_cleaner/model/utils.py @@ -27,7 +27,7 @@ def make_beta_schedule( if schedule == "linear": betas = ( torch.linspace( - linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 ) ** 2 ) @@ -772,7 +772,7 @@ def conv2d_resample( f=f, up=up, padding=[px0, px1, py0, py1], - gain=up ** 2, + gain=up**2, flip_filter=flip_filter, ) return x @@ -814,7 +814,7 @@ def conv2d_resample( x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], - gain=up ** 2, + gain=up**2, flip_filter=flip_filter, ) if down > 1: @@ -834,7 +834,7 @@ def conv2d_resample( f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], - gain=up ** 2, + gain=up**2, flip_filter=flip_filter, ) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) @@ -870,7 +870,7 @@ class Conv2dLayer(torch.nn.Module): self.register_buffer("resample_filter", setup_filter(resample_filter)) self.conv_clamp = conv_clamp self.padding = kernel_size // 2 - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) self.act_gain = activation_funcs[activation].def_gain memory_format = ( diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 9eba2f5..b710cdc 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -55,6 +55,7 @@ class Config(BaseModel): # Crop image to this size before doing sd inpainting # The value is always on the original image scale use_croper: bool = False + croper_is_outpainting: bool = False croper_x: int = None croper_y: int = None croper_height: int = None @@ -78,6 +79,10 @@ class Config(BaseModel): sd_seed: int = 42 sd_match_histograms: bool = False + # out-painting + sd_outpainting_softness: float = 30.0 + sd_outpainting_space: float = 50.0 + # Configs for opencv inpainting # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 cv2_flag: str = "INPAINT_NS" diff --git a/lama_cleaner/tests/test_outpainting.py b/lama_cleaner/tests/test_outpainting.py new file mode 100644 index 0000000..ec59399 --- /dev/null +++ b/lama_cleaner/tests/test_outpainting.py @@ -0,0 +1,70 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device(device) + + +@pytest.mark.parametrize("sd_device", ["mps"]) +@pytest.mark.parametrize( + "rect", + [ + [0, -100, 512, 512 - 128 + 100], + [0, 128, 512, 512 - 128 + 100], + [128, 0, 512 - 128 + 100, 512], + [-100, 0, 512 - 128 + 100, 512], + [0, 0, 512, 512 + 200], + [0, 0, 512 + 200, 512], + [-100, -100, 512 + 200, 512 + 200], + ], +) +def test_sdxl_outpainting(sd_device, rect): + def callback(i, t, latents): + pass + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + callback=callback, + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a dog sitting on a bench in the park", + sd_steps=30, + use_croper=True, + croper_is_outpainting=True, + croper_x=rect[0], + croper_y=rect[1], + croper_width=rect[2], + croper_height=rect[3], + sd_guidance_scale=14, + sd_sampler=SDSampler.dpm_plus_plus, + ) + + assert_equal( + model, + cfg, + f"sd15_outpainting_dpm++_{'_'.join(map(str, rect))}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + )