diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 96569d5..8db2dbe 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -8,6 +8,7 @@ MPS_SUPPORT_MODELS = [ "realisticVision1.4", "sd2", "paint_by_example", + "controlnet" ] DEFAULT_MODEL = "lama" diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 46030cc..b9c937d 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -1,3 +1,5 @@ +import gc + import PIL.Image import cv2 import numpy as np @@ -41,6 +43,38 @@ NAMES_MAP = { } +def load_from_local_model(local_model_path, torch_dtype, controlnet): + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + load_pipeline_from_original_stable_diffusion_ckpt, + ) + from .pipeline import StableDiffusionControlNetInpaintPipeline + + logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline") + + pipe = load_pipeline_from_original_stable_diffusion_ckpt( + local_model_path, + num_in_channels=9, + from_safetensors=local_model_path.endswith("safetensors"), + device="cpu", + ) + + inpaint_pipe = StableDiffusionControlNetInpaintPipeline( + vae=pipe.vae, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + unet=pipe.unet, + controlnet=controlnet, + scheduler=pipe.scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + + del pipe + gc.collect() + return inpaint_pipe.to(torch_dtype) + + class ControlNet(DiffusionInpaintModel): name = "controlnet" pad_mod = 8 @@ -71,13 +105,20 @@ class ControlNet(DiffusionInpaintModel): controlnet = ControlNetModel.from_pretrained( f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype ) - self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained( - model_id, - controlnet=controlnet, - revision="fp16" if use_gpu and fp16 else "main", - torch_dtype=torch_dtype, - **model_kwargs, - ) + if kwargs.get("sd_local_model_path", None): + self.model = load_from_local_model( + kwargs["sd_local_model_path"], + torch_dtype=torch_dtype, + controlnet=controlnet, + ) + else: + self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained( + model_id, + controlnet=controlnet, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + **model_kwargs, + ) # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py new file mode 100644 index 0000000..1e58875 --- /dev/null +++ b/lama_cleaner/tests/test_controlnet.py @@ -0,0 +1,92 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device(device) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +@pytest.mark.parametrize("cpu_textencoder", [True]) +@pytest.mark.parametrize("disable_nsfw", [True]) +def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + ) + cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"sd_controlnet_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.2, + fy=1.2, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + )