update
This commit is contained in:
parent
eb9764176c
commit
61d56288a5
@ -394,6 +394,9 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
def set_scheduler(self, config: Config):
|
def set_scheduler(self, config: Config):
|
||||||
scheduler_config = self.model.scheduler.config
|
scheduler_config = self.model.scheduler.config
|
||||||
sd_sampler = config.sd_sampler
|
sd_sampler = config.sd_sampler
|
||||||
|
if config.sd_lcm_lora:
|
||||||
|
sd_sampler = SDSampler.lcm
|
||||||
|
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
|
||||||
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||||
self.model.scheduler = scheduler
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import PIL.Image
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel
|
from diffusers import ControlNetModel, DiffusionPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||||
@ -69,6 +69,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
|
|
||||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
ModelType.DIFFUSERS_SD,
|
ModelType.DIFFUSERS_SD,
|
||||||
@ -131,6 +132,13 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
def switch_controlnet_method(self, new_method: str):
|
||||||
|
self.sd_controlnet_method = new_method
|
||||||
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
|
new_method, torch_dtype=self.torch_dtype, resume_download=True
|
||||||
|
).to(self.model.device)
|
||||||
|
self.model.controlnet = controlnet
|
||||||
|
|
||||||
def _get_control_image(self, image, mask):
|
def _get_control_image(self, image, mask):
|
||||||
if "canny" in self.sd_controlnet_method:
|
if "canny" in self.sd_controlnet_method:
|
||||||
control_image = make_canny_control_image(image)
|
control_image = make_canny_control_image(image)
|
||||||
|
@ -105,13 +105,18 @@ class ModelManager:
|
|||||||
if not self.available_models[self.name].support_controlnet:
|
if not self.available_models[self.name].support_controlnet:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.sd_controlnet != config.controlnet_enabled or (
|
if (
|
||||||
self.sd_controlnet and self.sd_controlnet_method != config.controlnet_method
|
self.sd_controlnet
|
||||||
|
and config.controlnet_method
|
||||||
|
and self.sd_controlnet_method != config.controlnet_method
|
||||||
):
|
):
|
||||||
# 可能关闭/开启 controlnet
|
|
||||||
# 可能开启了 controlnet,切换 controlnet 的方法
|
|
||||||
old_sd_controlnet = self.sd_controlnet
|
|
||||||
old_sd_controlnet_method = self.sd_controlnet_method
|
old_sd_controlnet_method = self.sd_controlnet_method
|
||||||
|
self.sd_controlnet_method = config.controlnet_method
|
||||||
|
self.model.switch_controlnet_method(config.controlnet_method)
|
||||||
|
logger.info(
|
||||||
|
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
|
||||||
|
)
|
||||||
|
elif self.sd_controlnet != config.controlnet_enabled:
|
||||||
self.sd_controlnet = config.controlnet_enabled
|
self.sd_controlnet = config.controlnet_enabled
|
||||||
self.sd_controlnet_method = config.controlnet_method
|
self.sd_controlnet_method = config.controlnet_method
|
||||||
|
|
||||||
@ -120,10 +125,6 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
if not config.controlnet_enabled:
|
if not config.controlnet_enabled:
|
||||||
logger.info(f"Disable controlnet")
|
logger.info(f"Disable controlnet")
|
||||||
elif old_sd_controlnet_method != config.controlnet_method:
|
|
||||||
logger.info(
|
|
||||||
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
|
||||||
|
@ -258,6 +258,13 @@ def process():
|
|||||||
croper_y=form["croperY"],
|
croper_y=form["croperY"],
|
||||||
croper_height=form["croperHeight"],
|
croper_height=form["croperHeight"],
|
||||||
croper_width=form["croperWidth"],
|
croper_width=form["croperWidth"],
|
||||||
|
|
||||||
|
use_extender=form["useExtender"],
|
||||||
|
extender_x=form["extenderX"],
|
||||||
|
extender_y=form["extenderY"],
|
||||||
|
extender_height=form["extenderHeight"],
|
||||||
|
extender_width=form["extenderWidth"],
|
||||||
|
|
||||||
sd_scale=form["sdScale"],
|
sd_scale=form["sdScale"],
|
||||||
sd_mask_blur=form["sdMaskBlur"],
|
sd_mask_blur=form["sdMaskBlur"],
|
||||||
sd_strength=form["sdStrength"],
|
sd_strength=form["sdStrength"],
|
||||||
|
@ -17,6 +17,7 @@ save_dir = current_dir / "result"
|
|||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
model_name = "runwayml/stable-diffusion-inpainting"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@ -35,7 +36,7 @@ def test_runway_sd_1_5(
|
|||||||
|
|
||||||
sd_steps = 1 if sd_device == "cpu" else 30
|
sd_steps = 1 if sd_device == "cpu" else 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=model_name,
|
||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
@ -83,7 +84,7 @@ def test_local_file_path(sd_device, sampler):
|
|||||||
|
|
||||||
sd_steps = 1 if sd_device == "cpu" else 30
|
sd_steps = 1 if sd_device == "cpu" else 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=model_name,
|
||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
@ -121,7 +122,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
|||||||
|
|
||||||
sd_steps = 1 if sd_device == "cpu" else 30
|
sd_steps = 1 if sd_device == "cpu" else 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=model_name,
|
||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
@ -162,7 +163,7 @@ def test_controlnet_switch(sd_device, sampler):
|
|||||||
|
|
||||||
sd_steps = 1 if sd_device == "cpu" else 30
|
sd_steps = 1 if sd_device == "cpu" else 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=model_name,
|
||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
wheel
|
wheel
|
||||||
twine
|
twine
|
||||||
|
pytest-loguru
|
Loading…
Reference in New Issue
Block a user