backend add freeu
This commit is contained in:
parent
2c9a53da8e
commit
bb98c91c8c
@ -33,8 +33,10 @@ AVAILABLE_MODELS = [
|
||||
"paint_by_example",
|
||||
"instruct_pix2pix",
|
||||
"kandinsky2.2",
|
||||
"sdxl"
|
||||
]
|
||||
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||
MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"]
|
||||
|
||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||
DEFAULT_DEVICE = "cuda"
|
||||
|
@ -3,7 +3,7 @@ import gc
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import SD15_MODELS
|
||||
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model.controlnet import ControlNet
|
||||
from lama_cleaner.model.fcf import FcF
|
||||
@ -65,6 +65,7 @@ class ModelManager:
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||
self.enable_disable_freeu(config)
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
@ -120,3 +121,19 @@ class ModelManager:
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
||||
|
||||
def enable_disable_freeu(self, config: Config):
|
||||
if str(self.model.device) == "mps":
|
||||
return
|
||||
|
||||
if self.name in MODELS_SUPPORT_FREEU:
|
||||
if config.sd_freeu:
|
||||
freeu_config = config.sd_freeu_config
|
||||
self.model.model.enable_freeu(
|
||||
s1=freeu_config.s1,
|
||||
s2=freeu_config.s2,
|
||||
b1=freeu_config.b1,
|
||||
b2=freeu_config.b2,
|
||||
)
|
||||
else:
|
||||
self.model.model.disable_freeu()
|
||||
|
@ -31,6 +31,13 @@ class SDSampler(str, Enum):
|
||||
uni_pc = "uni_pc"
|
||||
|
||||
|
||||
class FREEUConfig(BaseModel):
|
||||
s1: float = 1.0
|
||||
s2: float = 1.0
|
||||
b1: float = 1.0
|
||||
b2: float = 1.0
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -87,6 +94,10 @@ class Config(BaseModel):
|
||||
sd_outpainting_softness: float = 20.0
|
||||
sd_outpainting_space: float = 20.0
|
||||
|
||||
# freeu
|
||||
sd_freeu: bool = False
|
||||
sd_freeu_config: FREEUConfig = FREEUConfig()
|
||||
|
||||
# Configs for opencv inpainting
|
||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||
cv2_flag: str = "INPAINT_NS"
|
||||
|
@ -9,8 +9,8 @@ pydantic
|
||||
rich
|
||||
loguru
|
||||
yacs
|
||||
diffusers==0.20.1
|
||||
transformers==4.27.4
|
||||
diffusers==0.23.0
|
||||
transformers==4.34.1
|
||||
gradio
|
||||
piexif==1.1.3
|
||||
safetensors
|
||||
|
Loading…
Reference in New Issue
Block a user