add DiffusionInpaintModel

This commit is contained in:
Qing 2023-01-27 20:59:22 +08:00
parent 96659f2aef
commit 205170e1e5
5 changed files with 55 additions and 122 deletions

View File

@ -123,32 +123,3 @@ def make_compare_gif(
loop=0 loop=0
) )
return img_byte_arr.getvalue() return img_byte_arr.getvalue()
if __name__ == '__main__':
imgs = [
(
'/Users/qing/code/github/lama-cleaner/assets/unwant_person.jpg',
'/Users/qing/code/github/lama-cleaner/assets/unwant_person_clean.jpg'
),
# (
# '/Users/qing/code/github/lama-cleaner/assets/old_photo.jpg',
# '/Users/qing/code/github/lama-cleaner/assets/old_photo_clean.jpg'
# ),
# (
# '/Users/qing/code/github/lama-cleaner/assets/unwant_object.jpg',
# '/Users/qing/code/github/lama-cleaner/assets/unwant_object_clean.jpg'
# ),
# (
# '/Users/qing/code/github/lama-cleaner/assets/unwant_text.jpg',
# '/Users/qing/code/github/lama-cleaner/assets/unwant_text_clean.jpg'
# ),
# (
# '/Users/qing/code/github/lama-cleaner/assets/watermark.jpg',
# '/Users/qing/code/github/lama-cleaner/assets/watermark_cleanup.jpg'
# ),
]
for src_p, clean_p in imgs:
img_bytes = make_compare_gif(Image.open(src_p), Image.open(clean_p), max_side_length=600)
with open(Path(src_p).with_suffix('.gif'), 'wb') as f:
f.write(img_bytes)

View File

@ -245,3 +245,42 @@ class InpaintModel:
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
class DiffusionInpaintModel(InpaintModel):
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result

View File

@ -1,19 +1,16 @@
import random
import PIL import PIL
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from loguru import logger from loguru import logger
from lama_cleaner.helper import resize_max_size from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
class PaintByExample(InpaintModel): class PaintByExample(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
@ -53,11 +50,7 @@ class PaintByExample(InpaintModel):
mask: [H, W, 1] 255 means area to repaint mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE return: BGR IMAGE
""" """
seed = config.paint_by_example_seed set_seed(config.paint_by_example_seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
output = self.model( output = self.model(
image=PIL.Image.fromarray(image), image=PIL.Image.fromarray(image),
@ -71,42 +64,6 @@ class PaintByExample(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
if config.paint_by_example_match_histograms: if config.paint_by_example_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -8,9 +8,8 @@ from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerD
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from loguru import logger from loguru import logger
from lama_cleaner.helper import resize_max_size from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.utils import torch_gc, set_seed
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config, SDSampler
@ -28,7 +27,7 @@ class CPUTextEncoderWrapper:
return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)] return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)]
class SD(InpaintModel): class SD(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
@ -73,25 +72,6 @@ class SD(InpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -118,11 +98,7 @@ class SD(InpaintModel):
self.model.scheduler = scheduler self.model.scheduler = scheduler
seed = config.sd_seed set_seed(config.sd_seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if config.sd_mask_blur != 0: if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1 k = 2 * config.sd_mask_blur + 1
@ -147,24 +123,6 @@ class SD(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms: if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -1,4 +1,5 @@
import math import math
import random
from typing import Any from typing import Any
import torch import torch
@ -713,3 +714,10 @@ def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)