This commit is contained in:
Qing 2023-12-22 14:00:30 +08:00
parent eb9764176c
commit 61d56288a5
6 changed files with 36 additions and 15 deletions

View File

@ -394,6 +394,9 @@ class DiffusionInpaintModel(InpaintModel):
def set_scheduler(self, config: Config): def set_scheduler(self, config: Config):
scheduler_config = self.model.scheduler.config scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler sd_sampler = config.sd_sampler
if config.sd_lcm_lora:
sd_sampler = SDSampler.lcm
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
scheduler = get_scheduler(sd_sampler, scheduler_config) scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler self.model.scheduler = scheduler

View File

@ -2,7 +2,7 @@ import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel, DiffusionPipeline
from loguru import logger from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
@ -69,6 +69,7 @@ class ControlNet(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.torch_dtype = torch_dtype
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
@ -131,6 +132,13 @@ class ControlNet(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def switch_controlnet_method(self, new_method: str):
self.sd_controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained(
new_method, torch_dtype=self.torch_dtype, resume_download=True
).to(self.model.device)
self.model.controlnet = controlnet
def _get_control_image(self, image, mask): def _get_control_image(self, image, mask):
if "canny" in self.sd_controlnet_method: if "canny" in self.sd_controlnet_method:
control_image = make_canny_control_image(image) control_image = make_canny_control_image(image)

View File

@ -105,13 +105,18 @@ class ModelManager:
if not self.available_models[self.name].support_controlnet: if not self.available_models[self.name].support_controlnet:
return return
if self.sd_controlnet != config.controlnet_enabled or ( if (
self.sd_controlnet and self.sd_controlnet_method != config.controlnet_method self.sd_controlnet
and config.controlnet_method
and self.sd_controlnet_method != config.controlnet_method
): ):
# 可能关闭/开启 controlnet
# 可能开启了 controlnet切换 controlnet 的方法
old_sd_controlnet = self.sd_controlnet
old_sd_controlnet_method = self.sd_controlnet_method old_sd_controlnet_method = self.sd_controlnet_method
self.sd_controlnet_method = config.controlnet_method
self.model.switch_controlnet_method(config.controlnet_method)
logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
)
elif self.sd_controlnet != config.controlnet_enabled:
self.sd_controlnet = config.controlnet_enabled self.sd_controlnet = config.controlnet_enabled
self.sd_controlnet_method = config.controlnet_method self.sd_controlnet_method = config.controlnet_method
@ -120,10 +125,6 @@ class ModelManager:
) )
if not config.controlnet_enabled: if not config.controlnet_enabled:
logger.info(f"Disable controlnet") logger.info(f"Disable controlnet")
elif old_sd_controlnet_method != config.controlnet_method:
logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
)
else: else:
logger.info(f"Enable controlnet: {config.controlnet_method}") logger.info(f"Enable controlnet: {config.controlnet_method}")

View File

@ -258,6 +258,13 @@ def process():
croper_y=form["croperY"], croper_y=form["croperY"],
croper_height=form["croperHeight"], croper_height=form["croperHeight"],
croper_width=form["croperWidth"], croper_width=form["croperWidth"],
use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],
sd_scale=form["sdScale"], sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"], sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"], sd_strength=form["sdStrength"],

View File

@ -17,6 +17,7 @@ save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True) save_dir.mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device) device = torch.device(device)
model_name = "runwayml/stable-diffusion-inpainting"
@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) @pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@ -35,7 +36,7 @@ def test_runway_sd_1_5(
sd_steps = 1 if sd_device == "cpu" else 30 sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager( model = ModelManager(
name="sd1.5", name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
@ -83,7 +84,7 @@ def test_local_file_path(sd_device, sampler):
sd_steps = 1 if sd_device == "cpu" else 30 sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager( model = ModelManager(
name="sd1.5", name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
@ -121,7 +122,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
sd_steps = 1 if sd_device == "cpu" else 30 sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager( model = ModelManager(
name="sd1.5", name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
@ -162,7 +163,7 @@ def test_controlnet_switch(sd_device, sampler):
sd_steps = 1 if sd_device == "cpu" else 30 sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager( model = ModelManager(
name="sd1.5", name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",

View File

@ -1,2 +1,3 @@
wheel wheel
twine twine
pytest-loguru