wip: controlnet
This commit is contained in:
parent
e5ac6a105a
commit
87f54bb87e
@ -53,8 +53,13 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
SD_CONTROLNET_HELP = """
|
SD_CONTROLNET_HELP = """
|
||||||
Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control.
|
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
|
||||||
"""
|
"""
|
||||||
|
SD_CONTROLNET_CHOICES = [
|
||||||
|
"control_v11p_sd15_canny",
|
||||||
|
"control_v11p_sd15_openpose",
|
||||||
|
"control_v11p_sd15_inpaint",
|
||||||
|
]
|
||||||
|
|
||||||
SD_LOCAL_MODEL_HELP = """
|
SD_LOCAL_MODEL_HELP = """
|
||||||
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
|
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
|
||||||
|
@ -4,9 +4,7 @@ import PIL.Image
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import ControlNetModel
|
||||||
ControlNetModel,
|
|
||||||
)
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
@ -75,7 +73,7 @@ def load_from_local_model(
|
|||||||
num_in_channels=4 if is_native_control_inpaint else 9,
|
num_in_channels=4 if is_native_control_inpaint else 9,
|
||||||
from_safetensors=local_model_path.endswith("safetensors"),
|
from_safetensors=local_model_path.endswith("safetensors"),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
load_safety_checker=False
|
load_safety_checker=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
inpaint_pipe = pipe_class(
|
inpaint_pipe = pipe_class(
|
||||||
@ -92,7 +90,7 @@ def load_from_local_model(
|
|||||||
|
|
||||||
del pipe
|
del pipe
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return inpaint_pipe.to(torch_dtype)
|
return inpaint_pipe.to(torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(DiffusionInpaintModel):
|
class ControlNet(DiffusionInpaintModel):
|
||||||
@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||||
|
self.sd_controlnet_method = sd_controlnet_method
|
||||||
|
|
||||||
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
||||||
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
||||||
@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
output_type="np.array",
|
output_type="np.array",
|
||||||
).images[0]
|
).images[0]
|
||||||
else:
|
else:
|
||||||
|
if "canny" in self.sd_controlnet_method:
|
||||||
canny_image = cv2.Canny(image, 100, 200)
|
canny_image = cv2.Canny(image, 100, 200)
|
||||||
canny_image = canny_image[:, :, None]
|
canny_image = canny_image[:, :, None]
|
||||||
canny_image = np.concatenate(
|
canny_image = np.concatenate(
|
||||||
[canny_image, canny_image, canny_image], axis=2
|
[canny_image, canny_image, canny_image], axis=2
|
||||||
)
|
)
|
||||||
canny_image = PIL.Image.fromarray(canny_image)
|
canny_image = PIL.Image.fromarray(canny_image)
|
||||||
|
control_image = canny_image
|
||||||
|
elif "openpose" in self.sd_controlnet_method:
|
||||||
|
from controlnet_aux import OpenposeDetector
|
||||||
|
|
||||||
|
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||||
|
control_image = processor(image, hand_and_face=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.sd_controlnet_method} not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||||
image = PIL.Image.fromarray(image)
|
image = PIL.Image.fromarray(image)
|
||||||
|
|
||||||
output = self.model(
|
output = self.model(
|
||||||
image=image,
|
image=image,
|
||||||
control_image=canny_image,
|
control_image=control_image,
|
||||||
prompt=config.prompt,
|
prompt=config.prompt,
|
||||||
negative_prompt=config.negative_prompt,
|
negative_prompt=config.negative_prompt,
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
|
@ -35,13 +35,13 @@ class CPUTextEncoderWrapper:
|
|||||||
|
|
||||||
def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
|
def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
download_from_original_stable_diffusion_ckpt,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
logger.info(f"Converting {local_model_path} to diffusers pipeline")
|
logger.info(f"Converting {local_model_path} to diffusers pipeline")
|
||||||
|
|
||||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipe = download_from_original_stable_diffusion_ckpt(
|
||||||
local_model_path,
|
local_model_path,
|
||||||
num_in_channels=9,
|
num_in_channels=9,
|
||||||
from_safetensors=local_model_path.endswith("safetensors"),
|
from_safetensors=local_model_path.endswith("safetensors"),
|
||||||
|
@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT
|
|||||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||||
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model.zits import ZITS
|
from lama_cleaner.model.zits import ZITS
|
||||||
from lama_cleaner.model.opencv2 import OpenCV2
|
from lama_cleaner.model.opencv2 import OpenCV2
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
@ -59,7 +60,7 @@ class ModelManager:
|
|||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
return self.model(image, mask, config)
|
return self.model(image, mask, config)
|
||||||
|
|
||||||
def switch(self, new_name: str):
|
def switch(self, new_name: str, **kwargs):
|
||||||
if new_name == self.name:
|
if new_name == self.name:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@ -75,3 +76,17 @@ class ModelManager:
|
|||||||
self.name = new_name
|
self.name = new_name
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def switch_controlnet_method(self, control_method: str):
|
||||||
|
if not self.kwargs.get("sd_controlnet"):
|
||||||
|
return
|
||||||
|
if self.kwargs["sd_controlnet_method"] == control_method:
|
||||||
|
return
|
||||||
|
|
||||||
|
del self.model
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
self.kwargs["sd_controlnet_method"] = control_method
|
||||||
|
self.model = self.init_model(
|
||||||
|
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||||
|
)
|
||||||
|
@ -41,10 +41,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-controlnet-method",
|
"--sd-controlnet-method",
|
||||||
default="control_v11p_sd15_inpaint",
|
default="control_v11p_sd15_inpaint",
|
||||||
choices=[
|
choices=SD_CONTROLNET_CHOICES,
|
||||||
"control_v11p_sd15_canny",
|
|
||||||
"control_v11p_sd15_inpaint",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
|
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -436,6 +436,17 @@ def switch_model():
|
|||||||
return f"ok, switch to {new_name}", 200
|
return f"ok, switch to {new_name}", 200
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/controlnet_method", methods=["POST"])
|
||||||
|
def switch_controlnet_method():
|
||||||
|
new_method = request.form.get("method")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.switch_controlnet_method(new_method)
|
||||||
|
except NotImplementedError:
|
||||||
|
return f"Failed switch to {new_method} not implemented", 500
|
||||||
|
return f"Switch to {new_method}", 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||||
|
BIN
lama_cleaner/tests/mask.png
Normal file
BIN
lama_cleaner/tests/mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.7 KiB |
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from lama_cleaner.const import SD_CONTROLNET_CHOICES
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -22,7 +24,10 @@ device = torch.device(device)
|
|||||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [True])
|
@pytest.mark.parametrize("cpu_textencoder", [True])
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
@pytest.mark.parametrize("disable_nsfw", [True])
|
||||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
@pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES)
|
||||||
|
def test_runway_sd_1_5(
|
||||||
|
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method
|
||||||
|
):
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
if device == "mps" and not torch.backends.mps.is_available():
|
if device == "mps" and not torch.backends.mps.is_available():
|
||||||
@ -34,19 +39,20 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=False,
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
|
sd_controlnet_method=sd_controlnet_method,
|
||||||
)
|
)
|
||||||
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw"
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"sd_controlnet_{name}.png",
|
f"sd_controlnet_{sd_controlnet_method}_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.2,
|
fx=1.2,
|
||||||
@ -68,11 +74,12 @@ def test_local_file_path(sd_device, sampler):
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=False,
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
||||||
|
sd_controlnet_method="control_v11p_sd15_canny",
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
HDStrategy.ORIGINAL,
|
HDStrategy.ORIGINAL,
|
||||||
@ -86,7 +93,48 @@ def test_local_file_path(sd_device, sampler):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"sd_controlnet_local_model_{name}.png",
|
f"sd_controlnet_canny_local_model_{name}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||||
|
def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
if device == "mps" and not torch.backends.mps.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 1 if sd_device == "cpu" else 30
|
||||||
|
model = ModelManager(
|
||||||
|
name="sd1.5",
|
||||||
|
sd_controlnet=True,
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=False,
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
cpu_offload=True,
|
||||||
|
sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors",
|
||||||
|
sd_controlnet_method="control_v11p_sd15_inpaint",
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
HDStrategy.ORIGINAL,
|
||||||
|
prompt="a fox sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
controlnet_conditioning_scale=1.0,
|
||||||
|
sd_strength=1.0
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"device_{sd_device}_{sampler}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sd_controlnet_local_native_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
|||||||
model = ModelManager(name="instruct_pix2pix",
|
model = ModelManager(name="instruct_pix2pix",
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=False,
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=cpu_offload)
|
cpu_offload=cpu_offload)
|
||||||
@ -45,7 +45,7 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
|||||||
model = ModelManager(name="instruct_pix2pix",
|
model = ModelManager(name="instruct_pix2pix",
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=False,
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=cpu_offload)
|
cpu_offload=cpu_offload)
|
||||||
|
@ -14,3 +14,4 @@ gradio
|
|||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
safetensors
|
safetensors
|
||||||
omegaconf
|
omegaconf
|
||||||
|
controlnet-aux==0.0.3
|
Loading…
Reference in New Issue
Block a user