Merge branch 'update_sd'

This commit is contained in:
Qing 2022-09-30 22:45:11 +08:00
commit 7cb6c7c827
8 changed files with 126 additions and 24 deletions

View File

@ -58,9 +58,12 @@ lama-cleaner --model=lama --device=cpu --port=8080
Available arguments: Available arguments:
| Name | Description | Default | | Name | Description | Default |
| ----------------- | -------------------------------------------------------------------------------------------------------- | -------- | |-------------------|-------------------------------------------------------------------------------------------------------------------------------| -------- |
| --model | lama/ldm/zits/mat/fcf/sd. See details in [Inpaint Model](#inpainting-model) | lama | | --model | lama/ldm/zits/mat/fcf/sd1.4 See details in [Inpaint Model](#inpainting-model) | lama |
| --hf_access_token | stable-diffusion(sd) model need huggingface access token https://huggingface.co/docs/hub/security-tokens | | | --hf_access_token | stable-diffusion(sd) model need [huggingface access token](https://huggingface.co/docs/hub/security-tokens) to download model | |
| --sd-run-local | Once the model as downloaded, you can pass this arg and remove `--hf_access_token` | |
| --sd-disable-nsfw | Disable stable-diffusion NSFW checker. | |
| --sd-cpu-textencoder | Always run stable-diffusion TextEncoder model on CPU. | |
| --device | cuda or cpu | cuda | | --device | cuda or cpu | cuda |
| --port | Port for backend flask web server | 8080 | | --port | Port for backend flask web server | 8080 |
| --gui | Launch lama-cleaner as a desktop application | | | --gui | Launch lama-cleaner as a desktop application | |

View File

@ -6,6 +6,7 @@ import numpy as np
import torch import torch
from diffusers import PNDMScheduler, DDIMScheduler from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger from loguru import logger
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
from lama_cleaner.helper import norm_img from lama_cleaner.helper import norm_img
@ -38,6 +39,29 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask) # mask = torch.from_numpy(mask)
# return mask # return mask
class DummyFeatureExtractorOutput:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
def to(self, device):
return self
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, *args, **kwargs):
return DummyFeatureExtractorOutput(torch.empty(0, 3))
class DummySafetyChecker:
def __init__(self, *args, **kwargs):
pass
def __call__(self, clip_input, images):
return images, False
class SD(InpaintModel): class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
@ -46,15 +70,30 @@ class SD(InpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline from .sd_pipeline import StableDiffusionInpaintPipeline
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
if kwargs['sd_disable_nsfw']:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(),
safety_checker=DummySafetyChecker(),
))
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" if torch.cuda.is_available() else "main", revision="fp16" if torch.cuda.is_available() else "main",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs
) )
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
self.model = self.model.to(device) self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
self.callbacks = kwargs.pop("callbacks", None) self.callbacks = kwargs.pop("callbacks", None)
@torch.cuda.amp.autocast() @torch.cuda.amp.autocast()

View File

@ -236,7 +236,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_encoder_device = self.text_encoder.device
text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -248,7 +250,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
@ -269,7 +271,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

View File

@ -15,7 +15,22 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",
default="", default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
)
parser.add_argument(
"--sd-disable-nsfw",
action="store_true",
help="Disable Stable Diffusion NSFW checker",
)
parser.add_argument(
"--sd-cpu-textencoder",
action="store_true",
help="Always run Stable Diffusion TextEncoder model on CPU",
)
parser.add_argument(
"--sd-run-local",
action="store_true",
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token",
) )
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
@ -38,7 +53,7 @@ def parse_args():
if imghdr.what(args.input) is None: if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file") parser.error(f"invalid --input: {args.input} is not a valid image file")
if args.model.startswith("sd"): if args.model.startswith("sd") and not args.sd_run_local:
if not args.hf_access_token.startswith("hf_"): if not args.hf_access_token.startswith("hf_"):
parser.error( parser.error(
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens" f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"

View File

@ -218,6 +218,9 @@ def main(args):
name=args.model, name=args.model,
device=device, device=device,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local,
callbacks=[diffuser_callback], callbacks=[diffuser_callback],
) )

View File

@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
res = model(img, mask, config) res = model(img, mask, config)
cv2.imwrite( cv2.imwrite(
str(current_dir / gt_name), str(save_dir / gt_name),
res, res,
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
) )
@ -158,12 +160,17 @@ def test_fcf(strategy):
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm]) @pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm])
def test_sd(strategy, sampler, capfd): def test_sd(strategy, sampler):
def callback(step: int): def callback(step: int):
print(f"sd_step_{step}") print(f"sd_step_{step}")
sd_steps = 50 sd_steps = 50
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], model = ModelManager(name="sd1.4",
device=device,
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
sd_run_local=False,
sd_disable_nsfw=False,
sd_cpu_textencoder=False,
callbacks=[callback]) callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
@ -184,6 +191,40 @@ def test_sd(strategy, sampler, capfd):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
) )
# captured = capfd.readouterr()
# for i in range(sd_steps): @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
# assert f'sd_step_{i}' in captured.out @pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("disable_nsfw", [True, False])
def test_sd_run_local(strategy, sampler, disable_nsfw):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 50
model = ModelManager(
name="sd1.4",
device=device,
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
hf_access_token=None,
sd_run_local=True,
sd_disable_nsfw=disable_nsfw,
sd_cpu_textencoder=True,
)
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_{sampler}_local_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_{sampler}_blur_mask_local_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
)

View File

@ -1,4 +1,4 @@
torch>=1.8.2 torch>=1.9.0
opencv-python opencv-python
flask_cors flask_cors
flask==1.1.4 flask==1.1.4

View File

@ -21,7 +21,7 @@ def load_requirements():
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files # https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup( setuptools.setup(
name="lama-cleaner", name="lama-cleaner",
version="0.20.1", version="0.21.0",
author="PanicByte", author="PanicByte",
author_email="cwq1913@gmail.com", author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model", description="Image inpainting tool powered by SOTA AI Model",