From 3c87b050d940d5cdfabcd7816c708f9243b9a400 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 15 Oct 2022 22:32:25 +0800 Subject: [PATCH] update sd inpainting pipeline --- lama_cleaner/model/sd.py | 36 +---- lama_cleaner/model/sd_pipeline.py | 228 +++++++++++++++++++++--------- lama_cleaner/schema.py | 1 + lama_cleaner/server.py | 6 +- lama_cleaner/tests/test_model.py | 19 +-- 5 files changed, 178 insertions(+), 112 deletions(-) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index d9745d4..e960541 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -4,9 +4,8 @@ import PIL.Image import cv2 import numpy as np import torch -from diffusers import PNDMScheduler, DDIMScheduler +from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from loguru import logger -from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin from lama_cleaner.helper import norm_img @@ -39,30 +38,6 @@ from lama_cleaner.schema import Config, SDSampler # mask = torch.from_numpy(mask) # return mask -class DummyFeatureExtractorOutput: - def __init__(self, pixel_values): - self.pixel_values = pixel_values - - def to(self, device): - return self - - -class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def __call__(self, *args, **kwargs): - return DummyFeatureExtractorOutput(torch.empty(0, 3)) - - -class DummySafetyChecker: - def __init__(self, *args, **kwargs): - pass - - def __call__(self, clip_input, images): - return images, False - - class SD(InpaintModel): pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 min_size = 512 @@ -74,8 +49,7 @@ class SD(InpaintModel): if kwargs['sd_disable_nsfw']: logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update(dict( - feature_extractor=DummyFeatureExtractor(), - safety_checker=DummySafetyChecker(), + safety_checker=None, )) self.model = StableDiffusionInpaintPipeline.from_pretrained( @@ -94,7 +68,7 @@ class SD(InpaintModel): self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True) self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True ) - self.callbacks = kwargs.pop("callbacks", None) + self.callback = kwargs.pop("callback", None) @torch.cuda.amp.autocast() def forward(self, image, mask, config: Config): @@ -133,6 +107,8 @@ class SD(InpaintModel): "skip_prk_steps": True, } scheduler = PNDMScheduler(**PNDM_kwargs) + elif config.sd_sampler == SDSampler.k_lms: + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") else: raise ValueError(config.sd_sampler) @@ -156,7 +132,7 @@ class SD(InpaintModel): num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale, output_type="np.array", - callbacks=self.callbacks, + callback=self.callback, ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/sd_pipeline.py b/lama_cleaner/model/sd_pipeline.py index 827a383..b82f680 100644 --- a/lama_cleaner/model/sd_pipeline.py +++ b/lama_cleaner/model/sd_pipeline.py @@ -5,9 +5,10 @@ import numpy as np import torch import PIL -from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler +from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput -from diffusers.utils import logging +from diffusers.utils import logging, deprecate +from diffusers.configuration_utils import FrozenDict from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -59,7 +60,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -71,13 +72,37 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -113,7 +138,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): back to computing attention in one step. """ # set slice_size = `None` to disable `set_attention_slice` - self.enable_attention_slice(None) + self.enable_attention_slicing(None) @torch.no_grad() def __call__( @@ -124,11 +149,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callbacks: List[Callable[[int], None]] = None + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -141,8 +170,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be - converted to a single channel (luminance) before use. + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified @@ -157,6 +187,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -165,10 +200,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -187,58 +228,39 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # preprocess image - init_image = preprocess_image(init_image).to(self.device) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) - init_latents_orig = init_latents - - # preprocess mask - mask = preprocess_mask(mask_image).to(self.device) - mask = torch.cat([mask] * batch_size) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + self.scheduler.set_timesteps(num_inference_steps) # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) - text_encoder_device = self.text_encoder.device + text_input_ids = text_inputs.input_ids - text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_device = self.text_encoder.device + text_embeddings = self.text_encoder(text_input_ids.to(text_encoder_device))[0].to(self.device) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -246,17 +268,80 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device))[0].to(self.device) + + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # preprocess image + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(device=self.device, dtype=latents_dtype) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -267,10 +352,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): extra_step_kwargs["eta"] = eta latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + for i, t in tqdm(enumerate(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -281,25 +374,28 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) - if callbacks is not None: - for callback in callbacks: - callback(i) + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) - # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 265cc4c..a1bd7c1 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -17,6 +17,7 @@ class LDMSampler(str, Enum): class SDSampler(str, Enum): ddim = "ddim" pndm = "pndm" + k_lms = "k_lms" class Config(BaseModel): diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index f83184a..8bec742 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -82,7 +82,7 @@ def get_image_ext(img_bytes): return w -def diffuser_callback(step: int): +def diffuser_callback(i, t, latents): pass # socketio.emit('diffusion_step', {'diffusion_step': step}) @@ -129,7 +129,7 @@ def process(): ) if config.sd_seed == -1: - config.sd_seed = random.randint(1, 9999999) + config.sd_seed = random.randint(1, 999999999) logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) @@ -223,7 +223,7 @@ def main(args): sd_disable_nsfw=args.sd_disable_nsfw, sd_cpu_textencoder=args.sd_cpu_textencoder, sd_run_local=args.sd_run_local, - callbacks=[diffuser_callback], + callback=diffuser_callback, ) if args.gui: diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index 383231f..1b40f84 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -171,7 +171,7 @@ def test_sd(strategy, sampler): sd_run_local=False, sd_disable_nsfw=False, sd_cpu_textencoder=False, - callbacks=[callback]) + callback=callback) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg.sd_sampler = sampler @@ -193,9 +193,10 @@ def test_sd(strategy, sampler): @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) -@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms]) @pytest.mark.parametrize("disable_nsfw", [True, False]) -def test_sd_run_local(strategy, sampler, disable_nsfw): +@pytest.mark.parametrize("cpu_textencoder", [True, False]) +def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder): def callback(step: int): print(f"sd_step_{step}") @@ -207,7 +208,7 @@ def test_sd_run_local(strategy, sampler, disable_nsfw): hf_access_token=None, sd_run_local=True, sd_disable_nsfw=disable_nsfw, - sd_cpu_textencoder=True, + sd_cpu_textencoder=cpu_textencoder, ) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg.sd_sampler = sampler @@ -215,19 +216,11 @@ def test_sd_run_local(strategy, sampler, disable_nsfw): assert_equal( model, cfg, - f"sd_{strategy.capitalize()}_{sampler}_local_result.png", + f"sd_{strategy.capitalize()}_{sampler}_local_disablensfw_{disable_nsfw}_cputextencoder_{cpu_textencoder}_result.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) - assert_equal( - model, - cfg, - f"sd_{strategy.capitalize()}_{sampler}_blur_mask_local_result.png", - img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", - mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png", - ) - @pytest.mark.parametrize( "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]