backend add freeu
This commit is contained in:
parent
2c9a53da8e
commit
bb98c91c8c
@ -33,8 +33,10 @@ AVAILABLE_MODELS = [
|
|||||||
"paint_by_example",
|
"paint_by_example",
|
||||||
"instruct_pix2pix",
|
"instruct_pix2pix",
|
||||||
"kandinsky2.2",
|
"kandinsky2.2",
|
||||||
|
"sdxl"
|
||||||
]
|
]
|
||||||
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||||
|
MODELS_SUPPORT_FREEU = SD15_MODELS + ['sd2', "sdxl"]
|
||||||
|
|
||||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||||
DEFAULT_DEVICE = "cuda"
|
DEFAULT_DEVICE = "cuda"
|
||||||
|
@ -3,7 +3,7 @@ import gc
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import SD15_MODELS
|
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU
|
||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model.controlnet import ControlNet
|
from lama_cleaner.model.controlnet import ControlNet
|
||||||
from lama_cleaner.model.fcf import FcF
|
from lama_cleaner.model.fcf import FcF
|
||||||
@ -65,6 +65,7 @@ class ModelManager:
|
|||||||
|
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||||
|
self.enable_disable_freeu(config)
|
||||||
return self.model(image, mask, config)
|
return self.model(image, mask, config)
|
||||||
|
|
||||||
def switch(self, new_name: str, **kwargs):
|
def switch(self, new_name: str, **kwargs):
|
||||||
@ -120,3 +121,19 @@ class ModelManager:
|
|||||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||||
)
|
)
|
||||||
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
||||||
|
|
||||||
|
def enable_disable_freeu(self, config: Config):
|
||||||
|
if str(self.model.device) == "mps":
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.name in MODELS_SUPPORT_FREEU:
|
||||||
|
if config.sd_freeu:
|
||||||
|
freeu_config = config.sd_freeu_config
|
||||||
|
self.model.model.enable_freeu(
|
||||||
|
s1=freeu_config.s1,
|
||||||
|
s2=freeu_config.s2,
|
||||||
|
b1=freeu_config.b1,
|
||||||
|
b2=freeu_config.b2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model.model.disable_freeu()
|
||||||
|
@ -31,6 +31,13 @@ class SDSampler(str, Enum):
|
|||||||
uni_pc = "uni_pc"
|
uni_pc = "uni_pc"
|
||||||
|
|
||||||
|
|
||||||
|
class FREEUConfig(BaseModel):
|
||||||
|
s1: float = 1.0
|
||||||
|
s2: float = 1.0
|
||||||
|
b1: float = 1.0
|
||||||
|
b2: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -87,6 +94,10 @@ class Config(BaseModel):
|
|||||||
sd_outpainting_softness: float = 20.0
|
sd_outpainting_softness: float = 20.0
|
||||||
sd_outpainting_space: float = 20.0
|
sd_outpainting_space: float = 20.0
|
||||||
|
|
||||||
|
# freeu
|
||||||
|
sd_freeu: bool = False
|
||||||
|
sd_freeu_config: FREEUConfig = FREEUConfig()
|
||||||
|
|
||||||
# Configs for opencv inpainting
|
# Configs for opencv inpainting
|
||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||||
cv2_flag: str = "INPAINT_NS"
|
cv2_flag: str = "INPAINT_NS"
|
||||||
|
@ -9,8 +9,8 @@ pydantic
|
|||||||
rich
|
rich
|
||||||
loguru
|
loguru
|
||||||
yacs
|
yacs
|
||||||
diffusers==0.20.1
|
diffusers==0.23.0
|
||||||
transformers==4.27.4
|
transformers==4.34.1
|
||||||
gradio
|
gradio
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
safetensors
|
safetensors
|
||||||
|
Loading…
Reference in New Issue
Block a user