diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 738f283..809dc99 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -13,6 +13,7 @@ MPS_SUPPORT_MODELS = [ "paint_by_example", "controlnet", "kandinsky2.2", + "sdxl", ] DEFAULT_MODEL = "lama" @@ -23,6 +24,7 @@ AVAILABLE_MODELS = [ "mat", "fcf", "sd1.5", + "sdxl", "anything4", "realisticVision1.4", "cv2", diff --git a/lama_cleaner/model/g_diffuser_bot.py b/lama_cleaner/model/g_diffuser_bot.py index 450d618..bf7ce2a 100644 --- a/lama_cleaner/model/g_diffuser_bot.py +++ b/lama_cleaner/model/g_diffuser_bot.py @@ -151,6 +151,7 @@ def expand_image( hard_mask[:, 0 : origin_w // 2] = 255 if right != 0: hard_mask[:, origin_w // 2 :] = 255 + hard_mask = cv2.copyMakeBorder( hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255 ) @@ -166,12 +167,12 @@ if __name__ == "__main__": init_image = cv2.imread(str(image_path)) init_image, mask_image = expand_image( init_image, - 200, - 200, - 0, - 0, - 60, - 50, + top=100, + right=100, + bottom=100, + left=100, + softness=20, + space=20, ) print(mask_image.dtype, mask_image.min(), mask_image.max()) print(init_image.dtype, init_image.min(), init_image.max()) diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py new file mode 100644 index 0000000..d55c591 --- /dev/null +++ b/lama_cleaner/model/sdxl.py @@ -0,0 +1,111 @@ +import PIL.Image +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import torch_gc, get_scheduler +from lama_cleaner.schema import Config + + +class SDXL(DiffusionInpaintModel): + name = "sdxl" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers.pipelines import AutoPipelineForInpainting + + fp16 = not kwargs.get("no_half", False) + + model_kwargs = { + "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) + } + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + self.model = AutoPipelineForInpainting.from_pretrained( + "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + revision="main", + torch_dtype=torch_dtype, + use_auth_token=kwargs["hf_access_token"], + **model_kwargs, + ) + + # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention + if kwargs.get("enable_xformers", False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.warning("Stable Diffusion XL not support run TextEncoder on CPU") + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + + scheduler_config = self.model.scheduler.config + scheduler = get_scheduler(config.sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + img_h, img_w = image.shape[:2] + + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + strength=0.999 if config.sd_strength == 1.0 else config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + callback_steps=1 + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 4dca9e8..7d5ed46 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -15,6 +15,7 @@ from lama_cleaner.model.mat import MAT from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 +from lama_cleaner.model.sdxl import SDXL from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 @@ -35,6 +36,7 @@ models = { "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix, Kandinsky22.name: Kandinsky22, + SDXL.name: SDXL, } diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index b710cdc..1893e5a 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -66,8 +66,12 @@ class Config(BaseModel): sd_scale: float = 1.0 # Blur the edge of mask area. The higher the number the smoother blend with the original image sd_mask_blur: int = 0 - # Ignore this value, it's useless for inpainting - sd_strength: float = 0.75 + # Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + # starting point and more noise is added the higher the `strength`. The number of denoising steps depends + # on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + # process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + # essentially ignores `image`. + sd_strength: float = 1.0 # The number of denoising steps. More denoising steps usually lead to a # higher quality image at the expense of slower inference. sd_steps: int = 50 @@ -80,8 +84,8 @@ class Config(BaseModel): sd_match_histograms: bool = False # out-painting - sd_outpainting_softness: float = 30.0 - sd_outpainting_space: float = 50.0 + sd_outpainting_softness: float = 20.0 + sd_outpainting_space: float = 20.0 # Configs for opencv inpainting # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 diff --git a/lama_cleaner/tests/test_outpainting.py b/lama_cleaner/tests/test_outpainting.py index e306b30..3b3dbba 100644 --- a/lama_cleaner/tests/test_outpainting.py +++ b/lama_cleaner/tests/test_outpainting.py @@ -57,7 +57,7 @@ def test_outpainting(name, sd_device, rect): croper_y=rect[1], croper_width=rect[2], croper_height=rect[3], - sd_guidance_scale=4, + sd_guidance_scale=8.0, sd_sampler=SDSampler.dpm_plus_plus, ) @@ -75,7 +75,7 @@ def test_outpainting(name, sd_device, rect): @pytest.mark.parametrize( "rect", [ - [-100, -100, 768, 768], + [-128, -128, 768, 768], ], ) def test_kandinsky_outpainting(name, sd_device, rect): @@ -86,7 +86,7 @@ def test_kandinsky_outpainting(name, sd_device, rect): return model = ModelManager( - name=name, + name="sd1.5", device=torch.device(sd_device), hf_access_token="", sd_run_local=True, @@ -105,7 +105,7 @@ def test_kandinsky_outpainting(name, sd_device, rect): croper_y=rect[1], croper_width=rect[2], croper_height=rect[3], - sd_guidance_scale=4, + sd_guidance_scale=7, sd_sampler=SDSampler.dpm_plus_plus, ) @@ -115,4 +115,6 @@ def test_kandinsky_outpainting(name, sd_device, rect): f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png", img_p=current_dir / "cat.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1, + fy=1, ) diff --git a/lama_cleaner/tests/test_sdxl.py b/lama_cleaner/tests/test_sdxl.py new file mode 100644 index 0000000..68230a2 --- /dev/null +++ b/lama_cleaner/tests/test_sdxl.py @@ -0,0 +1,110 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) + + +@pytest.mark.parametrize("sd_device", ["mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +@pytest.mark.parametrize("cpu_textencoder", [False]) +@pytest.mark.parametrize("disable_nsfw", [True]) +def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + def callback(i, t, latents): + pass + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 20 + model = ModelManager( + name="sdxl", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + callback=callback, + ) + cfg = get_config( + strategy, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_strength=0.99, + sd_guidance_scale=7.0, + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"sdxl_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + +@pytest.mark.parametrize("sd_device", ["mps"]) +@pytest.mark.parametrize( + "rect", + [ + [-128, -128, 1024, 1024], + ], +) +def test_sdxl_outpainting(sd_device, rect): + def callback(i, t, latents): + pass + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + model = ModelManager( + name="sdxl", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + callback=callback, + ) + + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a dog sitting on a bench in the park", + negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature", + sd_steps=20, + use_croper=True, + croper_is_outpainting=True, + croper_x=rect[0], + croper_y=rect[1], + croper_width=rect[2], + croper_height=rect[3], + sd_strength=1.0, + sd_guidance_scale=8.0, + sd_sampler=SDSampler.ddim, + ) + + assert_equal( + model, + cfg, + f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.5, + fy=1.5, + )