wip: controlnet

This commit is contained in:
Qing 2023-05-11 21:51:58 +08:00
parent e5ac6a105a
commit 87f54bb87e
10 changed files with 117 additions and 29 deletions

View File

@ -53,8 +53,13 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
""" """
SD_CONTROLNET_HELP = """ SD_CONTROLNET_HELP = """
Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control. Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
""" """
SD_CONTROLNET_CHOICES = [
"control_v11p_sd15_canny",
"control_v11p_sd15_openpose",
"control_v11p_sd15_inpaint",
]
SD_LOCAL_MODEL_HELP = """ SD_LOCAL_MODEL_HELP = """
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.

View File

@ -4,9 +4,7 @@ import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import ( from diffusers import ControlNetModel
ControlNetModel,
)
from loguru import logger from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
@ -75,7 +73,7 @@ def load_from_local_model(
num_in_channels=4 if is_native_control_inpaint else 9, num_in_channels=4 if is_native_control_inpaint else 9,
from_safetensors=local_model_path.endswith("safetensors"), from_safetensors=local_model_path.endswith("safetensors"),
device="cpu", device="cpu",
load_safety_checker=False load_safety_checker=False,
) )
inpaint_pipe = pipe_class( inpaint_pipe = pipe_class(
@ -92,7 +90,7 @@ def load_from_local_model(
del pipe del pipe
gc.collect() gc.collect()
return inpaint_pipe.to(torch_dtype) return inpaint_pipe.to(torch_dtype=torch_dtype)
class ControlNet(DiffusionInpaintModel): class ControlNet(DiffusionInpaintModel):
@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel):
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
sd_controlnet_method = kwargs["sd_controlnet_method"] sd_controlnet_method = kwargs["sd_controlnet_method"]
self.sd_controlnet_method = sd_controlnet_method
if sd_controlnet_method == "control_v11p_sd15_inpaint": if sd_controlnet_method == "control_v11p_sd15_inpaint":
from diffusers import StableDiffusionControlNetPipeline as PipeClass from diffusers import StableDiffusionControlNetPipeline as PipeClass
@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel):
output_type="np.array", output_type="np.array",
).images[0] ).images[0]
else: else:
canny_image = cv2.Canny(image, 100, 200) if "canny" in self.sd_controlnet_method:
canny_image = canny_image[:, :, None] canny_image = cv2.Canny(image, 100, 200)
canny_image = np.concatenate( canny_image = canny_image[:, :, None]
[canny_image, canny_image, canny_image], axis=2 canny_image = np.concatenate(
) [canny_image, canny_image, canny_image], axis=2
canny_image = PIL.Image.fromarray(canny_image) )
canny_image = PIL.Image.fromarray(canny_image)
control_image = canny_image
elif "openpose" in self.sd_controlnet_method:
from controlnet_aux import OpenposeDetector
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
control_image = processor(image, hand_and_face=True)
else:
raise NotImplementedError(
f"{self.sd_controlnet_method} not implemented"
)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image) image = PIL.Image.fromarray(image)
output = self.model( output = self.model(
image=image, image=image,
control_image=canny_image, control_image=control_image,
prompt=config.prompt, prompt=config.prompt,
negative_prompt=config.negative_prompt, negative_prompt=config.negative_prompt,
mask_image=mask_image, mask_image=mask_image,

View File

@ -35,13 +35,13 @@ class CPUTextEncoderWrapper:
def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True): def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt, download_from_original_stable_diffusion_ckpt,
) )
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
logger.info(f"Converting {local_model_path} to diffusers pipeline") logger.info(f"Converting {local_model_path} to diffusers pipeline")
pipe = load_pipeline_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
local_model_path, local_model_path,
num_in_channels=9, num_in_channels=9,
from_safetensors=local_model_path.endswith("safetensors"), from_safetensors=local_model_path.endswith("safetensors"),

View File

@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model.zits import ZITS from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@ -59,7 +60,7 @@ class ModelManager:
def __call__(self, image, mask, config: Config): def __call__(self, image, mask, config: Config):
return self.model(image, mask, config) return self.model(image, mask, config)
def switch(self, new_name: str): def switch(self, new_name: str, **kwargs):
if new_name == self.name: if new_name == self.name:
return return
try: try:
@ -75,3 +76,17 @@ class ModelManager:
self.name = new_name self.name = new_name
except NotImplementedError as e: except NotImplementedError as e:
raise e raise e
def switch_controlnet_method(self, control_method: str):
if not self.kwargs.get("sd_controlnet"):
return
if self.kwargs["sd_controlnet_method"] == control_method:
return
del self.model
torch_gc()
self.kwargs["sd_controlnet_method"] = control_method
self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs
)

View File

@ -41,10 +41,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-controlnet-method", "--sd-controlnet-method",
default="control_v11p_sd15_inpaint", default="control_v11p_sd15_inpaint",
choices=[ choices=SD_CONTROLNET_CHOICES,
"control_v11p_sd15_canny",
"control_v11p_sd15_inpaint",
],
) )
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
parser.add_argument( parser.add_argument(

View File

@ -436,6 +436,17 @@ def switch_model():
return f"ok, switch to {new_name}", 200 return f"ok, switch to {new_name}", 200
@app.route("/controlnet_method", methods=["POST"])
def switch_controlnet_method():
new_method = request.form.get("method")
try:
model.switch_controlnet_method(new_method)
except NotImplementedError:
return f"Failed switch to {new_method} not implemented", 500
return f"Switch to {new_method}", 200
@app.route("/") @app.route("/")
def index(): def index():
return send_file(os.path.join(BUILD_DIR, "index.html")) return send_file(os.path.join(BUILD_DIR, "index.html"))

BIN
lama_cleaner/tests/mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -1,5 +1,7 @@
import os import os
from lama_cleaner.const import SD_CONTROLNET_CHOICES
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path from pathlib import Path
@ -22,7 +24,10 @@ device = torch.device(device)
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) @pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
@pytest.mark.parametrize("cpu_textencoder", [True]) @pytest.mark.parametrize("cpu_textencoder", [True])
@pytest.mark.parametrize("disable_nsfw", [True]) @pytest.mark.parametrize("disable_nsfw", [True])
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): @pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES)
def test_runway_sd_1_5(
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method
):
if sd_device == "cuda" and not torch.cuda.is_available(): if sd_device == "cuda" and not torch.cuda.is_available():
return return
if device == "mps" and not torch.backends.mps.is_available(): if device == "mps" and not torch.backends.mps.is_available():
@ -34,19 +39,20 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=False,
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder, sd_cpu_textencoder=cpu_textencoder,
sd_controlnet_method=sd_controlnet_method,
) )
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw"
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"sd_controlnet_{name}.png", f"sd_controlnet_{sd_controlnet_method}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.2, fx=1.2,
@ -68,11 +74,12 @@ def test_local_file_path(sd_device, sampler):
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=False,
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True, cpu_offload=True,
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
sd_controlnet_method="control_v11p_sd15_canny",
) )
cfg = get_config( cfg = get_config(
HDStrategy.ORIGINAL, HDStrategy.ORIGINAL,
@ -86,7 +93,48 @@ def test_local_file_path(sd_device, sampler):
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"sd_controlnet_local_model_{name}.png", f"sd_controlnet_canny_local_model_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
if device == "mps" and not torch.backends.mps.is_available():
return
sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=False,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors",
sd_controlnet_method="control_v11p_sd15_inpaint",
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
controlnet_conditioning_scale=1.0,
sd_strength=1.0
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"sd_controlnet_local_native_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
) )

View File

@ -20,7 +20,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
model = ModelManager(name="instruct_pix2pix", model = ModelManager(name="instruct_pix2pix",
device=torch.device(device), device=torch.device(device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=False,
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=cpu_offload) cpu_offload=cpu_offload)
@ -45,7 +45,7 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
model = ModelManager(name="instruct_pix2pix", model = ModelManager(name="instruct_pix2pix",
device=torch.device(device), device=torch.device(device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=False,
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=cpu_offload) cpu_offload=cpu_offload)

View File

@ -13,4 +13,5 @@ transformers==4.27.4
gradio gradio
piexif==1.1.3 piexif==1.1.3
safetensors safetensors
omegaconf omegaconf
controlnet-aux==0.0.3