diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 11148df..1b2b7e7 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -53,8 +53,13 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory. """ SD_CONTROLNET_HELP = """ -Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control. +Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. """ +SD_CONTROLNET_CHOICES = [ + "control_v11p_sd15_canny", + "control_v11p_sd15_openpose", + "control_v11p_sd15_inpaint", +] SD_LOCAL_MODEL_HELP = """ Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 4eece59..cb0a963 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -4,9 +4,7 @@ import PIL.Image import cv2 import numpy as np import torch -from diffusers import ( - ControlNetModel, -) +from diffusers import ControlNetModel from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel @@ -75,7 +73,7 @@ def load_from_local_model( num_in_channels=4 if is_native_control_inpaint else 9, from_safetensors=local_model_path.endswith("safetensors"), device="cpu", - load_safety_checker=False + load_safety_checker=False, ) inpaint_pipe = pipe_class( @@ -92,7 +90,7 @@ def load_from_local_model( del pipe gc.collect() - return inpaint_pipe.to(torch_dtype) + return inpaint_pipe.to(torch_dtype=torch_dtype) class ControlNet(DiffusionInpaintModel): @@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel): torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 sd_controlnet_method = kwargs["sd_controlnet_method"] + self.sd_controlnet_method = sd_controlnet_method if sd_controlnet_method == "control_v11p_sd15_inpaint": from diffusers import StableDiffusionControlNetPipeline as PipeClass @@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel): output_type="np.array", ).images[0] else: - canny_image = cv2.Canny(image, 100, 200) - canny_image = canny_image[:, :, None] - canny_image = np.concatenate( - [canny_image, canny_image, canny_image], axis=2 - ) - canny_image = PIL.Image.fromarray(canny_image) + if "canny" in self.sd_controlnet_method: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate( + [canny_image, canny_image, canny_image], axis=2 + ) + canny_image = PIL.Image.fromarray(canny_image) + control_image = canny_image + elif "openpose" in self.sd_controlnet_method: + from controlnet_aux import OpenposeDetector + + processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + control_image = processor(image, hand_and_face=True) + else: + raise NotImplementedError( + f"{self.sd_controlnet_method} not implemented" + ) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") image = PIL.Image.fromarray(image) output = self.model( image=image, - control_image=canny_image, + control_image=control_image, prompt=config.prompt, negative_prompt=config.negative_prompt, mask_image=mask_image, diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index e1ecece..fec7fc6 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -35,13 +35,13 @@ class CPUTextEncoderWrapper: def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True): from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - load_pipeline_from_original_stable_diffusion_ckpt, + download_from_original_stable_diffusion_ckpt, ) from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline logger.info(f"Converting {local_model_path} to diffusers pipeline") - pipe = load_pipeline_from_original_stable_diffusion_ckpt( + pipe = download_from_original_stable_diffusion_ckpt( local_model_path, num_in_channels=9, from_safetensors=local_model_path.endswith("safetensors"), diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index ac9abe3..74330bb 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 +from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config @@ -59,7 +60,7 @@ class ModelManager: def __call__(self, image, mask, config: Config): return self.model(image, mask, config) - def switch(self, new_name: str): + def switch(self, new_name: str, **kwargs): if new_name == self.name: return try: @@ -75,3 +76,17 @@ class ModelManager: self.name = new_name except NotImplementedError as e: raise e + + def switch_controlnet_method(self, control_method: str): + if not self.kwargs.get("sd_controlnet"): + return + if self.kwargs["sd_controlnet_method"] == control_method: + return + + del self.model + torch_gc() + + self.kwargs["sd_controlnet_method"] = control_method + self.model = self.init_model( + self.name, switch_mps_device(self.name, self.device), **self.kwargs + ) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 1d5abfe..1a3f48c 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -41,10 +41,7 @@ def parse_args(): parser.add_argument( "--sd-controlnet-method", default="control_v11p_sd15_inpaint", - choices=[ - "control_v11p_sd15_canny", - "control_v11p_sd15_inpaint", - ], + choices=SD_CONTROLNET_CHOICES, ) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument( diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index c7ef623..8d1984c 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -436,6 +436,17 @@ def switch_model(): return f"ok, switch to {new_name}", 200 +@app.route("/controlnet_method", methods=["POST"]) +def switch_controlnet_method(): + new_method = request.form.get("method") + + try: + model.switch_controlnet_method(new_method) + except NotImplementedError: + return f"Failed switch to {new_method} not implemented", 500 + return f"Switch to {new_method}", 200 + + @app.route("/") def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) diff --git a/lama_cleaner/tests/mask.png b/lama_cleaner/tests/mask.png new file mode 100644 index 0000000..29cf20b Binary files /dev/null and b/lama_cleaner/tests/mask.png differ diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py index 1e58875..d699e1b 100644 --- a/lama_cleaner/tests/test_controlnet.py +++ b/lama_cleaner/tests/test_controlnet.py @@ -1,5 +1,7 @@ import os +from lama_cleaner.const import SD_CONTROLNET_CHOICES + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" from pathlib import Path @@ -22,7 +24,10 @@ device = torch.device(device) @pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) @pytest.mark.parametrize("cpu_textencoder", [True]) @pytest.mark.parametrize("disable_nsfw", [True]) -def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): +@pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES) +def test_runway_sd_1_5( + sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method +): if sd_device == "cuda" and not torch.cuda.is_available(): return if device == "mps" and not torch.backends.mps.is_available(): @@ -34,19 +39,20 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder, + sd_controlnet_method=sd_controlnet_method, ) cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) cfg.sd_sampler = sampler - name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw" assert_equal( model, cfg, - f"sd_controlnet_{name}.png", + f"sd_controlnet_{sd_controlnet_method}_{name}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", fx=1.2, @@ -68,11 +74,12 @@ def test_local_file_path(sd_device, sampler): sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + sd_controlnet_method="control_v11p_sd15_canny", ) cfg = get_config( HDStrategy.ORIGINAL, @@ -86,7 +93,48 @@ def test_local_file_path(sd_device, sampler): assert_equal( model, cfg, - f"sd_controlnet_local_model_{name}.png", + f"sd_controlnet_canny_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors", + sd_controlnet_method="control_v11p_sd15_inpaint", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_conditioning_scale=1.0, + sd_strength=1.0 + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_local_native_{name}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) diff --git a/lama_cleaner/tests/test_instruct_pix2pix.py b/lama_cleaner/tests/test_instruct_pix2pix.py index 9780cf3..9778160 100644 --- a/lama_cleaner/tests/test_instruct_pix2pix.py +++ b/lama_cleaner/tests/test_instruct_pix2pix.py @@ -20,7 +20,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload): model = ModelManager(name="instruct_pix2pix", device=torch.device(device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload) @@ -45,7 +45,7 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload): model = ModelManager(name="instruct_pix2pix", device=torch.device(device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload) diff --git a/requirements.txt b/requirements.txt index 8e4a89a..d64b969 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ transformers==4.27.4 gradio piexif==1.1.3 safetensors -omegaconf \ No newline at end of file +omegaconf +controlnet-aux==0.0.3 \ No newline at end of file