add sdxl
This commit is contained in:
parent
78c8d8dbdb
commit
8fbc8059e1
@ -13,6 +13,7 @@ MPS_SUPPORT_MODELS = [
|
|||||||
"paint_by_example",
|
"paint_by_example",
|
||||||
"controlnet",
|
"controlnet",
|
||||||
"kandinsky2.2",
|
"kandinsky2.2",
|
||||||
|
"sdxl",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_MODEL = "lama"
|
DEFAULT_MODEL = "lama"
|
||||||
@ -23,6 +24,7 @@ AVAILABLE_MODELS = [
|
|||||||
"mat",
|
"mat",
|
||||||
"fcf",
|
"fcf",
|
||||||
"sd1.5",
|
"sd1.5",
|
||||||
|
"sdxl",
|
||||||
"anything4",
|
"anything4",
|
||||||
"realisticVision1.4",
|
"realisticVision1.4",
|
||||||
"cv2",
|
"cv2",
|
||||||
|
@ -151,6 +151,7 @@ def expand_image(
|
|||||||
hard_mask[:, 0 : origin_w // 2] = 255
|
hard_mask[:, 0 : origin_w // 2] = 255
|
||||||
if right != 0:
|
if right != 0:
|
||||||
hard_mask[:, origin_w // 2 :] = 255
|
hard_mask[:, origin_w // 2 :] = 255
|
||||||
|
|
||||||
hard_mask = cv2.copyMakeBorder(
|
hard_mask = cv2.copyMakeBorder(
|
||||||
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
|
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
|
||||||
)
|
)
|
||||||
@ -166,12 +167,12 @@ if __name__ == "__main__":
|
|||||||
init_image = cv2.imread(str(image_path))
|
init_image = cv2.imread(str(image_path))
|
||||||
init_image, mask_image = expand_image(
|
init_image, mask_image = expand_image(
|
||||||
init_image,
|
init_image,
|
||||||
200,
|
top=100,
|
||||||
200,
|
right=100,
|
||||||
0,
|
bottom=100,
|
||||||
0,
|
left=100,
|
||||||
60,
|
softness=20,
|
||||||
50,
|
space=20,
|
||||||
)
|
)
|
||||||
print(mask_image.dtype, mask_image.min(), mask_image.max())
|
print(mask_image.dtype, mask_image.min(), mask_image.max())
|
||||||
print(init_image.dtype, init_image.min(), init_image.max())
|
print(init_image.dtype, init_image.min(), init_image.max())
|
||||||
|
111
lama_cleaner/model/sdxl.py
Normal file
111
lama_cleaner/model/sdxl.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
|
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class SDXL(DiffusionInpaintModel):
|
||||||
|
name = "sdxl"
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from diffusers.pipelines import AutoPipelineForInpainting
|
||||||
|
|
||||||
|
fp16 = not kwargs.get("no_half", False)
|
||||||
|
|
||||||
|
model_kwargs = {
|
||||||
|
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||||
|
}
|
||||||
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
|
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||||
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
|
revision="main",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
use_auth_token=kwargs["hf_access_token"],
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||||
|
if kwargs.get("enable_xformers", False):
|
||||||
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.warning("Stable Diffusion XL not support run TextEncoder on CPU")
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
|
||||||
|
scheduler_config = self.model.scheduler.config
|
||||||
|
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||||
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
|
if config.sd_mask_blur != 0:
|
||||||
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
callback_steps=1
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_post_process(self, result, image, mask, config):
|
||||||
|
if config.sd_match_histograms:
|
||||||
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
|
if config.sd_mask_blur != 0:
|
||||||
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
|
return result, image, mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
# model will be downloaded when app start, and can't switch in frontend settings
|
||||||
|
return True
|
@ -15,6 +15,7 @@ from lama_cleaner.model.mat import MAT
|
|||||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||||
|
from lama_cleaner.model.sdxl import SDXL
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model.zits import ZITS
|
from lama_cleaner.model.zits import ZITS
|
||||||
from lama_cleaner.model.opencv2 import OpenCV2
|
from lama_cleaner.model.opencv2 import OpenCV2
|
||||||
@ -35,6 +36,7 @@ models = {
|
|||||||
"paint_by_example": PaintByExample,
|
"paint_by_example": PaintByExample,
|
||||||
"instruct_pix2pix": InstructPix2Pix,
|
"instruct_pix2pix": InstructPix2Pix,
|
||||||
Kandinsky22.name: Kandinsky22,
|
Kandinsky22.name: Kandinsky22,
|
||||||
|
SDXL.name: SDXL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,8 +66,12 @@ class Config(BaseModel):
|
|||||||
sd_scale: float = 1.0
|
sd_scale: float = 1.0
|
||||||
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
||||||
sd_mask_blur: int = 0
|
sd_mask_blur: int = 0
|
||||||
# Ignore this value, it's useless for inpainting
|
# Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||||
sd_strength: float = 0.75
|
# starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||||
|
# on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||||
|
# process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||||
|
# essentially ignores `image`.
|
||||||
|
sd_strength: float = 1.0
|
||||||
# The number of denoising steps. More denoising steps usually lead to a
|
# The number of denoising steps. More denoising steps usually lead to a
|
||||||
# higher quality image at the expense of slower inference.
|
# higher quality image at the expense of slower inference.
|
||||||
sd_steps: int = 50
|
sd_steps: int = 50
|
||||||
@ -80,8 +84,8 @@ class Config(BaseModel):
|
|||||||
sd_match_histograms: bool = False
|
sd_match_histograms: bool = False
|
||||||
|
|
||||||
# out-painting
|
# out-painting
|
||||||
sd_outpainting_softness: float = 30.0
|
sd_outpainting_softness: float = 20.0
|
||||||
sd_outpainting_space: float = 50.0
|
sd_outpainting_space: float = 20.0
|
||||||
|
|
||||||
# Configs for opencv inpainting
|
# Configs for opencv inpainting
|
||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||||
|
@ -57,7 +57,7 @@ def test_outpainting(name, sd_device, rect):
|
|||||||
croper_y=rect[1],
|
croper_y=rect[1],
|
||||||
croper_width=rect[2],
|
croper_width=rect[2],
|
||||||
croper_height=rect[3],
|
croper_height=rect[3],
|
||||||
sd_guidance_scale=4,
|
sd_guidance_scale=8.0,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ def test_outpainting(name, sd_device, rect):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
[
|
[
|
||||||
[-100, -100, 768, 768],
|
[-128, -128, 768, 768],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_kandinsky_outpainting(name, sd_device, rect):
|
def test_kandinsky_outpainting(name, sd_device, rect):
|
||||||
@ -86,7 +86,7 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
return
|
return
|
||||||
|
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name=name,
|
name="sd1.5",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
@ -105,7 +105,7 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
croper_y=rect[1],
|
croper_y=rect[1],
|
||||||
croper_width=rect[2],
|
croper_width=rect[2],
|
||||||
croper_height=rect[3],
|
croper_height=rect[3],
|
||||||
sd_guidance_scale=4,
|
sd_guidance_scale=7,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -115,4 +115,6 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
||||||
img_p=current_dir / "cat.png",
|
img_p=current_dir / "cat.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1,
|
||||||
|
fy=1,
|
||||||
)
|
)
|
||||||
|
110
lama_cleaner/tests/test_sdxl.py
Normal file
110
lama_cleaner/tests/test_sdxl.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import HDStrategy, SDSampler
|
||||||
|
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / "result"
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
@pytest.mark.parametrize("cpu_textencoder", [False])
|
||||||
|
@pytest.mark.parametrize("disable_nsfw", [True])
|
||||||
|
def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||||
|
def callback(i, t, latents):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 20
|
||||||
|
model = ModelManager(
|
||||||
|
name="sdxl",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=False,
|
||||||
|
disable_nsfw=disable_nsfw,
|
||||||
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy,
|
||||||
|
prompt="a fox sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_strength=0.99,
|
||||||
|
sd_guidance_scale=7.0,
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sdxl_{name}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=2,
|
||||||
|
fy=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"rect",
|
||||||
|
[
|
||||||
|
[-128, -128, 1024, 1024],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sdxl_outpainting(sd_device, rect):
|
||||||
|
def callback(i, t, latents):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
model = ModelManager(
|
||||||
|
name="sdxl",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=True,
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = get_config(
|
||||||
|
HDStrategy.ORIGINAL,
|
||||||
|
prompt="a dog sitting on a bench in the park",
|
||||||
|
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
|
||||||
|
sd_steps=20,
|
||||||
|
use_croper=True,
|
||||||
|
croper_is_outpainting=True,
|
||||||
|
croper_x=rect[0],
|
||||||
|
croper_y=rect[1],
|
||||||
|
croper_width=rect[2],
|
||||||
|
croper_height=rect[3],
|
||||||
|
sd_strength=1.0,
|
||||||
|
sd_guidance_scale=8.0,
|
||||||
|
sd_sampler=SDSampler.ddim,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1.5,
|
||||||
|
fy=1.5,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user