controlnet support load local ckpt
This commit is contained in:
parent
5fd253b07e
commit
65f12b490a
@ -8,6 +8,7 @@ MPS_SUPPORT_MODELS = [
|
||||
"realisticVision1.4",
|
||||
"sd2",
|
||||
"paint_by_example",
|
||||
"controlnet"
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "lama"
|
||||
|
@ -1,3 +1,5 @@
|
||||
import gc
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -41,6 +43,38 @@ NAMES_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_from_local_model(local_model_path, torch_dtype, controlnet):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .pipeline import StableDiffusionControlNetInpaintPipeline
|
||||
|
||||
logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
|
||||
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
local_model_path,
|
||||
num_in_channels=9,
|
||||
from_safetensors=local_model_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
inpaint_pipe = StableDiffusionControlNetInpaintPipeline(
|
||||
vae=pipe.vae,
|
||||
text_encoder=pipe.text_encoder,
|
||||
tokenizer=pipe.tokenizer,
|
||||
unet=pipe.unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=pipe.scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe.to(torch_dtype)
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
name = "controlnet"
|
||||
pad_mod = 8
|
||||
@ -71,13 +105,20 @@ class ControlNet(DiffusionInpaintModel):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
|
||||
)
|
||||
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
controlnet=controlnet,
|
||||
revision="fp16" if use_gpu and fp16 else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
if kwargs.get("sd_local_model_path", None):
|
||||
self.model = load_from_local_model(
|
||||
kwargs["sd_local_model_path"],
|
||||
torch_dtype=torch_dtype,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
else:
|
||||
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
controlnet=controlnet,
|
||||
revision="fp16" if use_gpu and fp16 else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
|
92
lama_cleaner/tests/test_controlnet.py
Normal file
92
lama_cleaner/tests/test_controlnet.py
Normal file
@ -0,0 +1,92 @@
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import HDStrategy, SDSampler
|
||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = torch.device(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||
@pytest.mark.parametrize("cpu_textencoder", [True])
|
||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 1 if sd_device == "cpu" else 30
|
||||
model = ModelManager(
|
||||
name="sd1.5",
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
)
|
||||
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.2,
|
||||
fy=1.2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||
def test_local_file_path(sd_device, sampler):
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 1 if sd_device == "cpu" else 30
|
||||
model = ModelManager(
|
||||
name="sd1.5",
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
||||
)
|
||||
cfg = get_config(
|
||||
HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
Loading…
Reference in New Issue
Block a user