Merge branch 'update_sd'
This commit is contained in:
commit
7cb6c7c827
23
README.md
23
README.md
@ -57,16 +57,19 @@ lama-cleaner --model=lama --device=cpu --port=8080
|
||||
|
||||
Available arguments:
|
||||
|
||||
| Name | Description | Default |
|
||||
| ----------------- | -------------------------------------------------------------------------------------------------------- | -------- |
|
||||
| --model | lama/ldm/zits/mat/fcf/sd. See details in [Inpaint Model](#inpainting-model) | lama |
|
||||
| --hf_access_token | stable-diffusion(sd) model need huggingface access token https://huggingface.co/docs/hub/security-tokens | |
|
||||
| --device | cuda or cpu | cuda |
|
||||
| --port | Port for backend flask web server | 8080 |
|
||||
| --gui | Launch lama-cleaner as a desktop application | |
|
||||
| --gui_size | Set the window size for the application | 1200 900 |
|
||||
| --input | Path to image you want to load by default | None |
|
||||
| --debug | Enable debug mode for flask web server | |
|
||||
| Name | Description | Default |
|
||||
|-------------------|-------------------------------------------------------------------------------------------------------------------------------| -------- |
|
||||
| --model | lama/ldm/zits/mat/fcf/sd1.4 See details in [Inpaint Model](#inpainting-model) | lama |
|
||||
| --hf_access_token | stable-diffusion(sd) model need [huggingface access token](https://huggingface.co/docs/hub/security-tokens) to download model | |
|
||||
| --sd-run-local | Once the model as downloaded, you can pass this arg and remove `--hf_access_token` | |
|
||||
| --sd-disable-nsfw | Disable stable-diffusion NSFW checker. | |
|
||||
| --sd-cpu-textencoder | Always run stable-diffusion TextEncoder model on CPU. | |
|
||||
| --device | cuda or cpu | cuda |
|
||||
| --port | Port for backend flask web server | 8080 |
|
||||
| --gui | Launch lama-cleaner as a desktop application | |
|
||||
| --gui_size | Set the window size for the application | 1200 900 |
|
||||
| --input | Path to image you want to load by default | None |
|
||||
| --debug | Enable debug mode for flask web server | |
|
||||
|
||||
## Inpainting Model
|
||||
|
||||
|
@ -6,6 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from diffusers import PNDMScheduler, DDIMScheduler
|
||||
from loguru import logger
|
||||
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
|
||||
|
||||
from lama_cleaner.helper import norm_img
|
||||
|
||||
@ -38,23 +39,61 @@ from lama_cleaner.schema import Config, SDSampler
|
||||
# mask = torch.from_numpy(mask)
|
||||
# return mask
|
||||
|
||||
class DummyFeatureExtractorOutput:
|
||||
def __init__(self, pixel_values):
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
|
||||
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DummyFeatureExtractorOutput(torch.empty(0, 3))
|
||||
|
||||
|
||||
class DummySafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, clip_input, images):
|
||||
return images, False
|
||||
|
||||
|
||||
class SD(InpaintModel):
|
||||
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
||||
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .sd_pipeline import StableDiffusionInpaintPipeline
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
|
||||
if kwargs['sd_disable_nsfw']:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
feature_extractor=DummyFeatureExtractor(),
|
||||
safety_checker=DummySafetyChecker(),
|
||||
))
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16" if torch.cuda.is_available() else "main",
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs
|
||||
)
|
||||
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
self.model = self.model.to(device)
|
||||
|
||||
if kwargs['sd_cpu_textencoder']:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
|
||||
|
||||
self.callbacks = kwargs.pop("callbacks", None)
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
|
@ -236,7 +236,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
text_encoder_device = self.text_encoder.device
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@ -248,7 +250,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@ -269,7 +271,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
|
@ -15,7 +15,22 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--hf_access_token",
|
||||
default="",
|
||||
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
||||
help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd-disable-nsfw",
|
||||
action="store_true",
|
||||
help="Disable Stable Diffusion NSFW checker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd-cpu-textencoder",
|
||||
action="store_true",
|
||||
help="Always run Stable Diffusion TextEncoder model on CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd-run-local",
|
||||
action="store_true",
|
||||
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token",
|
||||
)
|
||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
@ -38,7 +53,7 @@ def parse_args():
|
||||
if imghdr.what(args.input) is None:
|
||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||
|
||||
if args.model.startswith("sd"):
|
||||
if args.model.startswith("sd") and not args.sd_run_local:
|
||||
if not args.hf_access_token.startswith("hf_"):
|
||||
parser.error(
|
||||
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"
|
||||
|
@ -218,6 +218,9 @@ def main(args):
|
||||
name=args.model,
|
||||
device=device,
|
||||
hf_access_token=args.hf_access_token,
|
||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||
sd_run_local=args.sd_run_local,
|
||||
callbacks=[diffuser_callback],
|
||||
)
|
||||
|
||||
|
@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / 'result'
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(
|
||||
str(current_dir / gt_name),
|
||||
str(save_dir / gt_name),
|
||||
res,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||
)
|
||||
@ -158,12 +160,17 @@ def test_fcf(strategy):
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm])
|
||||
def test_sd(strategy, sampler, capfd):
|
||||
def test_sd(strategy, sampler):
|
||||
def callback(step: int):
|
||||
print(f"sd_step_{step}")
|
||||
|
||||
sd_steps = 50
|
||||
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||
model = ModelManager(name="sd1.4",
|
||||
device=device,
|
||||
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||
sd_run_local=False,
|
||||
sd_disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
callbacks=[callback])
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
@ -184,6 +191,40 @@ def test_sd(strategy, sampler, capfd):
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
)
|
||||
|
||||
# captured = capfd.readouterr()
|
||||
# for i in range(sd_steps):
|
||||
# assert f'sd_step_{i}' in captured.out
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||
def test_sd_run_local(strategy, sampler, disable_nsfw):
|
||||
def callback(step: int):
|
||||
print(f"sd_step_{step}")
|
||||
|
||||
sd_steps = 50
|
||||
model = ModelManager(
|
||||
name="sd1.4",
|
||||
device=device,
|
||||
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
|
||||
hf_access_token=None,
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{sampler}_local_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{sampler}_blur_mask_local_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
torch>=1.8.2
|
||||
torch>=1.9.0
|
||||
opencv-python
|
||||
flask_cors
|
||||
flask==1.1.4
|
||||
|
2
setup.py
2
setup.py
@ -21,7 +21,7 @@ def load_requirements():
|
||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||
setuptools.setup(
|
||||
name="lama-cleaner",
|
||||
version="0.20.1",
|
||||
version="0.21.0",
|
||||
author="PanicByte",
|
||||
author_email="cwq1913@gmail.com",
|
||||
description="Image inpainting tool powered by SOTA AI Model",
|
||||
|
Loading…
Reference in New Issue
Block a user