add lcm lora
This commit is contained in:
parent
bb98c91c8c
commit
58b931fdb2
@ -37,6 +37,7 @@ AVAILABLE_MODELS = [
|
|||||||
]
|
]
|
||||||
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||||
MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"]
|
MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"]
|
||||||
|
MODELS_SUPPORT_LCM_LORA = SD15_MODELS + ["sdxl"]
|
||||||
|
|
||||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||||
DEFAULT_DEVICE = "cuda"
|
DEFAULT_DEVICE = "cuda"
|
||||||
|
@ -13,7 +13,8 @@ from lama_cleaner.helper import (
|
|||||||
switch_mps_device,
|
switch_mps_device,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
||||||
from lama_cleaner.schema import Config, HDStrategy
|
from lama_cleaner.model.utils import get_scheduler
|
||||||
|
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||||
|
|
||||||
|
|
||||||
class InpaintModel:
|
class InpaintModel:
|
||||||
@ -381,3 +382,11 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
# original_pixel_indices
|
# original_pixel_indices
|
||||||
# ]
|
# ]
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
|
def set_scheduler(self, config: Config):
|
||||||
|
scheduler_config = self.model.scheduler.config
|
||||||
|
sd_sampler = config.sd_sampler
|
||||||
|
if config.sd_lcm_lora:
|
||||||
|
sd_sampler = SDSampler.lcm
|
||||||
|
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||||
|
self.model.scheduler = scheduler
|
||||||
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config, SDSampler
|
||||||
|
|
||||||
|
|
||||||
class CPUTextEncoderWrapper:
|
class CPUTextEncoderWrapper:
|
||||||
@ -67,6 +67,7 @@ def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
|
|||||||
class SD(DiffusionInpaintModel):
|
class SD(DiffusionInpaintModel):
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
@ -129,10 +130,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
scheduler_config = self.model.scheduler.config
|
|
||||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
|
||||||
self.model.scheduler = scheduler
|
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
@ -13,6 +13,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
name = "sdxl"
|
name = "sdxl"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from diffusers.pipelines import AutoPipelineForInpainting
|
from diffusers.pipelines import AutoPipelineForInpainting
|
||||||
@ -56,10 +57,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
scheduler_config = self.model.scheduler.config
|
|
||||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
|
||||||
self.model.scheduler = scheduler
|
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
@ -80,7 +78,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
height=img_h,
|
height=img_h,
|
||||||
width=img_w,
|
width=img_w,
|
||||||
generator=torch.manual_seed(config.sd_seed),
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
callback_steps=1
|
callback_steps=1,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = (output * 255).round().astype("uint8")
|
output = (output * 255).round().astype("uint8")
|
||||||
|
@ -3,7 +3,7 @@ import gc
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU
|
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU, MODELS_SUPPORT_LCM_LORA
|
||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model.controlnet import ControlNet
|
from lama_cleaner.model.controlnet import ControlNet
|
||||||
from lama_cleaner.model.fcf import FcF
|
from lama_cleaner.model.fcf import FcF
|
||||||
@ -66,6 +66,7 @@ class ModelManager:
|
|||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||||
self.enable_disable_freeu(config)
|
self.enable_disable_freeu(config)
|
||||||
|
self.enable_disable_lcm_lora(config)
|
||||||
return self.model(image, mask, config)
|
return self.model(image, mask, config)
|
||||||
|
|
||||||
def switch(self, new_name: str, **kwargs):
|
def switch(self, new_name: str, **kwargs):
|
||||||
@ -137,3 +138,12 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model.model.disable_freeu()
|
self.model.model.disable_freeu()
|
||||||
|
|
||||||
|
def enable_disable_lcm_lora(self, config: Config):
|
||||||
|
if self.name in MODELS_SUPPORT_LCM_LORA:
|
||||||
|
if config.sd_lcm_lora:
|
||||||
|
if not self.model.model.pipe.get_list_adapters():
|
||||||
|
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
||||||
|
else:
|
||||||
|
self.model.model.disable_lora()
|
||||||
|
|
||||||
|
@ -30,6 +30,8 @@ class SDSampler(str, Enum):
|
|||||||
dpm_plus_plus = "dpm++"
|
dpm_plus_plus = "dpm++"
|
||||||
uni_pc = "uni_pc"
|
uni_pc = "uni_pc"
|
||||||
|
|
||||||
|
lcm = "lcm"
|
||||||
|
|
||||||
|
|
||||||
class FREEUConfig(BaseModel):
|
class FREEUConfig(BaseModel):
|
||||||
s1: float = 1.0
|
s1: float = 1.0
|
||||||
@ -98,6 +100,9 @@ class Config(BaseModel):
|
|||||||
sd_freeu: bool = False
|
sd_freeu: bool = False
|
||||||
sd_freeu_config: FREEUConfig = FREEUConfig()
|
sd_freeu_config: FREEUConfig = FREEUConfig()
|
||||||
|
|
||||||
|
# lcm-lora
|
||||||
|
sd_lcm_lora: bool = False
|
||||||
|
|
||||||
# Configs for opencv inpainting
|
# Configs for opencv inpainting
|
||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||||
cv2_flag: str = "INPAINT_NS"
|
cv2_flag: str = "INPAINT_NS"
|
||||||
|
Loading…
Reference in New Issue
Block a user