diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 809dc99..233f291 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -33,8 +33,10 @@ AVAILABLE_MODELS = [ "paint_by_example", "instruct_pix2pix", "kandinsky2.2", + "sdxl" ] SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] +MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] DEFAULT_DEVICE = "cuda" diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 7d5ed46..7b5278d 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -3,7 +3,7 @@ import gc from loguru import logger -from lama_cleaner.const import SD15_MODELS +from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU from lama_cleaner.helper import switch_mps_device from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model.fcf import FcF @@ -65,6 +65,7 @@ class ModelManager: def __call__(self, image, mask, config: Config): self.switch_controlnet_method(control_method=config.controlnet_method) + self.enable_disable_freeu(config) return self.model(image, mask, config) def switch(self, new_name: str, **kwargs): @@ -120,3 +121,19 @@ class ModelManager: self.name, switch_mps_device(self.name, self.device), **self.kwargs ) logger.info(f"Switch ControlNet method from {old_method} to {control_method}") + + def enable_disable_freeu(self, config: Config): + if str(self.model.device) == "mps": + return + + if self.name in MODELS_SUPPORT_FREEU: + if config.sd_freeu: + freeu_config = config.sd_freeu_config + self.model.model.enable_freeu( + s1=freeu_config.s1, + s2=freeu_config.s2, + b1=freeu_config.b1, + b2=freeu_config.b2, + ) + else: + self.model.model.disable_freeu() diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 1893e5a..5c8651a 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -31,6 +31,13 @@ class SDSampler(str, Enum): uni_pc = "uni_pc" +class FREEUConfig(BaseModel): + s1: float = 1.0 + s2: float = 1.0 + b1: float = 1.0 + b2: float = 1.0 + + class Config(BaseModel): class Config: arbitrary_types_allowed = True @@ -87,6 +94,10 @@ class Config(BaseModel): sd_outpainting_softness: float = 20.0 sd_outpainting_space: float = 20.0 + # freeu + sd_freeu: bool = False + sd_freeu_config: FREEUConfig = FREEUConfig() + # Configs for opencv inpainting # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 cv2_flag: str = "INPAINT_NS" diff --git a/requirements.txt b/requirements.txt index 8d35bf3..016a179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,8 @@ pydantic rich loguru yacs -diffusers==0.20.1 -transformers==4.27.4 +diffusers==0.23.0 +transformers==4.34.1 gradio piexif==1.1.3 safetensors