From 58b931fdb2bbe34b0b51596b33f4a1eefa7e3011 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 15 Nov 2023 08:50:35 +0800 Subject: [PATCH] add lcm lora --- lama_cleaner/const.py | 1 + lama_cleaner/model/base.py | 11 ++++++++++- lama_cleaner/model/sd.py | 8 +++----- lama_cleaner/model/sdxl.py | 8 +++----- lama_cleaner/model_manager.py | 12 +++++++++++- lama_cleaner/schema.py | 5 +++++ 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 233f291..30937e6 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -37,6 +37,7 @@ AVAILABLE_MODELS = [ ] SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"] +MODELS_SUPPORT_LCM_LORA = SD15_MODELS + ["sdxl"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] DEFAULT_DEVICE = "cuda" diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index c452690..d74f4e9 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -13,7 +13,8 @@ from lama_cleaner.helper import ( switch_mps_device, ) from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb -from lama_cleaner.schema import Config, HDStrategy +from lama_cleaner.model.utils import get_scheduler +from lama_cleaner.schema import Config, HDStrategy, SDSampler class InpaintModel: @@ -381,3 +382,11 @@ class DiffusionInpaintModel(InpaintModel): # original_pixel_indices # ] return inpaint_result + + def set_scheduler(self, config: Config): + scheduler_config = self.model.scheduler.config + sd_sampler = config.sd_sampler + if config.sd_lcm_lora: + sd_sampler = SDSampler.lcm + scheduler = get_scheduler(sd_sampler, scheduler_config) + self.model.scheduler = scheduler diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index b29626c..9aa1b54 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -8,7 +8,7 @@ from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.utils import torch_gc, get_scheduler -from lama_cleaner.schema import Config +from lama_cleaner.schema import Config, SDSampler class CPUTextEncoderWrapper: @@ -67,6 +67,7 @@ def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True): class SD(DiffusionInpaintModel): pad_mod = 8 min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" def init_model(self, device: torch.device, **kwargs): from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline @@ -129,10 +130,7 @@ class SD(DiffusionInpaintModel): mask: [H, W, 1] 255 means area to repaint return: BGR IMAGE """ - - scheduler_config = self.model.scheduler.config - scheduler = get_scheduler(config.sd_sampler, scheduler_config) - self.model.scheduler = scheduler + self.set_scheduler(config) if config.sd_mask_blur != 0: k = 2 * config.sd_mask_blur + 1 diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index 05fd9cf..197ab77 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -13,6 +13,7 @@ class SDXL(DiffusionInpaintModel): name = "sdxl" pad_mod = 8 min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdxl" def init_model(self, device: torch.device, **kwargs): from diffusers.pipelines import AutoPipelineForInpainting @@ -56,10 +57,7 @@ class SDXL(DiffusionInpaintModel): mask: [H, W, 1] 255 means area to repaint return: BGR IMAGE """ - - scheduler_config = self.model.scheduler.config - scheduler = get_scheduler(config.sd_sampler, scheduler_config) - self.model.scheduler = scheduler + self.set_scheduler(config) if config.sd_mask_blur != 0: k = 2 * config.sd_mask_blur + 1 @@ -80,7 +78,7 @@ class SDXL(DiffusionInpaintModel): height=img_h, width=img_w, generator=torch.manual_seed(config.sd_seed), - callback_steps=1 + callback_steps=1, ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 7b5278d..43d9ced 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -3,7 +3,7 @@ import gc from loguru import logger -from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU +from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU, MODELS_SUPPORT_LCM_LORA from lama_cleaner.helper import switch_mps_device from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model.fcf import FcF @@ -66,6 +66,7 @@ class ModelManager: def __call__(self, image, mask, config: Config): self.switch_controlnet_method(control_method=config.controlnet_method) self.enable_disable_freeu(config) + self.enable_disable_lcm_lora(config) return self.model(image, mask, config) def switch(self, new_name: str, **kwargs): @@ -137,3 +138,12 @@ class ModelManager: ) else: self.model.model.disable_freeu() + + def enable_disable_lcm_lora(self, config: Config): + if self.name in MODELS_SUPPORT_LCM_LORA: + if config.sd_lcm_lora: + if not self.model.model.pipe.get_list_adapters(): + self.model.model.load_lora_weights(self.model.lcm_lora_id) + else: + self.model.model.disable_lora() + diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 5c8651a..d384a08 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -30,6 +30,8 @@ class SDSampler(str, Enum): dpm_plus_plus = "dpm++" uni_pc = "uni_pc" + lcm = "lcm" + class FREEUConfig(BaseModel): s1: float = 1.0 @@ -98,6 +100,9 @@ class Config(BaseModel): sd_freeu: bool = False sd_freeu_config: FREEUConfig = FREEUConfig() + # lcm-lora + sd_lcm_lora: bool = False + # Configs for opencv inpainting # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 cv2_flag: str = "INPAINT_NS"