From f2e90d3f843e4ed4930a100d58bb571a24be4cf4 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 29 Mar 2023 22:05:34 +0800 Subject: [PATCH] add load local model --- lama_cleaner/const.py | 4 ++ lama_cleaner/installer.py | 12 ++++++ lama_cleaner/model/sd.py | 57 +++++++++++++++++++++++++---- lama_cleaner/parse_args.py | 20 ++++++++++ lama_cleaner/server.py | 5 ++- lama_cleaner/tests/test_sd_model.py | 37 +++++++++++++++++++ scripts/user_scripts/win_config.bat | 6 ++- 7 files changed, 131 insertions(+), 10 deletions(-) create mode 100644 lama_cleaner/installer.py diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 01df33f..cdc89d7 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -52,6 +52,10 @@ SD_CONTROLNET_HELP = """ Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control. """ +SD_LOCAL_MODEL_HELP = """ +Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. +""" + LOCAL_FILES_ONLY_HELP = """ Use local files only, not connect to Hugging Face server. (sd/paint_by_example) """ diff --git a/lama_cleaner/installer.py b/lama_cleaner/installer.py new file mode 100644 index 0000000..f255e33 --- /dev/null +++ b/lama_cleaner/installer.py @@ -0,0 +1,12 @@ +import subprocess +import sys + + +def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +def install_plugins_package(): + install("rembg") + install("realesrgan") + install("gfpgan") diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 39ced29..9117dd6 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -1,3 +1,5 @@ +import gc + import PIL.Image import cv2 import numpy as np @@ -31,6 +33,37 @@ class CPUTextEncoderWrapper: return self.torch_dtype +def load_from_local_model(local_model_path, torch_dtype, disable_nsfw): + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + load_pipeline_from_original_stable_diffusion_ckpt, + ) + from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline + + logger.info(f"Converting {local_model_path} to diffusers pipeline") + + pipe = load_pipeline_from_original_stable_diffusion_ckpt( + local_model_path, + num_in_channels=9, + from_safetensors=local_model_path.endswith("safetensors"), + device="cpu", + ) + + inpaint_pipe = StableDiffusionInpaintPipeline( + vae=pipe.vae, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + unet=pipe.unet, + scheduler=pipe.scheduler, + safety_checker=None if disable_nsfw else pipe.safety_checker, + feature_extractor=None if disable_nsfw else pipe.safety_checker, + requires_safety_checker=not disable_nsfw, + ) + + del pipe + gc.collect() + return inpaint_pipe.to(torch_dtype) + + class SD(DiffusionInpaintModel): pad_mod = 8 min_size = 512 @@ -43,6 +76,7 @@ class SD(DiffusionInpaintModel): model_kwargs = { "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) } + disable_nsfw = False if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( @@ -52,16 +86,25 @@ class SD(DiffusionInpaintModel): requires_safety_checker=False, ) ) + disable_nsfw = True use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - self.model = StableDiffusionInpaintPipeline.from_pretrained( - self.model_id_or_path, - revision="fp16" if use_gpu and fp16 else "main", - torch_dtype=torch_dtype, - use_auth_token=kwargs["hf_access_token"], - **model_kwargs - ) + + if kwargs.get("sd_local_model_path", None): + self.model = load_from_local_model( + kwargs["sd_local_model_path"], + torch_dtype=torch_dtype, + disable_nsfw=disable_nsfw, + ) + else: + self.model = StableDiffusionInpaintPipeline.from_pretrained( + self.model_id_or_path, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + use_auth_token=kwargs["hf_access_token"], + **model_kwargs, + ) # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 6d3de56..fcaef80 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -38,6 +38,7 @@ def parse_args(): "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP ) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) + parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument( "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP ) @@ -112,6 +113,10 @@ def parse_args(): action="store_true", help=GIF_HELP, ) + parser.add_argument( + "--install-plugins-package", + action="store_true", + ) ######### # useless args @@ -129,6 +134,11 @@ def parse_args(): # collect system info to help debug dump_environment_info() + if args.install_plugins_package: + from lama_cleaner.installer import install_plugins_package + + install_plugins_package() + exit() if args.config_installer: if args.installer_config is None: @@ -165,6 +175,16 @@ def parse_args(): if args.model not in SD15_MODELS: logger.warning(f"--sd_controlnet only support {SD15_MODELS}") + if args.sd_local_model_path and args.model == "sd1.5": + if not os.path.exists(args.sd_local_model_path): + parser.error( + f"invalid --sd-local-model-path: {args.sd_local_model_path} not exists" + ) + if not os.path.isfile(args.sd_local_model_path): + parser.error( + f"invalid --sd-local-model-path: {args.sd_local_model_path} is a directory" + ) + os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR if args.model_dir and args.model_dir is not None: if os.path.isfile(args.model_dir): diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index d2a20d0..0e0543d 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import imghdr import io import logging import multiprocessing -import os import random import time from pathlib import Path @@ -527,6 +529,7 @@ def main(args): disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, sd_cpu_textencoder=args.sd_cpu_textencoder, sd_run_local=args.sd_run_local, + sd_local_model_path=args.sd_local_model_path, local_files_only=args.local_files_only, cpu_offload=args.cpu_offload, enable_xformers=args.sd_enable_xformers or args.enable_xformers, diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 87b58a9..9b80886 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -1,3 +1,6 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" from pathlib import Path import pytest @@ -202,3 +205,37 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 50 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/scripts/user_scripts/win_config.bat b/scripts/user_scripts/win_config.bat index bf352ef..2a11003 100644 --- a/scripts/user_scripts/win_config.bat +++ b/scripts/user_scripts/win_config.bat @@ -6,9 +6,11 @@ set PATH=C:\Windows\System32;%PATH% @call conda-unpack -@call conda install -y cudatoolkit=11.3 -@call pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 +@call conda install -y -c conda-forge cudatoolkit=11.7 +@call pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +@call pip install xformers @call pip3 install -U lama-cleaner +@call lama-cleaner --install-plugins-package @call lama-cleaner --config-installer --installer-config %0\..\installer_config.json