From f0b852725f7392c9b958337f65e02f906d33e1cc Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 27 Dec 2023 22:00:07 +0800 Subject: [PATCH] lots update --- lama_cleaner/const.py | 23 +- lama_cleaner/download.py | 56 +- lama_cleaner/helper.py | 4 + lama_cleaner/model/__init__.py | 2 + lama_cleaner/model/base.py | 4 +- lama_cleaner/model/controlnet.py | 41 +- lama_cleaner/model/kandinsky.py | 3 +- lama_cleaner/model/mi_gan.py | 2 +- lama_cleaner/model/paint_by_example.py | 2 +- lama_cleaner/model/pipeline/__init__.py | 3 - ...ine_stable_diffusion_controlnet_inpaint.py | 638 ------ lama_cleaner/model/power_paint/__init__.py | 0 .../model/power_paint/pipeline_powerpaint.py | 1243 ++++++++++++ .../pipeline_powerpaint_controlnet.py | 1775 +++++++++++++++++ lama_cleaner/model/power_paint/power_paint.py | 96 + .../model/power_paint/powerpaint_tokenizer.py | 540 +++++ lama_cleaner/model/sd.py | 24 +- lama_cleaner/model/sdxl.py | 12 +- lama_cleaner/model/utils.py | 23 +- lama_cleaner/model_info.py | 100 + lama_cleaner/model_manager.py | 53 +- lama_cleaner/schema.py | 128 +- lama_cleaner/server.py | 10 +- lama_cleaner/tests/utils.py | 0 lama_cleaner/web_config.py | 16 +- web_app/src/components/Extender.tsx | 20 +- web_app/src/components/PromptInput.tsx | 8 +- web_app/src/components/Settings.tsx | 12 +- .../components/SidePanel/DiffusionOptions.tsx | 182 +- web_app/src/lib/api.ts | 13 +- web_app/src/lib/const.ts | 7 +- web_app/src/lib/states.ts | 32 +- web_app/src/lib/types.ts | 13 + 33 files changed, 4085 insertions(+), 1000 deletions(-) delete mode 100644 lama_cleaner/model/pipeline/__init__.py delete mode 100644 lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py create mode 100644 lama_cleaner/model/power_paint/__init__.py create mode 100644 lama_cleaner/model/power_paint/pipeline_powerpaint.py create mode 100644 lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py create mode 100644 lama_cleaner/model/power_paint/power_paint.py create mode 100644 lama_cleaner/model/power_paint/powerpaint_tokenizer.py create mode 100644 lama_cleaner/model_info.py create mode 100644 lama_cleaner/tests/utils.py diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index b506349..a28beba 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -4,6 +4,11 @@ from enum import Enum from pydantic import BaseModel +DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" +DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" +DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline" +DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline" + MPS_UNSUPPORT_MODELS = [ "lama", "ldm", @@ -15,22 +20,8 @@ MPS_UNSUPPORT_MODELS = [ ] DEFAULT_MODEL = "lama" -AVAILABLE_MODELS = [ - "lama", - "ldm", - "zits", - "mat", - "fcf", - "manga", - "cv2", -] -DIFFUSERS_MODEL_FP16_REVERSION = [ - "runwayml/stable-diffusion-inpainting", - "Sanster/anything-4.0-inpainting", - "Sanster/Realistic_Vision_V1.4-inpainting", - "stabilityai/stable-diffusion-2-inpainting", - "timbrooks/instruct-pix2pix", -] +AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"] + AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] DEFAULT_DEVICE = "cuda" diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index 66b14f9..6e90f4c 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -5,23 +5,23 @@ from typing import List from loguru import logger from pathlib import Path -from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR -from lama_cleaner.runtime import setup_model_dir -from lama_cleaner.schema import ( - ModelInfo, - ModelType, - DIFFUSERS_SD_INPAINT_CLASS_NAME, - DIFFUSERS_SDXL_INPAINT_CLASS_NAME, +from lama_cleaner.const import ( + DEFAULT_MODEL_DIR, DIFFUSERS_SD_CLASS_NAME, + DIFFUSERS_SD_INPAINT_CLASS_NAME, DIFFUSERS_SDXL_CLASS_NAME, + DIFFUSERS_SDXL_INPAINT_CLASS_NAME, ) +from lama_cleaner.model.utils import handle_from_pretrained_exceptions +from lama_cleaner.model_info import ModelInfo, ModelType +from lama_cleaner.runtime import setup_model_dir def cli_download_model(model: str, model_dir: Path): setup_model_dir(model_dir) from lama_cleaner.model import models - if model in models: + if model in models and models[model].is_erase_model: logger.info(f"Downloading {model}...") models[model].download() logger.info(f"Done.") @@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path): logger.info(f"Downloading model from Huggingface: {model}") from diffusers import DiffusionPipeline - downloaded_path = DiffusionPipeline.download( + downloaded_path = handle_from_pretrained_exceptions( + DiffusionPipeline.download, pretrained_model_name=model, - variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", + variant="fp16", resume_download=True, ) logger.info(f"Done. Downloaded to {downloaded_path}") @@ -43,21 +44,33 @@ def folder_name_to_show_name(name: str) -> str: def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) + stable_diffusion_dir = cache_dir / "stable_diffusion" + stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" # logger.info(f"Scanning single file sd/sdxl models in {cache_dir}") res = [] - for it in cache_dir.glob(f"*.*"): + for it in stable_diffusion_dir.glob(f"*.*"): if it.suffix not in [".safetensors", ".ckpt"]: continue if "inpaint" in str(it).lower(): - if "sdxl" in str(it).lower(): - model_type = ModelType.DIFFUSERS_SDXL_INPAINT - else: - model_type = ModelType.DIFFUSERS_SD_INPAINT + model_type = ModelType.DIFFUSERS_SD_INPAINT else: - if "sdxl" in str(it).lower(): - model_type = ModelType.DIFFUSERS_SDXL - else: - model_type = ModelType.DIFFUSERS_SD + model_type = ModelType.DIFFUSERS_SD + res.append( + ModelInfo( + name=it.name, + path=str(it.absolute()), + model_type=model_type, + is_single_file_diffusers=True, + ) + ) + + for it in stable_diffusion_xl_dir.glob(f"*.*"): + if it.suffix not in [".safetensors", ".ckpt"]: + continue + if "inpaint" in str(it).lower(): + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + model_type = ModelType.DIFFUSERS_SDXL res.append( ModelInfo( name=it.name, @@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]: name = folder_name_to_show_name(it.parent.parent.parent.name) if name in diffusers_model_names: continue - - if _class_name == DIFFUSERS_SD_CLASS_NAME: + if "PowerPaint" in name: + model_type = ModelType.DIFFUSERS_OTHER + elif _class_name == DIFFUSERS_SD_CLASS_NAME: model_type = ModelType.DIFFUSERS_SD elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: model_type = ModelType.DIFFUSERS_SD_INPAINT diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 1c12128..f48cecf 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -290,3 +290,7 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: return cv2.drawContours(new_mask, contours, max_index, 255, -1) else: return mask + + +def is_mac(): + return sys.platform == "darwin" diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py index 1892ab7..473cb99 100644 --- a/lama_cleaner/model/__init__.py +++ b/lama_cleaner/model/__init__.py @@ -9,6 +9,7 @@ from .mat import MAT from .mi_gan import MIGAN from .opencv2 import OpenCV2 from .paint_by_example import PaintByExample +from .power_paint.power_paint import PowerPaint from .sd import SD15, SD2, Anything4, RealisticVision14, SD from .sdxl import SDXL from .zits import ZITS @@ -30,4 +31,5 @@ models = { InstructPix2Pix.name: InstructPix2Pix, Kandinsky22.name: Kandinsky22, SDXL.name: SDXL, + PowerPaint.name: PowerPaint, } diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index b4b43ba..dd65d55 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -14,7 +14,7 @@ from lama_cleaner.helper import ( ) from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.utils import get_scheduler -from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo +from lama_cleaner.schema import Config, HDStrategy, SDSampler class InpaintModel: @@ -271,7 +271,7 @@ class InpaintModel: class DiffusionInpaintModel(InpaintModel): def __init__(self, device, **kwargs): - self.model_info: ModelInfo = kwargs["model_info"] + self.model_info = kwargs["model_info"] self.model_id_or_path = self.model_info.path super().__init__(device, **kwargs) diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 29591f8..b7288f8 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -5,7 +5,6 @@ import torch from diffusers import ControlNetModel, DiffusionPipeline from loguru import logger -from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.helper.controlnet_preprocess import ( make_canny_control_image, @@ -14,8 +13,8 @@ from lama_cleaner.model.helper.controlnet_preprocess import ( make_inpaint_control_image, ) from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper -from lama_cleaner.model.utils import get_scheduler -from lama_cleaner.schema import Config, ModelInfo, ModelType +from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions +from lama_cleaner.schema import Config, ModelType class ControlNet(DiffusionInpaintModel): @@ -39,11 +38,11 @@ class ControlNet(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): fp16 = not kwargs.get("no_half", False) - model_info: ModelInfo = kwargs["model_info"] - sd_controlnet_method = kwargs["sd_controlnet_method"] + model_info = kwargs["model_info"] + controlnet_method = kwargs["controlnet_method"] self.model_info = model_info - self.sd_controlnet_method = sd_controlnet_method + self.controlnet_method = controlnet_method model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): @@ -76,7 +75,8 @@ class ControlNet(DiffusionInpaintModel): ) controlnet = ControlNetModel.from_pretrained( - sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True + pretrained_model_name_or_path=controlnet_method, + resume_download=True, ) if model_info.is_single_file_diffusers: if self.model_info.model_type == ModelType.DIFFUSERS_SD: @@ -88,17 +88,12 @@ class ControlNet(DiffusionInpaintModel): model_info.path, controlnet=controlnet, **model_kwargs ).to(torch_dtype) else: - self.model = PipeClass.from_pretrained( - model_info.path, + self.model = handle_from_pretrained_exceptions( + PipeClass.from_pretrained, + pretrained_model_name_or_path=model_info.path, controlnet=controlnet, - revision="fp16" - if ( - model_info.path in DIFFUSERS_MODEL_FP16_REVERSION - and use_gpu - and fp16 - ) - else "main", - torch_dtype=torch_dtype, + variant="fp16", + dtype=torch_dtype, **model_kwargs, ) @@ -116,23 +111,23 @@ class ControlNet(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) def switch_controlnet_method(self, new_method: str): - self.sd_controlnet_method = new_method + self.controlnet_method = new_method controlnet = ControlNetModel.from_pretrained( new_method, torch_dtype=self.torch_dtype, resume_download=True ).to(self.model.device) self.model.controlnet = controlnet def _get_control_image(self, image, mask): - if "canny" in self.sd_controlnet_method: + if "canny" in self.controlnet_method: control_image = make_canny_control_image(image) - elif "openpose" in self.sd_controlnet_method: + elif "openpose" in self.controlnet_method: control_image = make_openpose_control_image(image) - elif "depth" in self.sd_controlnet_method: + elif "depth" in self.controlnet_method: control_image = make_depth_control_image(image) - elif "inpaint" in self.sd_controlnet_method: + elif "inpaint" in self.controlnet_method: control_image = make_inpaint_control_image(image, mask) else: - raise NotImplementedError(f"{self.sd_controlnet_method} not implemented") + raise NotImplementedError(f"{self.controlnet_method} not implemented") return control_image def forward(self, image, mask, config: Config): diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 783a01b..0645af7 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel): } self.model = AutoPipelineForInpainting.from_pretrained( - self.model_id_or_path, **model_kwargs + self.name, **model_kwargs ).to(device) self.callback = kwargs.pop("callback", None) @@ -66,4 +66,3 @@ class Kandinsky(DiffusionInpaintModel): class Kandinsky22(Kandinsky): name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" - model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint" diff --git a/lama_cleaner/model/mi_gan.py b/lama_cleaner/model/mi_gan.py index 3e3f200..d8ec0fa 100644 --- a/lama_cleaner/model/mi_gan.py +++ b/lama_cleaner/model/mi_gan.py @@ -16,7 +16,7 @@ from lama_cleaner.model.base import InpaintModel MIGAN_MODEL_URL = os.environ.get( "MIGAN_MODEL_URL", - "/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt", + "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt", ) MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c") diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index 80b9745..07d3842 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -28,7 +28,7 @@ class PaintByExample(DiffusionInpaintModel): ) self.model = DiffusionPipeline.from_pretrained( - "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs + self.name, torch_dtype=torch_dtype, **model_kwargs ) # TODO: gpu_id diff --git a/lama_cleaner/model/pipeline/__init__.py b/lama_cleaner/model/pipeline/__init__.py deleted file mode 100644 index 9056bc6..0000000 --- a/lama_cleaner/model/pipeline/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .pipeline_stable_diffusion_controlnet_inpaint import ( - StableDiffusionControlNetInpaintPipeline, -) diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py deleted file mode 100644 index f65e95d..0000000 --- a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import gc -from typing import Union, List, Optional, Callable, Dict, Any - -# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py - -import torch -import PIL.Image - -from diffusers.pipelines.controlnet.pipeline_controlnet import * -from diffusers.utils import replace_example_docstring - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - >>> # download an image - >>> image = load_image( - ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - ... ) - >>> image = np.array(image) - >>> mask_image = load_image( - ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - ... ) - >>> mask_image = np.array(mask_image) - >>> # get canny image - >>> canny_image = cv2.Canny(image, 100, 200) - >>> canny_image = canny_image[:, :, None] - >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) - >>> canny_image = Image.fromarray(canny_image) - - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( - ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() - - >>> pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.manual_seed(0) - >>> image = pipe( - ... "futuristic-looking doggo", - ... num_inference_steps=20, - ... generator=generator, - ... image=image, - ... control_image=canny_image, - ... mask_image=mask_image - ... ).images[0] - ``` -""" - - -def prepare_mask_and_masked_image(image, mask): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError( - f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" - ) - - # Batch single image - if image.ndim == 3: - assert ( - image.shape[0] == 3 - ), "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert ( - image.ndim == 4 and mask.ndim == 4 - ), "Image and Mask must have 4 dimensions" - assert ( - image.shape[-2:] == mask.shape[-2:] - ), "Image and Mask must have the same spatial dimensions" - assert ( - image.shape[0] == mask.shape[0] - ), "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError( - f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" - ) - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = np.concatenate( - [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 - ) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - return mask, masked_image - - -class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance. - - This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`ControlNetModel`]): - Provides additional conditioning to the unet during the denoising process - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - @classmethod - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_from_original_stable_diffusion_ckpt, - ) - - controlnet = kwargs.pop("controlnet", None) - - pipe = download_from_original_stable_diffusion_ckpt( - pretrained_model_link_or_path, - num_in_channels=9, - from_safetensors=pretrained_model_link_or_path.endswith("safetensors"), - device="cpu", - load_safety_checker=False, - ) - - inpaint_pipe = cls( - vae=pipe.vae, - text_encoder=pipe.text_encoder, - tokenizer=pipe.tokenizer, - unet=pipe.unet, - controlnet=controlnet, - scheduler=pipe.scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - - del pipe - gc.collect() - return inpaint_pipe - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - height, - width, - dtype, - device, - generator, - do_classifier_free_guidance, - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample( - generator=generator[i] - ) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample( - generator=generator - ) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - torch.cat([masked_image_latents] * 2) - if do_classifier_free_guidance - else masked_image_latents - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - - def _default_height_width(self, height, width, image): - if isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[3] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[2] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - control_image: Union[ - torch.FloatTensor, - PIL.Image.Image, - List[torch.FloatTensor], - List[PIL.Image.Image], - ] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, - ): - r""" - Function invoked when calling the pipeline for generation. - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can - also be accepted as an image. The control image is automatically resized to fit the output image. - mask_image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. - Examples: - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, control_image) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt=prompt, - image=control_image, - callback_steps=callback_steps, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare image - control_image = self.prepare_image( - control_image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) - - if do_classifier_free_guidance: - control_image = torch.cat([control_image] * 2) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.controlnet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # EXTRA: prepare mask latents - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - - down_block_res_samples, mid_block_res_sample = self.controlnet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=control_image, - return_dict=False, - ) - - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - - # predict the noise residual - latent_model_input = torch.cat( - [latent_model_input, mask, masked_image_latents], dim=1 - ) - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) diff --git a/lama_cleaner/model/power_paint/__init__.py b/lama_cleaner/model/power_paint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lama_cleaner/model/power_paint/pipeline_powerpaint.py b/lama_cleaner/model/power_paint/pipeline_powerpaint.py new file mode 100644 index 0000000..9b7f8c5 --- /dev/null +++ b/lama_cleaner/model/power_paint/pipeline_powerpaint.py @@ -0,0 +1,1243 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image( + image, mask, height, width, return_image: bool = False +): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError( + f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" + ) + + # Batch single image + if image.ndim == 3: + assert ( + image.shape[0] == 3 + ), "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" + assert ( + image.shape[-2:] == mask.shape[-2:] + ), "Image and Mask must have the same spatial dimensions" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [ + i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image + ] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "skip_prk_steps") + and scheduler.config.skip_prk_steps is False + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate( + "skip_prk_steps not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.info( + f"You have loaded a UNet with {unet.config.in_channels} input channels which." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a + time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. + Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the + iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook + ) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook( + self.safety_checker, device, prev_module_hook=hook + ) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + promptA, + promptB, + t, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA=None, + negative_promptB=None, + t_nag=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + prompt = promptA + negative_prompt = negative_promptA + + if promptA is not None and isinstance(promptA, str): + batch_size = 1 + elif promptA is not None and isinstance(promptA, list): + batch_size = len(promptA) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + promptA = self.maybe_convert_prompt(promptA, self.tokenizer) + + text_inputsA = self.tokenizer( + promptA, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputsB = self.tokenizer( + promptB, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_idsA = text_inputsA.input_ids + text_input_idsB = text_inputsB.input_ids + untruncated_ids = self.tokenizer( + promptA, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_idsA.shape[ + -1 + ] and not torch.equal(text_input_idsA, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputsA.attention_mask.to(device) + else: + attention_mask = None + + # print("text_input_idsA: ",text_input_idsA) + # print("text_input_idsB: ",text_input_idsB) + # print('t: ',t) + + prompt_embedsA = self.text_encoder( + text_input_idsA.to(device), + attention_mask=attention_mask, + ) + prompt_embedsA = prompt_embedsA[0] + + prompt_embedsB = self.text_encoder( + text_input_idsB.to(device), + attention_mask=attention_mask, + ) + prompt_embedsB = prompt_embedsB[0] + prompt_embeds = prompt_embedsA * (t) + (1 - t) * prompt_embedsB + # print("prompt_embeds: ",prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokensA: List[str] + uncond_tokensB: List[str] + if negative_prompt is None: + uncond_tokensA = [""] * batch_size + uncond_tokensB = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokensA = [negative_promptA] + uncond_tokensB = [negative_promptB] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokensA = negative_promptA + uncond_tokensB = negative_promptB + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokensA = self.maybe_convert_prompt( + uncond_tokensA, self.tokenizer + ) + uncond_tokensB = self.maybe_convert_prompt( + uncond_tokensB, self.tokenizer + ) + + max_length = prompt_embeds.shape[1] + uncond_inputA = self.tokenizer( + uncond_tokensA, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_inputB = self.tokenizer( + uncond_tokensB, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_inputA.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embedsA = self.text_encoder( + uncond_inputA.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embedsB = self.text_encoder( + uncond_inputB.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = ( + negative_prompt_embedsA[0] * (t_nag) + + (1 - t_nag) * negative_prompt_embedsB[0] + ) + + # negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # print("prompt_embeds: ",prompt_embeds) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = ( + noise + if is_strength_max + else self.scheduler.add_noise(image_latents, noise, timestep) + ) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = ( + latents * self.scheduler.init_noise_sigma + if is_strength_max + else latents + ) + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample( + generator=generator[i] + ) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample( + generator=generator + ) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + task_class: Union[torch.Tensor, float, int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked + out with `mask_image` and repainted according to `prompt`). + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + prompt = promptA + negative_prompt = negative_promptA + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + mask_condition = mask.clone() + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.unet.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + if num_channels_unet == 9: + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) + + # predict the noise residual + if task_class is not None: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + task_class=task_class, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = ( + 1 - init_mask + ) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to( + device=device, dtype=masked_image_latents.dtype + ) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to( + device=device, dtype=masked_image_latents.dtype + ) + condition_kwargs = { + "image": init_image_condition, + "mask": mask_condition, + } + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + **condition_kwargs, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py b/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py new file mode 100644 index 0000000..cba0f8f --- /dev/null +++ b/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py @@ -0,0 +1,1775 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor,is_compiled_module +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.controlnet import MultiControlNetModel + + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image + + + >>> control_image = make_inpaint_condition(init_image, mask_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + promptA, + promptB, + t, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA=None, + negative_promptB=None, + t_nag = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + prompt = promptA + negative_prompt = negative_promptA + + if promptA is not None and isinstance(promptA, str): + batch_size = 1 + elif promptA is not None and isinstance(promptA, list): + batch_size = len(promptA) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + promptA = self.maybe_convert_prompt(promptA, self.tokenizer) + + text_inputsA = self.tokenizer( + promptA, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputsB = self.tokenizer( + promptB, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_idsA = text_inputsA.input_ids + text_input_idsB = text_inputsB.input_ids + untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal( + text_input_idsA, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputsA.attention_mask.to(device) + else: + attention_mask = None + + # print("text_input_idsA: ",text_input_idsA) + # print("text_input_idsB: ",text_input_idsB) + # print('t: ',t) + + prompt_embedsA = self.text_encoder( + text_input_idsA.to(device), + attention_mask=attention_mask, + ) + prompt_embedsA = prompt_embedsA[0] + + prompt_embedsB = self.text_encoder( + text_input_idsB.to(device), + attention_mask=attention_mask, + ) + prompt_embedsB = prompt_embedsB[0] + prompt_embeds = prompt_embedsA*(t)+(1-t)*prompt_embedsB + # print("prompt_embeds: ",prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokensA: List[str] + uncond_tokensB: List[str] + if negative_prompt is None: + uncond_tokensA = [""] * batch_size + uncond_tokensB = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokensA = [negative_promptA] + uncond_tokensB = [negative_promptB] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokensA = negative_promptA + uncond_tokensB = negative_promptB + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer) + uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_inputA = self.tokenizer( + uncond_tokensA, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_inputB = self.tokenizer( + uncond_tokensB, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_inputA.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embedsA = self.text_encoder( + uncond_inputA.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embedsB = self.text_encoder( + uncond_inputB.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embedsA[0]*(t_nag)+(1-t_nag)*negative_prompt_embedsB[0] + + # negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # print("prompt_embeds: ",prompt_embeds) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + def predict_woControl( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + task_class: Union[torch.Tensor, float, int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked + out with `mask_image` and repainted according to `prompt`). + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + prompt = promptA + negative_prompt = negative_promptA + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + mask_condition = mask.clone() + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + if task_class is not None: + noise_pred = self.unet( + sample = latent_model_input, + timestep = t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + task_class = task_class, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + prompt = promptA + negative_prompt = negative_promptA + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/lama_cleaner/model/power_paint/power_paint.py b/lama_cleaner/model/power_paint/power_paint.py new file mode 100644 index 0000000..014f403 --- /dev/null +++ b/lama_cleaner/model/power_paint/power_paint.py @@ -0,0 +1,96 @@ +from PIL import Image +import PIL.Image +import cv2 +import torch +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper +from lama_cleaner.model.utils import handle_from_pretrained_exceptions +from lama_cleaner.schema import Config +from .powerpaint_tokenizer import add_task_to_prompt + + +class PowerPaint(DiffusionInpaintModel): + name = "Sanster/PowerPaint-V1-stable-diffusion-inpainting" + pad_mod = 8 + min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + + def init_model(self, device: torch.device, **kwargs): + from .pipeline_powerpaint import StableDiffusionInpaintPipeline + from .powerpaint_tokenizer import PowerPaintTokenizer + + fp16 = not kwargs.get("no_half", False) + model_kwargs = {} + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + self.model = handle_from_pretrained_exceptions( + StableDiffusionInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.name, + variant="fp16", + torch_dtype=torch_dtype, + **model_kwargs, + ) + self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + img_h, img_w = image.shape[:2] + promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt( + config.prompt, config.negative_prompt, config.powerpaint_task + ) + + output = self.model( + image=PIL.Image.fromarray(image), + promptA=promptA, + promptB=promptB, + tradoff=config.fitting_degree, + tradoff_nag=config.fitting_degree, + negative_promptA=negative_promptA, + negative_promptB=negative_promptB, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + strength=config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + callback_steps=1, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/lama_cleaner/model/power_paint/powerpaint_tokenizer.py b/lama_cleaner/model/power_paint/powerpaint_tokenizer.py new file mode 100644 index 0000000..887ba29 --- /dev/null +++ b/lama_cleaner/model/power_paint/powerpaint_tokenizer.py @@ -0,0 +1,540 @@ +import torch +import torch.nn as nn +import copy +import random +from typing import Any, List, Optional, Union +from transformers import CLIPTokenizer + +from lama_cleaner.schema import PowerPaintTask + + +def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask): + if task == PowerPaintTask.object_remove: + promptA = prompt + " P_ctxt" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + " P_obj" + negative_promptB = negative_prompt + " P_obj" + elif task == PowerPaintTask.shape_guided: + promptA = prompt + " P_shape" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + negative_promptB = negative_prompt + elif task == PowerPaintTask.outpainting: + promptA = prompt + " P_ctxt" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + " P_obj" + negative_promptB = negative_prompt + " P_obj" + else: + promptA = prompt + " P_obj" + promptB = prompt + " P_obj" + negative_promptA = negative_prompt + negative_promptB = negative_prompt + + return promptA, promptB, negative_promptA, negative_promptB + + +class PowerPaintTokenizer: + def __init__(self, tokenizer: CLIPTokenizer): + self.wrapped = tokenizer + self.token_map = {} + placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"] + num_vec_per_token = 10 + for placeholder_token in placeholder_tokens: + output = [] + for i in range(num_vec_per_token): + ith_token = placeholder_token + f"_{i}" + output.append(ith_token) + self.token_map[placeholder_token] = output + + def __getattr__(self, name: str) -> Any: + if name == "wrapped": + return super().__getattr__("wrapped") + + try: + return getattr(self.wrapped, name) + except AttributeError: + try: + return super().__getattr__(name) + except AttributeError: + raise AttributeError( + "'name' cannot be found in both " + f"'{self.__class__.__name__}' and " + f"'{self.__class__.__name__}.tokenizer'." + ) + + def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs): + """Attempt to add tokens to the tokenizer. + + Args: + tokens (Union[str, List[str]]): The tokens to be added. + """ + num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs) + assert num_added_tokens != 0, ( + f"The tokenizer already contains the token {tokens}. Please pass " + "a different `placeholder_token` that is not already in the " + "tokenizer." + ) + + def get_token_info(self, token: str) -> dict: + """Get the information of a token, including its start and end index in + the current tokenizer. + + Args: + token (str): The token to be queried. + + Returns: + dict: The information of the token, including its start and end + index in current tokenizer. + """ + token_ids = self.__call__(token).input_ids + start, end = token_ids[1], token_ids[-2] + 1 + return {"name": token, "start": start, "end": end} + + def add_placeholder_token( + self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs + ): + """Add placeholder tokens to the tokenizer. + + Args: + placeholder_token (str): The placeholder token to be added. + num_vec_per_token (int, optional): The number of vectors of + the added placeholder token. + *args, **kwargs: The arguments for `self.wrapped.add_tokens`. + """ + output = [] + if num_vec_per_token == 1: + self.try_adding_tokens(placeholder_token, *args, **kwargs) + output.append(placeholder_token) + else: + output = [] + for i in range(num_vec_per_token): + ith_token = placeholder_token + f"_{i}" + self.try_adding_tokens(ith_token, *args, **kwargs) + output.append(ith_token) + + for token in self.token_map: + if token in placeholder_token: + raise ValueError( + f"The tokenizer already has placeholder token {token} " + f"that can get confused with {placeholder_token} " + "keep placeholder tokens independent" + ) + self.token_map[placeholder_token] = output + + def replace_placeholder_tokens_in_text( + self, + text: Union[str, List[str]], + vector_shuffle: bool = False, + prop_tokens_to_load: float = 1.0, + ) -> Union[str, List[str]]: + """Replace the keywords in text with placeholder tokens. This function + will be called in `self.__call__` and `self.encode`. + + Args: + text (Union[str, List[str]]): The text to be processed. + vector_shuffle (bool, optional): Whether to shuffle the vectors. + Defaults to False. + prop_tokens_to_load (float, optional): The proportion of tokens to + be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0. + + Returns: + Union[str, List[str]]: The processed text. + """ + if isinstance(text, list): + output = [] + for i in range(len(text)): + output.append( + self.replace_placeholder_tokens_in_text( + text[i], vector_shuffle=vector_shuffle + ) + ) + return output + + for placeholder_token in self.token_map: + if placeholder_token in text: + tokens = self.token_map[placeholder_token] + tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] + if vector_shuffle: + tokens = copy.copy(tokens) + random.shuffle(tokens) + text = text.replace(placeholder_token, " ".join(tokens)) + return text + + def replace_text_with_placeholder_tokens( + self, text: Union[str, List[str]] + ) -> Union[str, List[str]]: + """Replace the placeholder tokens in text with the original keywords. + This function will be called in `self.decode`. + + Args: + text (Union[str, List[str]]): The text to be processed. + + Returns: + Union[str, List[str]]: The processed text. + """ + if isinstance(text, list): + output = [] + for i in range(len(text)): + output.append(self.replace_text_with_placeholder_tokens(text[i])) + return output + + for placeholder_token, tokens in self.token_map.items(): + merged_tokens = " ".join(tokens) + if merged_tokens in text: + text = text.replace(merged_tokens, placeholder_token) + return text + + def __call__( + self, + text: Union[str, List[str]], + *args, + vector_shuffle: bool = False, + prop_tokens_to_load: float = 1.0, + **kwargs, + ): + """The call function of the wrapper. + + Args: + text (Union[str, List[str]]): The text to be tokenized. + vector_shuffle (bool, optional): Whether to shuffle the vectors. + Defaults to False. + prop_tokens_to_load (float, optional): The proportion of tokens to + be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0 + *args, **kwargs: The arguments for `self.wrapped.__call__`. + """ + replaced_text = self.replace_placeholder_tokens_in_text( + text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load + ) + + return self.wrapped.__call__(replaced_text, *args, **kwargs) + + def encode(self, text: Union[str, List[str]], *args, **kwargs): + """Encode the passed text to token index. + + Args: + text (Union[str, List[str]]): The text to be encode. + *args, **kwargs: The arguments for `self.wrapped.__call__`. + """ + replaced_text = self.replace_placeholder_tokens_in_text(text) + return self.wrapped(replaced_text, *args, **kwargs) + + def decode( + self, token_ids, return_raw: bool = False, *args, **kwargs + ) -> Union[str, List[str]]: + """Decode the token index to text. + + Args: + token_ids: The token index to be decoded. + return_raw: Whether keep the placeholder token in the text. + Defaults to False. + *args, **kwargs: The arguments for `self.wrapped.decode`. + + Returns: + Union[str, List[str]]: The decoded text. + """ + text = self.wrapped.decode(token_ids, *args, **kwargs) + if return_raw: + return text + replaced_text = self.replace_text_with_placeholder_tokens(text) + return replaced_text + + +class EmbeddingLayerWithFixes(nn.Module): + """The revised embedding layer to support external embeddings. This design + of this class is inspired by https://github.com/AUTOMATIC1111/stable- + diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi + jack.py#L224 # noqa. + + Args: + wrapped (nn.Emebdding): The embedding layer to be wrapped. + external_embeddings (Union[dict, List[dict]], optional): The external + embeddings added to this layer. Defaults to None. + """ + + def __init__( + self, + wrapped: nn.Embedding, + external_embeddings: Optional[Union[dict, List[dict]]] = None, + ): + super().__init__() + self.wrapped = wrapped + self.num_embeddings = wrapped.weight.shape[0] + + self.external_embeddings = [] + if external_embeddings: + self.add_embeddings(external_embeddings) + + self.trainable_embeddings = nn.ParameterDict() + + @property + def weight(self): + """Get the weight of wrapped embedding layer.""" + return self.wrapped.weight + + def check_duplicate_names(self, embeddings: List[dict]): + """Check whether duplicate names exist in list of 'external + embeddings'. + + Args: + embeddings (List[dict]): A list of embedding to be check. + """ + names = [emb["name"] for emb in embeddings] + assert len(names) == len(set(names)), ( + "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'" + ) + + def check_ids_overlap(self, embeddings): + """Check whether overlap exist in token ids of 'external_embeddings'. + + Args: + embeddings (List[dict]): A list of embedding to be check. + """ + ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings] + ids_range.sort() # sort by 'start' + # check if 'end' has overlapping + for idx in range(len(ids_range) - 1): + name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1] + assert ids_range[idx][1] <= ids_range[idx + 1][0], ( + f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'." + ) + + def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]): + """Add external embeddings to this layer. + + Use case: + + >>> 1. Add token to tokenizer and get the token id. + >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32') + >>> # 'how much' in kiswahili + >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4) + >>> + >>> 2. Add external embeddings to the model. + >>> new_embedding = { + >>> 'name': 'ngapi', # 'how much' in kiswahili + >>> 'embedding': torch.ones(1, 15) * 4, + >>> 'start': tokenizer.get_token_info('kwaheri')['start'], + >>> 'end': tokenizer.get_token_info('kwaheri')['end'], + >>> 'trainable': False # if True, will registry as a parameter + >>> } + >>> embedding_layer = nn.Embedding(10, 15) + >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer) + >>> embedding_layer_wrapper.add_embeddings(new_embedding) + >>> + >>> 3. Forward tokenizer and embedding layer! + >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?'] + >>> input_ids = tokenizer( + >>> input_text, padding='max_length', truncation=True, + >>> return_tensors='pt')['input_ids'] + >>> out_feat = embedding_layer_wrapper(input_ids) + >>> + >>> 4. Let's validate the result! + >>> assert (out_feat[0, 3: 7] == 2.3).all() + >>> assert (out_feat[2, 5: 9] == 2.3).all() + + Args: + embeddings (Union[dict, list[dict]]): The external embeddings to + be added. Each dict must contain the following 4 fields: 'name' + (the name of this embedding), 'embedding' (the embedding + tensor), 'start' (the start token id of this embedding), 'end' + (the end token id of this embedding). For example: + `{name: NAME, start: START, end: END, embedding: torch.Tensor}` + """ + if isinstance(embeddings, dict): + embeddings = [embeddings] + + self.external_embeddings += embeddings + self.check_duplicate_names(self.external_embeddings) + self.check_ids_overlap(self.external_embeddings) + + # set for trainable + added_trainable_emb_info = [] + for embedding in embeddings: + trainable = embedding.get("trainable", False) + if trainable: + name = embedding["name"] + embedding["embedding"] = torch.nn.Parameter(embedding["embedding"]) + self.trainable_embeddings[name] = embedding["embedding"] + added_trainable_emb_info.append(name) + + added_emb_info = [emb["name"] for emb in embeddings] + added_emb_info = ", ".join(added_emb_info) + print(f"Successfully add external embeddings: {added_emb_info}.", "current") + + if added_trainable_emb_info: + added_trainable_emb_info = ", ".join(added_trainable_emb_info) + print( + "Successfully add trainable external embeddings: " + f"{added_trainable_emb_info}", + "current", + ) + + def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Replace external input ids to 0. + + Args: + input_ids (torch.Tensor): The input ids to be replaced. + + Returns: + torch.Tensor: The replaced input ids. + """ + input_ids_fwd = input_ids.clone() + input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0 + return input_ids_fwd + + def replace_embeddings( + self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict + ) -> torch.Tensor: + """Replace external embedding to the embedding layer. Noted that, in + this function we use `torch.cat` to avoid inplace modification. + + Args: + input_ids (torch.Tensor): The original token ids. Shape like + [LENGTH, ]. + embedding (torch.Tensor): The embedding of token ids after + `replace_input_ids` function. + external_embedding (dict): The external embedding to be replaced. + + Returns: + torch.Tensor: The replaced embedding. + """ + new_embedding = [] + + name = external_embedding["name"] + start = external_embedding["start"] + end = external_embedding["end"] + target_ids_to_replace = [i for i in range(start, end)] + ext_emb = external_embedding["embedding"] + + # do not need to replace + if not (input_ids == start).any(): + return embedding + + # start replace + s_idx, e_idx = 0, 0 + while e_idx < len(input_ids): + if input_ids[e_idx] == start: + if e_idx != 0: + # add embedding do not need to replace + new_embedding.append(embedding[s_idx:e_idx]) + + # check if the next embedding need to replace is valid + actually_ids_to_replace = [ + int(i) for i in input_ids[e_idx : e_idx + end - start] + ] + assert actually_ids_to_replace == target_ids_to_replace, ( + f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. " + f"Expect '{target_ids_to_replace}' for embedding " + f"'{name}' but found '{actually_ids_to_replace}'." + ) + + new_embedding.append(ext_emb) + + s_idx = e_idx + end - start + e_idx = s_idx + 1 + else: + e_idx += 1 + + if e_idx == len(input_ids): + new_embedding.append(embedding[s_idx:e_idx]) + + return torch.cat(new_embedding, dim=0) + + def forward( + self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None + ): + """The forward function. + + Args: + input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or + [LENGTH, ]. + external_embeddings (Optional[List[dict]]): The external + embeddings. If not passed, only `self.external_embeddings` + will be used. Defaults to None. + + input_ids: shape like [bz, LENGTH] or [LENGTH]. + """ + assert input_ids.ndim in [1, 2] + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + if external_embeddings is None and not self.external_embeddings: + return self.wrapped(input_ids) + + input_ids_fwd = self.replace_input_ids(input_ids) + inputs_embeds = self.wrapped(input_ids_fwd) + + vecs = [] + + if external_embeddings is None: + external_embeddings = [] + elif isinstance(external_embeddings, dict): + external_embeddings = [external_embeddings] + embeddings = self.external_embeddings + external_embeddings + + for input_id, embedding in zip(input_ids, inputs_embeds): + new_embedding = embedding + for external_embedding in embeddings: + new_embedding = self.replace_embeddings( + input_id, new_embedding, external_embedding + ) + vecs.append(new_embedding) + + return torch.stack(vecs) + + +def add_tokens( + tokenizer, + text_encoder, + placeholder_tokens: list, + initialize_tokens: list = None, + num_vectors_per_token: int = 1, +): + """Add token for training. + + # TODO: support add tokens as dict, then we can load pretrained tokens. + """ + if initialize_tokens is not None: + assert len(initialize_tokens) == len( + placeholder_tokens + ), "placeholder_token should be the same length as initialize_token" + for ii in range(len(placeholder_tokens)): + tokenizer.add_placeholder_token( + placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token + ) + + # text_encoder.set_embedding_layer() + embedding_layer = text_encoder.text_model.embeddings.token_embedding + text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes( + embedding_layer + ) + embedding_layer = text_encoder.text_model.embeddings.token_embedding + + assert embedding_layer is not None, ( + "Do not support get embedding layer for current text encoder. " + "Please check your configuration." + ) + initialize_embedding = [] + if initialize_tokens is not None: + for ii in range(len(placeholder_tokens)): + init_id = tokenizer(initialize_tokens[ii]).input_ids[1] + temp_embedding = embedding_layer.weight[init_id] + initialize_embedding.append( + temp_embedding[None, ...].repeat(num_vectors_per_token, 1) + ) + else: + for ii in range(len(placeholder_tokens)): + init_id = tokenizer("a").input_ids[1] + temp_embedding = embedding_layer.weight[init_id] + len_emb = temp_embedding.shape[0] + init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0 + initialize_embedding.append(init_weight) + + # initialize_embedding = torch.cat(initialize_embedding,dim=0) + + token_info_all = [] + for ii in range(len(placeholder_tokens)): + token_info = tokenizer.get_token_info(placeholder_tokens[ii]) + token_info["embedding"] = initialize_embedding[ii] + token_info["trainable"] = True + token_info_all.append(token_info) + embedding_layer.add_embeddings(token_info_all) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 56134e2..1c7fcdd 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -3,9 +3,9 @@ import cv2 import torch from loguru import logger -from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper +from lama_cleaner.model.utils import handle_from_pretrained_exceptions from lama_cleaner.schema import Config, ModelType @@ -40,20 +40,18 @@ class SD(DiffusionInpaintModel): model_kwargs["num_in_channels"] = 9 self.model = StableDiffusionInpaintPipeline.from_single_file( - self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs + self.model_id_or_path, dtype=torch_dtype, **model_kwargs ) else: - self.model = StableDiffusionInpaintPipeline.from_pretrained( - self.model_id_or_path, - revision="fp16" - if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION - else "main", - torch_dtype=torch_dtype, + self.model = handle_from_pretrained_exceptions( + StableDiffusionInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.model_id_or_path, + variant="fp16", + dtype=torch_dtype, **model_kwargs, ) if kwargs.get("cpu_offload", False) and use_gpu: - # TODO: gpu_id logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) else: @@ -98,20 +96,20 @@ class SD(DiffusionInpaintModel): class SD15(SD): - name = "sd1.5" + name = "runwayml/stable-diffusion-inpainting" model_id_or_path = "runwayml/stable-diffusion-inpainting" class Anything4(SD): - name = "anything4" + name = "Sanster/anything-4.0-inpainting" model_id_or_path = "Sanster/anything-4.0-inpainting" class RealisticVision14(SD): - name = "realisticVision1.4" + name = "Sanster/Realistic_Vision_V1.4-inpainting" model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting" class SD2(SD): - name = "sd2" + name = "stabilityai/stable-diffusion-2-inpainting" model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index 428c670..b25451f 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -8,11 +8,12 @@ from diffusers import AutoencoderKL from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import handle_from_pretrained_exceptions from lama_cleaner.schema import Config, ModelType class SDXL(DiffusionInpaintModel): - name = "sdxl" + name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" pad_mod = 8 min_size = 512 lcm_lora_id = "latent-consistency/lcm-lora-sdxl" @@ -34,18 +35,19 @@ class SDXL(DiffusionInpaintModel): if os.path.isfile(self.model_id_or_path): self.model = StableDiffusionXLInpaintPipeline.from_single_file( self.model_id_or_path, - torch_dtype=torch_dtype, + dtype=torch_dtype, num_in_channels=num_in_channels, ) else: vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype ) - self.model = StableDiffusionXLInpaintPipeline.from_pretrained( - self.model_id_or_path, - revision="main", + self.model = handle_from_pretrained_exceptions( + StableDiffusionXLInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.model_id_or_path, torch_dtype=torch_dtype, vae=vae, + variant="fp16", ) if kwargs.get("cpu_offload", False) and use_gpu: diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py index 7cbea95..18fcf6c 100644 --- a/lama_cleaner/model/utils.py +++ b/lama_cleaner/model/utils.py @@ -1,6 +1,7 @@ import gc import math import random +import traceback from typing import Any import torch @@ -16,8 +17,11 @@ from diffusers import ( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, - LCMScheduler + LCMScheduler, ) +from huggingface_hub.utils import RevisionNotFoundError +from loguru import logger +from requests import HTTPError from lama_cleaner.schema import SDSampler from torch import conv2d, conv_transpose2d @@ -944,3 +948,20 @@ def get_scheduler(sd_sampler, scheduler_config): return LCMScheduler.from_config(scheduler_config) else: raise ValueError(sd_sampler) + + +def handle_from_pretrained_exceptions(func, **kwargs): + try: + return func(**kwargs) + except ValueError as e: + # 处理异常的逻辑 + if "You are trying to load the model files of the `variant=fp16`" in str(e): + logger.info("variant=fp16 not found, try revision=fp16") + return func(**{**kwargs, "variant": None, "revision": "fp16"}) + except OSError as e: + previous_traceback = traceback.format_exc() + if "RevisionNotFoundError: 404 Client Error." in previous_traceback: + logger.info("revision=fp16 not found, try revision=main") + return func(**{**kwargs, "variant": None, "revision": "main"}) + except Exception as e: + raise e diff --git a/lama_cleaner/model_info.py b/lama_cleaner/model_info.py new file mode 100644 index 0000000..199018b --- /dev/null +++ b/lama_cleaner/model_info.py @@ -0,0 +1,100 @@ +from enum import Enum +from typing import List + +from pydantic import computed_field, BaseModel + +from lama_cleaner.const import ( + SDXL_CONTROLNET_CHOICES, + SD2_CONTROLNET_CHOICES, + SD_CONTROLNET_CHOICES, +) +from lama_cleaner.model import InstructPix2Pix, Kandinsky22, PowerPaint, SD2 +from lama_cleaner.schema import ModelType + + +class ModelInfo(BaseModel): + name: str + path: str + model_type: ModelType + is_single_file_diffusers: bool = False + + @computed_field + @property + def need_prompt(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [ + InstructPix2Pix.name, + Kandinsky22.name, + PowerPaint.name, + ] + + @computed_field + @property + def controlnets(self) -> List[str]: + if self.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + return SDXL_CONTROLNET_CHOICES + if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: + if self.name in [SD2.name]: + return SD2_CONTROLNET_CHOICES + else: + return SD_CONTROLNET_CHOICES + if self.name == PowerPaint.name: + return SD_CONTROLNET_CHOICES + return [] + + @computed_field + @property + def support_strength(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_outpainting(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [Kandinsky22.name, PowerPaint.name] + + @computed_field + @property + def support_lcm_lora(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_controlnet(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [PowerPaint.name] + + @computed_field + @property + def support_freeu(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [InstructPix2Pix.name] diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 01f85a3..f7fcfa0 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -7,7 +7,8 @@ from lama_cleaner.download import scan_models from lama_cleaner.helper import switch_mps_device from lama_cleaner.model import models, ControlNet, SD, SDXL from lama_cleaner.model.utils import torch_gc -from lama_cleaner.schema import Config, ModelInfo, ModelType +from lama_cleaner.model_info import ModelInfo, ModelType +from lama_cleaner.schema import Config class ModelManager: @@ -18,13 +19,20 @@ class ModelManager: self.available_models: Dict[str, ModelInfo] = {} self.scan_models() - self.sd_controlnet = False - self.sd_controlnet_method = "" + self.enable_controlnet = kwargs.get("enable_controlnet", False) + controlnet_method = kwargs.get("controlnet_method", None) + if ( + controlnet_method is None + and name in self.available_models + and self.available_models[name].support_controlnet + ): + controlnet_method = self.available_models[name].controlnets[0] + self.controlnet_method = controlnet_method self.model = self.init_model(name, device, **kwargs) @property def current_model(self) -> Dict: - return self.available_models[name].model_dump() + return self.available_models[self.name].model_dump() def init_model(self, name: str, device, **kwargs): logger.info(f"Loading model: {name}") @@ -35,15 +43,14 @@ class ModelManager: kwargs = { **kwargs, "model_info": model_info, - "sd_controlnet": self.sd_controlnet, - "sd_controlnet_method": self.sd_controlnet_method, + "enable_controlnet": self.enable_controlnet, + "controlnet_method": self.controlnet_method, } - if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]: - return models[name](device, **kwargs) - - if self.sd_controlnet: + if model_info.support_controlnet and self.enable_controlnet: return ControlNet(device, **kwargs) + elif model_info.name in models: + return models[name](device, **kwargs) else: if model_info.model_type in [ ModelType.DIFFUSERS_SD_INPAINT, @@ -75,15 +82,15 @@ class ModelManager: return old_name = self.name - old_sd_controlnet_method = self.sd_controlnet_method + old_controlnet_method = self.controlnet_method self.name = new_name if ( self.available_models[new_name].support_controlnet - and self.sd_controlnet_method + and self.controlnet_method not in self.available_models[new_name].controlnets ): - self.sd_controlnet_method = self.available_models[new_name].controlnets[0] + self.controlnet_method = self.available_models[new_name].controlnets[0] try: # TODO: enable/disable controlnet without reload model del self.model @@ -94,7 +101,7 @@ class ModelManager: ) except Exception as e: self.name = old_name - self.sd_controlnet_method = old_sd_controlnet_method + self.controlnet_method = old_controlnet_method logger.info(f"Switch model from {old_name} to {new_name} failed, rollback") self.model = self.init_model( old_name, switch_mps_device(old_name, self.device), **self.kwargs @@ -106,24 +113,24 @@ class ModelManager: return if ( - self.sd_controlnet + self.enable_controlnet and config.controlnet_method - and self.sd_controlnet_method != config.controlnet_method + and self.controlnet_method != config.controlnet_method ): - old_sd_controlnet_method = self.sd_controlnet_method - self.sd_controlnet_method = config.controlnet_method + old_controlnet_method = self.controlnet_method + self.controlnet_method = config.controlnet_method self.model.switch_controlnet_method(config.controlnet_method) logger.info( - f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}" + f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}" ) - elif self.sd_controlnet != config.controlnet_enabled: - self.sd_controlnet = config.controlnet_enabled - self.sd_controlnet_method = config.controlnet_method + elif self.enable_controlnet != config.enable_controlnet: + self.enable_controlnet = config.enable_controlnet + self.controlnet_method = config.controlnet_method self.model = self.init_model( self.name, switch_mps_device(self.name, self.device), **self.kwargs ) - if not config.controlnet_enabled: + if not config.enable_controlnet: logger.info(f"Disable controlnet") else: logger.info(f"Enable controlnet: {config.controlnet_method}") diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index b253d7b..3478cb6 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -1,19 +1,8 @@ -from typing import Optional, List from enum import Enum +from typing import Optional from PIL.Image import Image -from pydantic import BaseModel, computed_field - -from lama_cleaner.const import ( - SDXL_CONTROLNET_CHOICES, - SD2_CONTROLNET_CHOICES, - SD_CONTROLNET_CHOICES, -) - -DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" -DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" -DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline" -DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline" +from pydantic import BaseModel class ModelType(str, Enum): @@ -25,103 +14,6 @@ class ModelType(str, Enum): DIFFUSERS_OTHER = "diffusers_other" -FREEU_DEFAULT_CONFIGS = { - ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), - ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2), -} - - -class ModelInfo(BaseModel): - name: str - path: str - model_type: ModelType - is_single_file_diffusers: bool = False - - @computed_field - @property - def need_prompt(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [ - "timbrooks/instruct-pix2pix", - "kandinsky-community/kandinsky-2-2-decoder-inpaint", - ] - - @computed_field - @property - def controlnets(self) -> List[str]: - if self.model_type in [ - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SDXL_INPAINT, - ]: - return SDXL_CONTROLNET_CHOICES - if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: - if self.name in ["stabilityai/stable-diffusion-2-inpainting"]: - return SD2_CONTROLNET_CHOICES - else: - return SD_CONTROLNET_CHOICES - return [] - - @computed_field - @property - def support_strength(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - - @computed_field - @property - def support_outpainting(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] or self.name in [ - "kandinsky-community/kandinsky-2-2-decoder-inpaint", - ] - - @computed_field - @property - def support_lcm_lora(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - - @computed_field - @property - def support_controlnet(self) -> bool: - return self.model_type in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - - @computed_field - @property - def support_freeu(self) -> bool: - return ( - self.model_type - in [ - ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SDXL_INPAINT, - ] - or "timbrooks/instruct-pix2pix" in self.name - ) - - class HDStrategy(str, Enum): # Use original image size ORIGINAL = "Original" @@ -157,6 +49,13 @@ class FREEUConfig(BaseModel): b2: float = 1.4 +class PowerPaintTask(str, Enum): + text_guided = "text-guided" + shape_guided = "shape-guided" + object_remove = "object-remove" + outpainting = "outpainting" + + class Config(BaseModel): class Config: arbitrary_types_allowed = True @@ -239,6 +138,11 @@ class Config(BaseModel): p2p_image_guidance_scale: float = 1.5 # ControlNet - controlnet_enabled: bool = False + enable_controlnet: bool = False controlnet_conditioning_scale: float = 0.4 - controlnet_method: str = "control_v11p_sd15_canny" + controlnet_method: str = "lllyasviel/control_v11p_sd15_canny" + + # PowerPaint + powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided + # control the fitting degree of the generated objects to the mask shape. + fitting_degree: float = 1.0 diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index adc9f2c..711d56a 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -63,6 +63,7 @@ from lama_cleaner.helper import ( numpy_to_bytes, resize_max_size, pil_to_bytes, + is_mac, ) NUM_THREADS = str(multiprocessing.cpu_count()) @@ -285,9 +286,10 @@ def process(): cv2_radius=form["cv2Radius"], paint_by_example_example_image=paint_by_example_example_image, p2p_image_guidance_scale=form["p2pImageGuidanceScale"], - controlnet_enabled=form["controlnet_enabled"], + enable_controlnet=form["enable_controlnet"], controlnet_conditioning_scale=form["controlnet_conditioning_scale"], controlnet_method=form["controlnet_method"], + powerpaint_task=form["powerpaintTask"], ) if config.sd_seed == -1: @@ -305,6 +307,8 @@ def process(): if "CUDA out of memory. " in str(e): # NOTE: the string may change? return "CUDA out of memory", 500 + elif "Invalid buffer size" in str(e) and is_mac(): + return "Out of memory", 500 else: logger.exception(e) return f"{str(e)}", 500 @@ -423,8 +427,8 @@ def get_server_config(): "plugins": list(global_config.plugins.keys()), "enableFileManager": global_config.enable_file_manager, "enableAutoSaving": global_config.enable_auto_saving, - "enableControlnet": global_config.model_manager.sd_controlnet, - "controlnetMethod": global_config.model_manager.sd_controlnet_method, + "enableControlnet": global_config.model_manager.enable_controlnet, + "controlnetMethod": global_config.model_manager.controlnet_method, "disableModelSwitch": global_config.disable_model_switch, "isDesktop": global_config.is_desktop, }, 200 diff --git a/lama_cleaner/tests/utils.py b/lama_cleaner/tests/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index 2f52d62..5107d27 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -15,8 +15,8 @@ def save_config( port, model, sd_local_model_path, - sd_controlnet, - sd_controlnet_method, + enable_controlnet, + controlnet_method, device, gui, no_gui_auto_close, @@ -176,13 +176,13 @@ def main(config_file: str): sd_local_model_path = gr.Textbox( init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}" ) - sd_controlnet = gr.Checkbox( - init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" + enable_controlnet = gr.Checkbox( + init_config.enable_controlnet, label=f"{SD_CONTROLNET_HELP}" ) - sd_controlnet_method = gr.Radio( + controlnet_method = gr.Radio( SD_CONTROLNET_CHOICES, label="ControlNet method", - value=init_config.sd_controlnet_method, + value=init_config.controlnet_method, ) no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") cpu_offload = gr.Checkbox( @@ -205,8 +205,8 @@ def main(config_file: str): port, model, sd_local_model_path, - sd_controlnet, - sd_controlnet_method, + enable_controlnet, + controlnet_method, device, gui, no_gui_auto_close, diff --git a/web_app/src/components/Extender.tsx b/web_app/src/components/Extender.tsx index 1c37090..dacb85e 100644 --- a/web_app/src/components/Extender.tsx +++ b/web_app/src/components/Extender.tsx @@ -1,5 +1,5 @@ -import { EXTENDER_ALL, EXTENDER_X, EXTENDER_Y } from "@/lib/const" import { useStore } from "@/lib/states" +import { ExtenderDirection } from "@/lib/types" import { cn } from "@/lib/utils" import React, { useEffect, useState } from "react" import { twMerge } from "tailwind-merge" @@ -107,7 +107,7 @@ const Extender = (props: Props) => { const newY = evData.initY + offsetY let clampedY = newY let clampedHeight = newHeight - if (extenderDirection === EXTENDER_ALL) { + if (extenderDirection === ExtenderDirection.xy) { if (clampedY > 0) { clampedY = 0 clampedHeight = evData.initHeight - Math.abs(evData.initY) @@ -124,7 +124,7 @@ const Extender = (props: Props) => { const moveBottom = () => { const newHeight = evData.initHeight + offsetY let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) - if (extenderDirection === EXTENDER_ALL) { + if (extenderDirection === ExtenderDirection.xy) { if (clampedHeight < Math.abs(clampedY) + imageHeight) { clampedHeight = Math.abs(clampedY) + imageHeight } @@ -138,7 +138,7 @@ const Extender = (props: Props) => { const newX = evData.initX + offsetX let clampedX = newX let clampedWidth = newWidth - if (extenderDirection === EXTENDER_ALL) { + if (extenderDirection === ExtenderDirection.xy) { if (clampedX > 0) { clampedX = 0 clampedWidth = evData.initWidth - Math.abs(evData.initX) @@ -155,7 +155,7 @@ const Extender = (props: Props) => { const moveRight = () => { const newWidth = evData.initWidth + offsetX let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) - if (extenderDirection === EXTENDER_ALL) { + if (extenderDirection === ExtenderDirection.xy) { if (clampedWidth < Math.abs(clampedX) + imageWdith) { clampedWidth = Math.abs(clampedX) + imageWdith } @@ -296,7 +296,9 @@ const Extender = (props: Props) => { onPointerDown={onCropPointerDown} className="absolute top-0 h-full w-full" > - {[EXTENDER_Y, EXTENDER_ALL].includes(extenderDirection) ? ( + {[ExtenderDirection.y, ExtenderDirection.xy].includes( + extenderDirection + ) ? ( <>
{ <> )} - {[EXTENDER_X, EXTENDER_ALL].includes(extenderDirection) ? ( + {[ExtenderDirection.x, ExtenderDirection.xy].includes( + extenderDirection + ) ? ( <>
{ <> )} - {extenderDirection === EXTENDER_ALL ? ( + {extenderDirection === ExtenderDirection.xy ? ( <> {createDragHandle("cursor-nw-resize", "top", "left")} {createDragHandle("cursor-ne-resize", "top", "right")} diff --git a/web_app/src/components/PromptInput.tsx b/web_app/src/components/PromptInput.tsx index 09e9571..5b65851 100644 --- a/web_app/src/components/PromptInput.tsx +++ b/web_app/src/components/PromptInput.tsx @@ -36,9 +36,9 @@ const PromptInput = () => { updateSettings({ prompt: target.value }) } - const handleRepaintClick = async () => { - if (prompt.length !== 0 && !isProcessing) { - await runInpainting() + const handleRepaintClick = () => { + if (!isProcessing) { + runInpainting() } } @@ -69,7 +69,7 @@ const PromptInput = () => {