From 61d56288a5ce84ee88616b2f6fa8449ef3a9fd88 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 22 Dec 2023 14:00:30 +0800 Subject: [PATCH] update --- lama_cleaner/model/base.py | 3 +++ lama_cleaner/model/controlnet.py | 10 +++++++++- lama_cleaner/model_manager.py | 19 ++++++++++--------- lama_cleaner/server.py | 7 +++++++ lama_cleaner/tests/test_controlnet.py | 9 +++++---- requirements-dev.txt | 3 ++- 6 files changed, 36 insertions(+), 15 deletions(-) diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index e724032..72fe2ed 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -394,6 +394,9 @@ class DiffusionInpaintModel(InpaintModel): def set_scheduler(self, config: Config): scheduler_config = self.model.scheduler.config sd_sampler = config.sd_sampler + if config.sd_lcm_lora: + sd_sampler = SDSampler.lcm + logger.info(f"LCM Lora enabled, use {sd_sampler} sampler") scheduler = get_scheduler(sd_sampler, scheduler_config) self.model.scheduler = scheduler diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 749feef..e5a02ce 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -2,7 +2,7 @@ import PIL.Image import cv2 import numpy as np import torch -from diffusers import ControlNetModel +from diffusers import ControlNetModel, DiffusionPipeline from loguru import logger from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION @@ -69,6 +69,7 @@ class ControlNet(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + self.torch_dtype = torch_dtype if model_info.model_type in [ ModelType.DIFFUSERS_SD, @@ -131,6 +132,13 @@ class ControlNet(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) + def switch_controlnet_method(self, new_method: str): + self.sd_controlnet_method = new_method + controlnet = ControlNetModel.from_pretrained( + new_method, torch_dtype=self.torch_dtype, resume_download=True + ).to(self.model.device) + self.model.controlnet = controlnet + def _get_control_image(self, image, mask): if "canny" in self.sd_controlnet_method: control_image = make_canny_control_image(image) diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 5084396..69bbf99 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -105,13 +105,18 @@ class ModelManager: if not self.available_models[self.name].support_controlnet: return - if self.sd_controlnet != config.controlnet_enabled or ( - self.sd_controlnet and self.sd_controlnet_method != config.controlnet_method + if ( + self.sd_controlnet + and config.controlnet_method + and self.sd_controlnet_method != config.controlnet_method ): - # 可能关闭/开启 controlnet - # 可能开启了 controlnet,切换 controlnet 的方法 - old_sd_controlnet = self.sd_controlnet old_sd_controlnet_method = self.sd_controlnet_method + self.sd_controlnet_method = config.controlnet_method + self.model.switch_controlnet_method(config.controlnet_method) + logger.info( + f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}" + ) + elif self.sd_controlnet != config.controlnet_enabled: self.sd_controlnet = config.controlnet_enabled self.sd_controlnet_method = config.controlnet_method @@ -120,10 +125,6 @@ class ModelManager: ) if not config.controlnet_enabled: logger.info(f"Disable controlnet") - elif old_sd_controlnet_method != config.controlnet_method: - logger.info( - f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}" - ) else: logger.info(f"Enable controlnet: {config.controlnet_method}") diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index d66053c..6ec0b40 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -258,6 +258,13 @@ def process(): croper_y=form["croperY"], croper_height=form["croperHeight"], croper_width=form["croperWidth"], + + use_extender=form["useExtender"], + extender_x=form["extenderX"], + extender_y=form["extenderY"], + extender_height=form["extenderHeight"], + extender_width=form["extenderWidth"], + sd_scale=form["sdScale"], sd_mask_blur=form["sdMaskBlur"], sd_strength=form["sdStrength"], diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py index 6b1cf51..5c4571b 100644 --- a/lama_cleaner/tests/test_controlnet.py +++ b/lama_cleaner/tests/test_controlnet.py @@ -17,6 +17,7 @@ save_dir = current_dir / "result" save_dir.mkdir(exist_ok=True, parents=True) device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) +model_name = "runwayml/stable-diffusion-inpainting" @pytest.mark.parametrize("sd_device", ["cuda", "mps"]) @@ -35,7 +36,7 @@ def test_runway_sd_1_5( sd_steps = 1 if sd_device == "cpu" else 30 model = ModelManager( - name="sd1.5", + name=model_name, sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", @@ -83,7 +84,7 @@ def test_local_file_path(sd_device, sampler): sd_steps = 1 if sd_device == "cpu" else 30 model = ModelManager( - name="sd1.5", + name=model_name, sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", @@ -121,7 +122,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): sd_steps = 1 if sd_device == "cpu" else 30 model = ModelManager( - name="sd1.5", + name=model_name, sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", @@ -162,7 +163,7 @@ def test_controlnet_switch(sd_device, sampler): sd_steps = 1 if sd_device == "cpu" else 30 model = ModelManager( - name="sd1.5", + name=model_name, sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", diff --git a/requirements-dev.txt b/requirements-dev.txt index d5ba964..5f854af 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,3 @@ wheel -twine \ No newline at end of file +twine +pytest-loguru \ No newline at end of file