kandinsky

This commit is contained in:
Qing 2023-08-30 21:30:11 +08:00
parent 7ba8fdbe76
commit 94211a4985
11 changed files with 170 additions and 12 deletions

View File

@ -270,6 +270,12 @@ function ModelSettingBlock() {
'https://arxiv.org/abs/2211.09800',
'https://github.com/timothybrooks/instruct-pix2pix'
)
case AIModel.KANDINSKY21:
return renderModelDesc(
'Kandinsky 2.1',
'https://huggingface.co/kandinsky-community/kandinsky-2-1-inpaint',
'https://huggingface.co/kandinsky-community/kandinsky-2-1-inpaint'
)
default:
return <></>
}

View File

@ -17,6 +17,7 @@ export enum AIModel {
Mange = 'manga',
PAINT_BY_EXAMPLE = 'paint_by_example',
PIX2PIX = 'instruct_pix2pix',
KANDINSKY21 = 'kandinsky2.1',
}
export enum ControlNetMethod {
@ -566,6 +567,13 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropMargin: 128,
enabled: true,
},
[AIModel.KANDINSKY21]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
}
export enum SDSampler {
@ -719,7 +727,8 @@ export const isSDState = selector({
settings.model === AIModel.SD15 ||
settings.model === AIModel.SD2 ||
settings.model === AIModel.ANYTHING4 ||
settings.model === AIModel.REALISTIC_VISION_1_4
settings.model === AIModel.REALISTIC_VISION_1_4 ||
settings.model === AIModel.KANDINSKY21
)
},
})

View File

@ -29,6 +29,7 @@ AVAILABLE_MODELS = [
"sd2",
"paint_by_example",
"instruct_pix2pix",
"kandinsky2.1"
]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]

View File

@ -218,7 +218,7 @@ class ControlNet(DiffusionInpaintModel):
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
negative_prompt=config.negative_prompt,
generator=torch.manual_seed(config.sd_seed),
output_type="np.array",
output_type="np",
callback=self.callback,
).images[0]
else:
@ -262,7 +262,7 @@ class ControlNet(DiffusionInpaintModel):
mask_image=mask_image,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,

View File

@ -59,7 +59,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
num_inference_steps=config.p2p_steps,
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
output_type="np.array",
output_type="np",
generator=torch.manual_seed(config.sd_seed)
).images[0]

View File

@ -0,0 +1,91 @@
import PIL.Image
import cv2
import numpy as np
import torch
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config
class Kandinsky(DiffusionInpaintModel):
pad_mod = 64
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import AutoPipelineForInpainting
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]),
"torch_dtype": torch_dtype,
}
# self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
# self.prior_name, **model_kwargs
# ).to("cpu")
#
# self.model = KandinskyInpaintPipeline.from_pretrained(
# self.model_name, **model_kwargs
# ).to(device)
self.model = AutoPipelineForInpainting.from_pretrained(
self.model_name, **model_kwargs
).to(device)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
scheduler_config = self.model.scheduler.config
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler
generator = torch.manual_seed(config.sd_seed)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
mask = mask.astype(np.float32) / 255
img_h, img_w = image.shape[:2]
output = self.model(
prompt=config.prompt,
negative_prompt=config.negative_prompt,
image=PIL.Image.fromarray(image),
mask_image=mask[:, :, 0],
height=img_h,
width=img_w,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
class Kandinsky22(Kandinsky):
name = "kandinsky2.2"
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"

View File

@ -147,7 +147,7 @@ class SD(DiffusionInpaintModel):
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,

View File

@ -7,6 +7,7 @@ from lama_cleaner.const import SD15_MODELS
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet
from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.kandinsky import Kandinsky22
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.manga import Manga
@ -33,6 +34,7 @@ models = {
"sd2": SD2,
"paint_by_example": PaintByExample,
"instruct_pix2pix": InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
}

View File

@ -1 +1,2 @@
*_result.png
*_result.png
result/

BIN
lama_cleaner/tests/cat.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 481 KiB

View File

@ -17,6 +17,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
@pytest.mark.parametrize("name", ["sd1.5"])
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize(
"rect",
@ -30,16 +31,15 @@ device = torch.device(device)
[-100, -100, 512 + 200, 512 + 200],
],
)
def test_sdxl_outpainting(sd_device, rect):
def test_outpainting(name, sd_device, rect):
def callback(i, t, latents):
pass
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
name=name,
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
@ -50,21 +50,69 @@ def test_sdxl_outpainting(sd_device, rect):
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a dog sitting on a bench in the park",
sd_steps=30,
sd_steps=50,
use_croper=True,
croper_is_outpainting=True,
croper_x=rect[0],
croper_y=rect[1],
croper_width=rect[2],
croper_height=rect[3],
sd_guidance_scale=14,
sd_guidance_scale=4,
sd_sampler=SDSampler.dpm_plus_plus,
)
assert_equal(
model,
cfg,
f"sd15_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("name", ["kandinsky2.2"])
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize(
"rect",
[
[-100, -100, 768, 768],
],
)
def test_kandinsky_outpainting(name, sd_device, rect):
def callback(i, t, latents):
pass
if sd_device == "cuda" and not torch.cuda.is_available():
return
model = ModelManager(
name=name,
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a cat",
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
sd_steps=50,
use_croper=True,
croper_is_outpainting=True,
croper_x=rect[0],
croper_y=rect[1],
croper_width=rect[2],
croper_height=rect[3],
sd_guidance_scale=4,
sd_sampler=SDSampler.dpm_plus_plus,
)
assert_equal(
model,
cfg,
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
img_p=current_dir / "cat.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)