Merge branch 'update_sd'

This commit is contained in:
Qing 2022-09-30 22:45:11 +08:00
commit 7cb6c7c827
8 changed files with 126 additions and 24 deletions

View File

@ -57,16 +57,19 @@ lama-cleaner --model=lama --device=cpu --port=8080
Available arguments:
| Name | Description | Default |
| ----------------- | -------------------------------------------------------------------------------------------------------- | -------- |
| --model | lama/ldm/zits/mat/fcf/sd. See details in [Inpaint Model](#inpainting-model) | lama |
| --hf_access_token | stable-diffusion(sd) model need huggingface access token https://huggingface.co/docs/hub/security-tokens | |
| --device | cuda or cpu | cuda |
| --port | Port for backend flask web server | 8080 |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --debug | Enable debug mode for flask web server | |
| Name | Description | Default |
|-------------------|-------------------------------------------------------------------------------------------------------------------------------| -------- |
| --model | lama/ldm/zits/mat/fcf/sd1.4 See details in [Inpaint Model](#inpainting-model) | lama |
| --hf_access_token | stable-diffusion(sd) model need [huggingface access token](https://huggingface.co/docs/hub/security-tokens) to download model | |
| --sd-run-local | Once the model as downloaded, you can pass this arg and remove `--hf_access_token` | |
| --sd-disable-nsfw | Disable stable-diffusion NSFW checker. | |
| --sd-cpu-textencoder | Always run stable-diffusion TextEncoder model on CPU. | |
| --device | cuda or cpu | cuda |
| --port | Port for backend flask web server | 8080 |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --debug | Enable debug mode for flask web server | |
## Inpainting Model

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
from lama_cleaner.helper import norm_img
@ -38,23 +39,61 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask)
# return mask
class DummyFeatureExtractorOutput:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
def to(self, device):
return self
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, *args, **kwargs):
return DummyFeatureExtractorOutput(torch.empty(0, 3))
class DummySafetyChecker:
def __init__(self, *args, **kwargs):
pass
def __call__(self, clip_input, images):
return images, False
class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
if kwargs['sd_disable_nsfw']:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(),
safety_checker=DummySafetyChecker(),
))
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="fp16" if torch.cuda.is_available() else "main",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"],
**model_kwargs
)
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
self.callbacks = kwargs.pop("callbacks", None)
@torch.cuda.amp.autocast()

View File

@ -236,7 +236,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
text_encoder_device = self.text_encoder.device
text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -248,7 +250,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@ -269,7 +271,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

View File

@ -15,7 +15,22 @@ def parse_args():
parser.add_argument(
"--hf_access_token",
default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
)
parser.add_argument(
"--sd-disable-nsfw",
action="store_true",
help="Disable Stable Diffusion NSFW checker",
)
parser.add_argument(
"--sd-cpu-textencoder",
action="store_true",
help="Always run Stable Diffusion TextEncoder model on CPU",
)
parser.add_argument(
"--sd-run-local",
action="store_true",
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token",
)
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
@ -38,7 +53,7 @@ def parse_args():
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
if args.model.startswith("sd"):
if args.model.startswith("sd") and not args.sd_run_local:
if not args.hf_access_token.startswith("hf_"):
parser.error(
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"

View File

@ -218,6 +218,9 @@ def main(args):
name=args.model,
device=device,
hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local,
callbacks=[diffuser_callback],
)

View File

@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
res = model(img, mask, config)
cv2.imwrite(
str(current_dir / gt_name),
str(save_dir / gt_name),
res,
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
)
@ -158,12 +160,17 @@ def test_fcf(strategy):
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm])
def test_sd(strategy, sampler, capfd):
def test_sd(strategy, sampler):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 50
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
model = ModelManager(name="sd1.4",
device=device,
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
sd_run_local=False,
sd_disable_nsfw=False,
sd_cpu_textencoder=False,
callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
@ -184,6 +191,40 @@ def test_sd(strategy, sampler, capfd):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
)
# captured = capfd.readouterr()
# for i in range(sd_steps):
# assert f'sd_step_{i}' in captured.out
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("disable_nsfw", [True, False])
def test_sd_run_local(strategy, sampler, disable_nsfw):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 50
model = ModelManager(
name="sd1.4",
device=device,
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
hf_access_token=None,
sd_run_local=True,
sd_disable_nsfw=disable_nsfw,
sd_cpu_textencoder=True,
)
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_{sampler}_local_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_{sampler}_blur_mask_local_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
)

View File

@ -1,4 +1,4 @@
torch>=1.8.2
torch>=1.9.0
opencv-python
flask_cors
flask==1.1.4

View File

@ -21,7 +21,7 @@ def load_requirements():
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup(
name="lama-cleaner",
version="0.20.1",
version="0.21.0",
author="PanicByte",
author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model",