wip: controlnet
This commit is contained in:
parent
e5ac6a105a
commit
87f54bb87e
@ -53,8 +53,13 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
||||
"""
|
||||
|
||||
SD_CONTROLNET_HELP = """
|
||||
Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control.
|
||||
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
|
||||
"""
|
||||
SD_CONTROLNET_CHOICES = [
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_openpose",
|
||||
"control_v11p_sd15_inpaint",
|
||||
]
|
||||
|
||||
SD_LOCAL_MODEL_HELP = """
|
||||
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
|
||||
|
@ -4,9 +4,7 @@ import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
ControlNetModel,
|
||||
)
|
||||
from diffusers import ControlNetModel
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
@ -75,7 +73,7 @@ def load_from_local_model(
|
||||
num_in_channels=4 if is_native_control_inpaint else 9,
|
||||
from_safetensors=local_model_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
inpaint_pipe = pipe_class(
|
||||
@ -92,7 +90,7 @@ def load_from_local_model(
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe.to(torch_dtype)
|
||||
return inpaint_pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
|
||||
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
||||
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
||||
@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel):
|
||||
output_type="np.array",
|
||||
).images[0]
|
||||
else:
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate(
|
||||
[canny_image, canny_image, canny_image], axis=2
|
||||
)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
control_image = canny_image
|
||||
elif "openpose" in self.sd_controlnet_method:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
|
||||
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||
control_image = processor(image, hand_and_face=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.sd_controlnet_method} not implemented"
|
||||
)
|
||||
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
image = PIL.Image.fromarray(image)
|
||||
|
||||
output = self.model(
|
||||
image=image,
|
||||
control_image=canny_image,
|
||||
control_image=control_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask_image=mask_image,
|
||||
|
@ -35,13 +35,13 @@ class CPUTextEncoderWrapper:
|
||||
|
||||
def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||
|
||||
logger.info(f"Converting {local_model_path} to diffusers pipeline")
|
||||
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
local_model_path,
|
||||
num_in_channels=9,
|
||||
from_safetensors=local_model_path.endswith("safetensors"),
|
||||
|
@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT
|
||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.model.opencv2 import OpenCV2
|
||||
from lama_cleaner.schema import Config
|
||||
@ -59,7 +60,7 @@ class ModelManager:
|
||||
def __call__(self, image, mask, config: Config):
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str):
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
if new_name == self.name:
|
||||
return
|
||||
try:
|
||||
@ -75,3 +76,17 @@ class ModelManager:
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
raise e
|
||||
|
||||
def switch_controlnet_method(self, control_method: str):
|
||||
if not self.kwargs.get("sd_controlnet"):
|
||||
return
|
||||
if self.kwargs["sd_controlnet_method"] == control_method:
|
||||
return
|
||||
|
||||
del self.model
|
||||
torch_gc()
|
||||
|
||||
self.kwargs["sd_controlnet_method"] = control_method
|
||||
self.model = self.init_model(
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
|
@ -41,10 +41,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--sd-controlnet-method",
|
||||
default="control_v11p_sd15_inpaint",
|
||||
choices=[
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_inpaint",
|
||||
],
|
||||
choices=SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
|
||||
parser.add_argument(
|
||||
|
@ -436,6 +436,17 @@ def switch_model():
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/controlnet_method", methods=["POST"])
|
||||
def switch_controlnet_method():
|
||||
new_method = request.form.get("method")
|
||||
|
||||
try:
|
||||
model.switch_controlnet_method(new_method)
|
||||
except NotImplementedError:
|
||||
return f"Failed switch to {new_method} not implemented", 500
|
||||
return f"Switch to {new_method}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
BIN
lama_cleaner/tests/mask.png
Normal file
BIN
lama_cleaner/tests/mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.7 KiB |
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from lama_cleaner.const import SD_CONTROLNET_CHOICES
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
@ -22,7 +24,10 @@ device = torch.device(device)
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||
@pytest.mark.parametrize("cpu_textencoder", [True])
|
||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||
@pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES)
|
||||
def test_runway_sd_1_5(
|
||||
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method
|
||||
):
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
@ -34,19 +39,20 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_run_local=False,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
sd_controlnet_method=sd_controlnet_method,
|
||||
)
|
||||
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_{name}.png",
|
||||
f"sd_controlnet_{sd_controlnet_method}_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.2,
|
||||
@ -68,11 +74,12 @@ def test_local_file_path(sd_device, sampler):
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_run_local=False,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
||||
sd_controlnet_method="control_v11p_sd15_canny",
|
||||
)
|
||||
cfg = get_config(
|
||||
HDStrategy.ORIGINAL,
|
||||
@ -86,7 +93,48 @@ def test_local_file_path(sd_device, sampler):
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_local_model_{name}.png",
|
||||
f"sd_controlnet_canny_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||
def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 1 if sd_device == "cpu" else 30
|
||||
model = ModelManager(
|
||||
name="sd1.5",
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=False,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors",
|
||||
sd_controlnet_method="control_v11p_sd15_inpaint",
|
||||
)
|
||||
cfg = get_config(
|
||||
HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
sd_strength=1.0
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_local_native_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
||||
model = ModelManager(name="instruct_pix2pix",
|
||||
device=torch.device(device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_run_local=False,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=cpu_offload)
|
||||
@ -45,7 +45,7 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
||||
model = ModelManager(name="instruct_pix2pix",
|
||||
device=torch.device(device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_run_local=False,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=cpu_offload)
|
||||
|
@ -14,3 +14,4 @@ gradio
|
||||
piexif==1.1.3
|
||||
safetensors
|
||||
omegaconf
|
||||
controlnet-aux==0.0.3
|
Loading…
Reference in New Issue
Block a user