This commit is contained in:
Qing 2023-11-14 14:19:56 +08:00
parent 78c8d8dbdb
commit 8fbc8059e1
7 changed files with 246 additions and 14 deletions

View File

@ -13,6 +13,7 @@ MPS_SUPPORT_MODELS = [
"paint_by_example",
"controlnet",
"kandinsky2.2",
"sdxl",
]
DEFAULT_MODEL = "lama"
@ -23,6 +24,7 @@ AVAILABLE_MODELS = [
"mat",
"fcf",
"sd1.5",
"sdxl",
"anything4",
"realisticVision1.4",
"cv2",

View File

@ -151,6 +151,7 @@ def expand_image(
hard_mask[:, 0 : origin_w // 2] = 255
if right != 0:
hard_mask[:, origin_w // 2 :] = 255
hard_mask = cv2.copyMakeBorder(
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
)
@ -166,12 +167,12 @@ if __name__ == "__main__":
init_image = cv2.imread(str(image_path))
init_image, mask_image = expand_image(
init_image,
200,
200,
0,
0,
60,
50,
top=100,
right=100,
bottom=100,
left=100,
softness=20,
space=20,
)
print(mask_image.dtype, mask_image.min(), mask_image.max())
print(init_image.dtype, init_image.min(), init_image.max())

111
lama_cleaner/model/sdxl.py Normal file
View File

@ -0,0 +1,111 @@
import PIL.Image
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config
class SDXL(DiffusionInpaintModel):
name = "sdxl"
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines import AutoPipelineForInpainting
fp16 = not kwargs.get("no_half", False)
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
revision="main",
torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
**model_kwargs,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.warning("Stable Diffusion XL not support run TextEncoder on CPU")
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
scheduler_config = self.model.scheduler.config
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2]
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
num_inference_steps=config.sd_steps,
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -15,6 +15,7 @@ from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
from lama_cleaner.model.sdxl import SDXL
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2
@ -35,6 +36,7 @@ models = {
"paint_by_example": PaintByExample,
"instruct_pix2pix": InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL,
}

View File

@ -66,8 +66,12 @@ class Config(BaseModel):
sd_scale: float = 1.0
# Blur the edge of mask area. The higher the number the smoother blend with the original image
sd_mask_blur: int = 0
# Ignore this value, it's useless for inpainting
sd_strength: float = 0.75
# Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
# starting point and more noise is added the higher the `strength`. The number of denoising steps depends
# on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
# process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
# essentially ignores `image`.
sd_strength: float = 1.0
# The number of denoising steps. More denoising steps usually lead to a
# higher quality image at the expense of slower inference.
sd_steps: int = 50
@ -80,8 +84,8 @@ class Config(BaseModel):
sd_match_histograms: bool = False
# out-painting
sd_outpainting_softness: float = 30.0
sd_outpainting_space: float = 50.0
sd_outpainting_softness: float = 20.0
sd_outpainting_space: float = 20.0
# Configs for opencv inpainting
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07

View File

@ -57,7 +57,7 @@ def test_outpainting(name, sd_device, rect):
croper_y=rect[1],
croper_width=rect[2],
croper_height=rect[3],
sd_guidance_scale=4,
sd_guidance_scale=8.0,
sd_sampler=SDSampler.dpm_plus_plus,
)
@ -75,7 +75,7 @@ def test_outpainting(name, sd_device, rect):
@pytest.mark.parametrize(
"rect",
[
[-100, -100, 768, 768],
[-128, -128, 768, 768],
],
)
def test_kandinsky_outpainting(name, sd_device, rect):
@ -86,7 +86,7 @@ def test_kandinsky_outpainting(name, sd_device, rect):
return
model = ModelManager(
name=name,
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
@ -105,7 +105,7 @@ def test_kandinsky_outpainting(name, sd_device, rect):
croper_y=rect[1],
croper_width=rect[2],
croper_height=rect[3],
sd_guidance_scale=4,
sd_guidance_scale=7,
sd_sampler=SDSampler.dpm_plus_plus,
)
@ -115,4 +115,6 @@ def test_kandinsky_outpainting(name, sd_device, rect):
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
img_p=current_dir / "cat.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1,
fy=1,
)

View File

@ -0,0 +1,110 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
import pytest
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import HDStrategy, SDSampler
from lama_cleaner.tests.test_model import get_config, assert_equal
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [True])
def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
def callback(i, t, latents):
pass
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 20
model = ModelManager(
name="sdxl",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=False,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
callback=callback,
)
cfg = get_config(
strategy,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
sd_strength=0.99,
sd_guidance_scale=7.0,
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
assert_equal(
model,
cfg,
f"sdxl_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=2,
fy=2,
)
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize(
"rect",
[
[-128, -128, 1024, 1024],
],
)
def test_sdxl_outpainting(sd_device, rect):
def callback(i, t, latents):
pass
if sd_device == "cuda" and not torch.cuda.is_available():
return
model = ModelManager(
name="sdxl",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a dog sitting on a bench in the park",
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
sd_steps=20,
use_croper=True,
croper_is_outpainting=True,
croper_x=rect[0],
croper_y=rect[1],
croper_width=rect[2],
croper_height=rect[3],
sd_strength=1.0,
sd_guidance_scale=8.0,
sd_sampler=SDSampler.ddim,
)
assert_equal(
model,
cfg,
f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.5,
fy=1.5,
)