backend add outpainting
This commit is contained in:
parent
0b3a9a68a2
commit
c7c309cb89
@ -12,6 +12,7 @@ from lama_cleaner.helper import (
|
||||
pad_img_to_modulo,
|
||||
switch_mps_device,
|
||||
)
|
||||
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
@ -266,7 +267,12 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
"""
|
||||
# boxes = boxes_from_mask(mask)
|
||||
if config.use_croper:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
if config.croper_is_outpainting:
|
||||
inpaint_result = self._do_outpainting(image, config)
|
||||
else:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(
|
||||
image, mask, config
|
||||
)
|
||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
@ -275,6 +281,79 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def _do_outpainting(self, image, config: Config):
|
||||
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
||||
# 从 image 中 crop 出 outpainting 区域
|
||||
image_h, image_w = image.shape[:2]
|
||||
cropper_l = config.croper_x
|
||||
cropper_t = config.croper_y
|
||||
cropper_r = config.croper_x + config.croper_width
|
||||
cropper_b = config.croper_y + config.croper_height
|
||||
image_l = 0
|
||||
image_t = 0
|
||||
image_r = image_w
|
||||
image_b = image_h
|
||||
|
||||
# 类似求 IOU
|
||||
l = max(cropper_l, image_l)
|
||||
t = max(cropper_t, image_t)
|
||||
r = min(cropper_r, image_r)
|
||||
b = min(cropper_b, image_b)
|
||||
|
||||
assert (
|
||||
0 <= l < r and 0 <= t < b
|
||||
), f"cropper and image not overlap, {l},{t},{r},{b}"
|
||||
|
||||
cropped_image = image[t:b, l:r, :]
|
||||
padding_l = max(0, image_l - cropper_l)
|
||||
padding_t = max(0, image_t - cropper_t)
|
||||
padding_r = max(0, cropper_r - image_r)
|
||||
padding_b = max(0, cropper_b - image_b)
|
||||
|
||||
zero_padding_count = [padding_l, padding_t, padding_r, padding_b].count(0)
|
||||
|
||||
if zero_padding_count not in [0, 3]:
|
||||
logger.warning(
|
||||
f"padding count({zero_padding_count}) not 0 or 3, may result in bad edge outpainting"
|
||||
)
|
||||
|
||||
expanded_image, mask_image = expand_image(
|
||||
cropped_image,
|
||||
left=padding_l,
|
||||
top=padding_t,
|
||||
right=padding_r,
|
||||
bottom=padding_b,
|
||||
softness=config.sd_outpainting_softness,
|
||||
space=config.sd_outpainting_space,
|
||||
)
|
||||
|
||||
# 最终扩大了的 image, BGR
|
||||
expanded_cropped_result_image = self._scaled_pad_forward(
|
||||
expanded_image, mask_image, config
|
||||
)
|
||||
|
||||
# RGB -> BGR
|
||||
outpainting_image = cv2.copyMakeBorder(
|
||||
image,
|
||||
left=padding_l,
|
||||
top=padding_t,
|
||||
right=padding_r,
|
||||
bottom=padding_b,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0,
|
||||
)[:, :, ::-1]
|
||||
|
||||
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
||||
paste_t = 0 if config.croper_y < 0 else config.croper_y
|
||||
paste_l = 0 if config.croper_x < 0 else config.croper_x
|
||||
|
||||
outpainting_image[
|
||||
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
||||
paste_l : paste_l + expanded_cropped_result_image.shape[1],
|
||||
:,
|
||||
] = expanded_cropped_result_image
|
||||
return outpainting_image
|
||||
|
||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||
origin_size = image.shape[:2]
|
||||
@ -291,8 +370,14 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
||||
original_pixel_indices
|
||||
]
|
||||
|
||||
# blend result, copy from g_diffuser_bot
|
||||
# mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
|
||||
# inpaint_result = np.clip(
|
||||
# inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
|
||||
# )
|
||||
# original_pixel_indices = mask < 127
|
||||
# inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
||||
# original_pixel_indices
|
||||
# ]
|
||||
return inpaint_result
|
||||
|
181
lama_cleaner/model/g_diffuser_bot.py
Normal file
181
lama_cleaner/model/g_diffuser_bot.py
Normal file
@ -0,0 +1,181 @@
|
||||
# code copy from: https://github.com/parlance-zz/g-diffuser-bot
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def np_img_grey_to_rgb(data):
|
||||
if data.ndim == 3:
|
||||
return data
|
||||
return np.expand_dims(data, 2) * np.ones((1, 1, 3))
|
||||
|
||||
|
||||
def convolve(data1, data2): # fast convolution with fft
|
||||
if data1.ndim != data2.ndim: # promote to rgb if mismatch
|
||||
if data1.ndim < 3:
|
||||
data1 = np_img_grey_to_rgb(data1)
|
||||
if data2.ndim < 3:
|
||||
data2 = np_img_grey_to_rgb(data2)
|
||||
return ifft2(fft2(data1) * fft2(data2))
|
||||
|
||||
|
||||
def fft2(data):
|
||||
if data.ndim > 2: # multiple channels
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
|
||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
||||
else: # single channel
|
||||
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
|
||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
||||
|
||||
return out_fft
|
||||
|
||||
|
||||
def ifft2(data):
|
||||
if data.ndim > 2: # multiple channels
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
|
||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
||||
else: # single channel
|
||||
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
|
||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
||||
|
||||
return out_ifft
|
||||
|
||||
|
||||
def get_gradient_kernel(width, height, std=3.14, mode="linear"):
|
||||
window_scale_x = float(
|
||||
width / min(width, height)
|
||||
) # for non-square aspect ratios we still want a circular kernel
|
||||
window_scale_y = float(height / min(width, height))
|
||||
if mode == "gaussian":
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
kx = np.exp(-x * x * std)
|
||||
if window_scale_x != window_scale_y:
|
||||
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
|
||||
ky = np.exp(-y * y * std)
|
||||
else:
|
||||
y = x
|
||||
ky = kx
|
||||
return np.outer(kx, ky)
|
||||
elif mode == "linear":
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
if window_scale_x != window_scale_y:
|
||||
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
|
||||
else:
|
||||
y = x
|
||||
return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0)
|
||||
else:
|
||||
raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode))
|
||||
|
||||
|
||||
def image_blur(data, std=3.14, mode="linear"):
|
||||
width = data.shape[0]
|
||||
height = data.shape[1]
|
||||
kernel = get_gradient_kernel(width, height, std, mode=mode)
|
||||
return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel))))
|
||||
|
||||
|
||||
def soften_mask(np_rgba_image, softness, space):
|
||||
if softness == 0:
|
||||
return np_rgba_image
|
||||
softness = min(softness, 1.0)
|
||||
space = np.clip(space, 0.0, 1.0)
|
||||
original_max_opacity = np.max(np_rgba_image[:, :, 3])
|
||||
out_mask = np_rgba_image[:, :, 3] <= 0.0
|
||||
blurred_mask = image_blur(np_rgba_image[:, :, 3], 3.5 / softness, mode="linear")
|
||||
blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0)
|
||||
np_rgba_image[
|
||||
:, :, 3
|
||||
] *= blurred_mask # preserve partial opacity in original input mask
|
||||
np_rgba_image[:, :, 3] /= np.max(np_rgba_image[:, :, 3]) # renormalize
|
||||
np_rgba_image[:, :, 3] = np.clip(
|
||||
np_rgba_image[:, :, 3] - space, 0.0, 1.0
|
||||
) # make space
|
||||
np_rgba_image[:, :, 3] /= np.max(np_rgba_image[:, :, 3]) # and renormalize again
|
||||
np_rgba_image[:, :, 3] *= original_max_opacity # restore original max opacity
|
||||
return np_rgba_image
|
||||
|
||||
|
||||
def expand_image(
|
||||
cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
|
||||
):
|
||||
origin_h, origin_w = cv2_img.shape[:2]
|
||||
new_width = cv2_img.shape[1] + left + right
|
||||
new_height = cv2_img.shape[0] + top + bottom
|
||||
new_img = np.zeros((new_height, new_width, 4), np.uint8) # expanded image is rgba
|
||||
|
||||
print(
|
||||
"Expanding input image from {0}x{1} to {2}x{3}".format(
|
||||
cv2_img.shape[1], cv2_img.shape[0], new_width, new_height
|
||||
)
|
||||
)
|
||||
if cv2_img.shape[2] == 3: # rgb input image
|
||||
new_img[
|
||||
top : top + cv2_img.shape[0], left : left + cv2_img.shape[1], 0:3
|
||||
] = cv2_img
|
||||
new_img[
|
||||
top : top + cv2_img.shape[0], left : left + cv2_img.shape[1], 3
|
||||
] = 255 # fully opaque
|
||||
elif cv2_img.shape[2] == 4: # rgba input image
|
||||
new_img[top : top + cv2_img.shape[0], left : left + cv2_img.shape[1]] = cv2_img
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported image format: {0} channels".format(cv2_img.shape[2])
|
||||
)
|
||||
|
||||
if softness > 0.0:
|
||||
new_img = soften_mask(new_img / 255.0, softness / 100.0, space / 100.0)
|
||||
new_img = (np.clip(new_img, 0.0, 1.0) * 255.0).astype(np.uint8)
|
||||
|
||||
mask_image = 255.0 - new_img[:, :, 3] # extract mask from alpha channel and invert
|
||||
rgb_init_image = (
|
||||
0.0 + new_img[:, :, 0:3]
|
||||
) # strip mask from init_img leaving only rgb channels
|
||||
|
||||
hard_mask = np.zeros_like(cv2_img[:, :, 0])
|
||||
if top != 0:
|
||||
hard_mask[0 : origin_h // 2, :] = 255
|
||||
if bottom != 0:
|
||||
hard_mask[origin_h // 2 :, :] = 255
|
||||
if left != 0:
|
||||
hard_mask[:, 0 : origin_w // 2] = 255
|
||||
if right != 0:
|
||||
hard_mask[:, origin_w // 2 :] = 255
|
||||
hard_mask = cv2.copyMakeBorder(
|
||||
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
|
||||
)
|
||||
mask_image = np.where(hard_mask > 0, mask_image, 0)
|
||||
return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
image_path = current_dir.parent / "tests" / "bunny.jpeg"
|
||||
init_image = cv2.imread(str(image_path))
|
||||
init_image, mask_image = expand_image(
|
||||
init_image,
|
||||
200,
|
||||
200,
|
||||
0,
|
||||
0,
|
||||
60,
|
||||
50,
|
||||
)
|
||||
print(mask_image.dtype, mask_image.min(), mask_image.max())
|
||||
print(init_image.dtype, init_image.min(), init_image.max())
|
||||
mask_image = mask_image.astype(np.uint8)
|
||||
init_image = init_image.astype(np.uint8)
|
||||
cv2.imwrite("expanded_image.png", init_image)
|
||||
cv2.imwrite("expanded_mask.png", mask_image)
|
@ -27,7 +27,7 @@ def make_beta_schedule(
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
|
||||
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
||||
)
|
||||
** 2
|
||||
)
|
||||
@ -772,7 +772,7 @@ def conv2d_resample(
|
||||
f=f,
|
||||
up=up,
|
||||
padding=[px0, px1, py0, py1],
|
||||
gain=up ** 2,
|
||||
gain=up**2,
|
||||
flip_filter=flip_filter,
|
||||
)
|
||||
return x
|
||||
@ -814,7 +814,7 @@ def conv2d_resample(
|
||||
x=x,
|
||||
f=f,
|
||||
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
|
||||
gain=up ** 2,
|
||||
gain=up**2,
|
||||
flip_filter=flip_filter,
|
||||
)
|
||||
if down > 1:
|
||||
@ -834,7 +834,7 @@ def conv2d_resample(
|
||||
f=(f if up > 1 else None),
|
||||
up=up,
|
||||
padding=[px0, px1, py0, py1],
|
||||
gain=up ** 2,
|
||||
gain=up**2,
|
||||
flip_filter=flip_filter,
|
||||
)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
@ -870,7 +870,7 @@ class Conv2dLayer(torch.nn.Module):
|
||||
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
||||
self.conv_clamp = conv_clamp
|
||||
self.padding = kernel_size // 2
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
||||
self.act_gain = activation_funcs[activation].def_gain
|
||||
|
||||
memory_format = (
|
||||
|
@ -55,6 +55,7 @@ class Config(BaseModel):
|
||||
# Crop image to this size before doing sd inpainting
|
||||
# The value is always on the original image scale
|
||||
use_croper: bool = False
|
||||
croper_is_outpainting: bool = False
|
||||
croper_x: int = None
|
||||
croper_y: int = None
|
||||
croper_height: int = None
|
||||
@ -78,6 +79,10 @@ class Config(BaseModel):
|
||||
sd_seed: int = 42
|
||||
sd_match_histograms: bool = False
|
||||
|
||||
# out-painting
|
||||
sd_outpainting_softness: float = 30.0
|
||||
sd_outpainting_space: float = 50.0
|
||||
|
||||
# Configs for opencv inpainting
|
||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||
cv2_flag: str = "INPAINT_NS"
|
||||
|
70
lama_cleaner/tests/test_outpainting.py
Normal file
70
lama_cleaner/tests/test_outpainting.py
Normal file
@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import HDStrategy, SDSampler
|
||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = torch.device(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"rect",
|
||||
[
|
||||
[0, -100, 512, 512 - 128 + 100],
|
||||
[0, 128, 512, 512 - 128 + 100],
|
||||
[128, 0, 512 - 128 + 100, 512],
|
||||
[-100, 0, 512 - 128 + 100, 512],
|
||||
[0, 0, 512, 512 + 200],
|
||||
[0, 0, 512 + 200, 512],
|
||||
[-100, -100, 512 + 200, 512 + 200],
|
||||
],
|
||||
)
|
||||
def test_sdxl_outpainting(sd_device, rect):
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 50 if sd_device == "cuda" else 1
|
||||
model = ModelManager(
|
||||
name="sd1.5",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
HDStrategy.ORIGINAL,
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
sd_steps=30,
|
||||
use_croper=True,
|
||||
croper_is_outpainting=True,
|
||||
croper_x=rect[0],
|
||||
croper_y=rect[1],
|
||||
croper_width=rect[2],
|
||||
croper_height=rect[3],
|
||||
sd_guidance_scale=14,
|
||||
sd_sampler=SDSampler.dpm_plus_plus,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd15_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
Loading…
Reference in New Issue
Block a user