backend add freeu

This commit is contained in:
Qing 2023-11-14 22:04:16 +08:00
parent 2c9a53da8e
commit bb98c91c8c
4 changed files with 33 additions and 3 deletions

View File

@ -33,8 +33,10 @@ AVAILABLE_MODELS = [
"paint_by_example",
"instruct_pix2pix",
"kandinsky2.2",
"sdxl"
]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"]
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda"

View File

@ -3,7 +3,7 @@ import gc
from loguru import logger
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet
from lama_cleaner.model.fcf import FcF
@ -65,6 +65,7 @@ class ModelManager:
def __call__(self, image, mask, config: Config):
self.switch_controlnet_method(control_method=config.controlnet_method)
self.enable_disable_freeu(config)
return self.model(image, mask, config)
def switch(self, new_name: str, **kwargs):
@ -120,3 +121,19 @@ class ModelManager:
self.name, switch_mps_device(self.name, self.device), **self.kwargs
)
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
def enable_disable_freeu(self, config: Config):
if str(self.model.device) == "mps":
return
if self.name in MODELS_SUPPORT_FREEU:
if config.sd_freeu:
freeu_config = config.sd_freeu_config
self.model.model.enable_freeu(
s1=freeu_config.s1,
s2=freeu_config.s2,
b1=freeu_config.b1,
b2=freeu_config.b2,
)
else:
self.model.model.disable_freeu()

View File

@ -31,6 +31,13 @@ class SDSampler(str, Enum):
uni_pc = "uni_pc"
class FREEUConfig(BaseModel):
s1: float = 1.0
s2: float = 1.0
b1: float = 1.0
b2: float = 1.0
class Config(BaseModel):
class Config:
arbitrary_types_allowed = True
@ -87,6 +94,10 @@ class Config(BaseModel):
sd_outpainting_softness: float = 20.0
sd_outpainting_space: float = 20.0
# freeu
sd_freeu: bool = False
sd_freeu_config: FREEUConfig = FREEUConfig()
# Configs for opencv inpainting
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
cv2_flag: str = "INPAINT_NS"

View File

@ -9,8 +9,8 @@ pydantic
rich
loguru
yacs
diffusers==0.20.1
transformers==4.27.4
diffusers==0.23.0
transformers==4.34.1
gradio
piexif==1.1.3
safetensors