add load local model

This commit is contained in:
Qing 2023-03-29 22:05:34 +08:00
parent 61df5f69b3
commit f2e90d3f84
7 changed files with 131 additions and 10 deletions

View File

@ -52,6 +52,10 @@ SD_CONTROLNET_HELP = """
Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control. Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control.
""" """
SD_LOCAL_MODEL_HELP = """
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
"""
LOCAL_FILES_ONLY_HELP = """ LOCAL_FILES_ONLY_HELP = """
Use local files only, not connect to Hugging Face server. (sd/paint_by_example) Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
""" """

12
lama_cleaner/installer.py Normal file
View File

@ -0,0 +1,12 @@
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
def install_plugins_package():
install("rembg")
install("realesrgan")
install("gfpgan")

View File

@ -1,3 +1,5 @@
import gc
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
@ -31,6 +33,37 @@ class CPUTextEncoderWrapper:
return self.torch_dtype return self.torch_dtype
def load_from_local_model(local_model_path, torch_dtype, disable_nsfw):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
logger.info(f"Converting {local_model_path} to diffusers pipeline")
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
local_model_path,
num_in_channels=9,
from_safetensors=local_model_path.endswith("safetensors"),
device="cpu",
)
inpaint_pipe = StableDiffusionInpaintPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
scheduler=pipe.scheduler,
safety_checker=None if disable_nsfw else pipe.safety_checker,
feature_extractor=None if disable_nsfw else pipe.safety_checker,
requires_safety_checker=not disable_nsfw,
)
del pipe
gc.collect()
return inpaint_pipe.to(torch_dtype)
class SD(DiffusionInpaintModel): class SD(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
@ -43,6 +76,7 @@ class SD(DiffusionInpaintModel):
model_kwargs = { model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
} }
disable_nsfw = False
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
@ -52,15 +86,24 @@ class SD(DiffusionInpaintModel):
requires_safety_checker=False, requires_safety_checker=False,
) )
) )
disable_nsfw = True
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
if kwargs.get("sd_local_model_path", None):
self.model = load_from_local_model(
kwargs["sd_local_model_path"],
torch_dtype=torch_dtype,
disable_nsfw=disable_nsfw,
)
else:
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" if use_gpu and fp16 else "main", revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs **model_kwargs,
) )
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing

View File

@ -38,6 +38,7 @@ def parse_args():
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
) )
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
parser.add_argument( parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
) )
@ -112,6 +113,10 @@ def parse_args():
action="store_true", action="store_true",
help=GIF_HELP, help=GIF_HELP,
) )
parser.add_argument(
"--install-plugins-package",
action="store_true",
)
######### #########
# useless args # useless args
@ -129,6 +134,11 @@ def parse_args():
# collect system info to help debug # collect system info to help debug
dump_environment_info() dump_environment_info()
if args.install_plugins_package:
from lama_cleaner.installer import install_plugins_package
install_plugins_package()
exit()
if args.config_installer: if args.config_installer:
if args.installer_config is None: if args.installer_config is None:
@ -165,6 +175,16 @@ def parse_args():
if args.model not in SD15_MODELS: if args.model not in SD15_MODELS:
logger.warning(f"--sd_controlnet only support {SD15_MODELS}") logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
if args.sd_local_model_path and args.model == "sd1.5":
if not os.path.exists(args.sd_local_model_path):
parser.error(
f"invalid --sd-local-model-path: {args.sd_local_model_path} not exists"
)
if not os.path.isfile(args.sd_local_model_path):
parser.error(
f"invalid --sd-local-model-path: {args.sd_local_model_path} is a directory"
)
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):

View File

@ -1,10 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import imghdr import imghdr
import io import io
import logging import logging
import multiprocessing import multiprocessing
import os
import random import random
import time import time
from pathlib import Path from pathlib import Path
@ -527,6 +529,7 @@ def main(args):
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder, sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local, sd_run_local=args.sd_run_local,
sd_local_model_path=args.sd_local_model_path,
local_files_only=args.local_files_only, local_files_only=args.local_files_only,
cpu_offload=args.cpu_offload, cpu_offload=args.cpu_offload,
enable_xformers=args.sd_enable_xformers or args.enable_xformers, enable_xformers=args.sd_enable_xformers or args.enable_xformers,

View File

@ -1,3 +1,6 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path from pathlib import Path
import pytest import pytest
@ -202,3 +205,37 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
) )
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
def test_local_file_path(sd_device, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 1 if sd_device == "cpu" else 50
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"sd_local_model_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -6,9 +6,11 @@ set PATH=C:\Windows\System32;%PATH%
@call conda-unpack @call conda-unpack
@call conda install -y cudatoolkit=11.3 @call conda install -y -c conda-forge cudatoolkit=11.7
@call pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 @call pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
@call pip install xformers
@call pip3 install -U lama-cleaner @call pip3 install -U lama-cleaner
@call lama-cleaner --install-plugins-package
@call lama-cleaner --config-installer --installer-config %0\..\installer_config.json @call lama-cleaner --config-installer --installer-config %0\..\installer_config.json